Building on this question I am looking to update the values of a 2-D tensor the first time in a row the tf.where condition is met. Here is a sample code I am using to simulate:
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
val = "hello"
new_val = "goodbye"
matrix = tf.constant([["word","hello","hello"],
["word", "other", "hello"],
["hello", "hello","hello"],
["word", "word", "word"]
])
matching_indices = tf.where(tf.equal(matrix, val))
first_matching_idx = tf.segment_min(data = matching_indices[:, 1],
segment_ids = matching_indices[:, 0])
sess = tf.InteractiveSession(graph=graph)
print(sess.run(first_matching_idx))
This will output [1, 2, 0] where the 1 is the placement of the first hello in row 1, the 2 is the placement of the first hello in row 2, and the 0 is the placement of the first hello in row 3.
However, I can't figure out a way to get the first matching index to be updated with the new value -- basically I want the first "hello" to be turned into "goodbye". I have tried using tf.scatter_update() but it does not seem to work on 2D tensors. Is there any way to modify the 2-D tensor as described?