2

For numpy we have

threshold = 3
a = np.array([1,2,3,4,5,6])
a[a>=3] = 199

# a is [1, 2, 199, 199, 199, 199]

How to write a similar code in tensorflow 2

b = tf.Variable(a)

Thanks.

khelwood
  • 55,782
  • 14
  • 81
  • 108
jason
  • 1,998
  • 3
  • 22
  • 42

1 Answers1

2

Sure, you can use tf.where to conditionally set values:

b = tf.Variable(a)
tf.where(b >= 3, 199, b)
# <tf.Tensor: shape=(6,), dtype=int64, numpy=array([  1,   2, 199, 199, 199, 199])>
cs95
  • 379,657
  • 97
  • 704
  • 746
  • Hi@cs95, Thanks for the answer. Will `tf.where(b>=3, 199, b)` be plugin to the computational graph and affect the gradient of `b` in the back propragration? Thanks, – jason Dec 25 '20 at 16:35
  • 1
    @jason Couldn't tell you for sure, sorry. Perhaps you could open a follow up question. – cs95 Dec 25 '20 at 16:37
  • Sure, the problem is in this post now https://stackoverflow.com/questions/65449945/what-types-of-operations-will-will-not-plugin-the-computational-graph-in-the-ten – jason Dec 25 '20 at 16:45
  • @jason Good luck with the question, I'll see if I can do any digging. – cs95 Dec 25 '20 at 16:46