0

I have a tensor 'a', I want to modify a element of it.

 a = tf.convert_to_tensor([[1.0, 1.0, 1.0],
                           [1.0, 2.0, 1.0],
                           [1.0, 1.0, 1.0]], dtype=tf.float32)

And I can got the index of that element.

 index = tf.where(a==2)

How to derive 'b' from 'a'?

 b = tf.convert_to_tensor([[1.0, 1.0, 1.0],
                           [1.0, 0.0, 1.0],
                           [1.0, 1.0, 1.0]], dtype=tf.float32)

I know that I can't not modify a tensor from this post.

Zehao Shi
  • 99
  • 8

1 Answers1

0

I solve it by using tf.sparse_to_dense()

import tensorflow as tf

a = tf.convert_to_tensor([[1.0, 1.0, 1.0],
                         [1.0, 2.0, 1.0],
                         [1.0, 1.0, 1.0]], dtype=tf.float32)

index = tf.where(a > 1)
zero = tf.sparse_to_dense(index, tf.shape(a, out_type=tf.int64), 0., 1.)
update = tf.sparse_to_dense(index, tf.shape(a, out_type=tf.int64), 0., 0.)
b = a * zero + update

with tf.Session() as sess:
  print sess.run(b)
Zehao Shi
  • 99
  • 8