I am given a 2D Tensor with stochastic rows. After applying tf.math.greater()
and tf.cast(tf.int32)
I am left with a Tensor with 0's and 1's. I now want to apply reduce sum onto that matrix but with a condition: If there was at least one 1 summed and a 0 follows I want to remove all following 1 aswell, meaning 1 0 1
should result in 1
instead of 2
.
I have tried to solve the Problem with tf.scan()
, but I was not able to come up with a function yet that is able to handle starting 0's, because the row might look like: 0 0 0 1 0 1
One idea was to set the lower part of the matrix to one (bc I know everything left from the diagonal will always be 0) and then have a function like tf.scan()
run to filter out the spots (see code and error message below).
Let z be the matrix after tf.cast.
helper = tf.matrix_band_part(tf.ones_like(z), -1, 0)
z = tf.math.logical_or(tf.cast(z, tf.bool), tf.cast(helper,tf.bool))
z = tf.cast(z, tf.int32)
z = tf.scan(lambda a, x: x if a == 1 else 0 ,z)
Resulting in:
ValueError: Incompatible shape for value ([]), expected ([5])