function of index_matrix_to_pairs() is different from the comment?
the comment is:[[3,1,2], [2,3,1]] -> [[[0, 3], [1, 1], [2, 2]], [[0, 2], [1, 3], [2, 1]]]
but when I test it, it output this:[[[0 3],[0 1],[0 2]],[[1 2],[1 3],[1 1]]]
this is test code:
with tf.Session(): ,,,,,,,print index_matrix_to_pairs(tf.constant([[3,1,2], [2,3,1]])).eval()
I have the same question
should fix to follow? def index_matrix_to_pairs(index_matrix):
[[3,1,2], [2,3,1]] -> [[[0, 3], [1, 1], [2, 2]],
[[0, 2], [1, 3], [2, 1]]]
replicated_first_indices = tf.range(tf.shape(index_matrix)[-1]) #[0, 1, 2] rank = len(index_matrix.get_shape()) if rank == 2: replicated_first_indices = tf.tile( tf.expand_dims(replicated_first_indices, dim=0), #[[0, 1, 2]] [tf.shape(index_matrix)[0], 1]) return tf.stack([replicated_first_indices, index_matrix], axis=rank)