I want to perform a check for even and odd elements of the batch and swap them if needed. I managed to result with two tensors I want to interweave:
def tf_oplu(x, name=None):
even = x[:,::2] #slicing into odd and even parts on the batch
odd = x[:,1::2]
even_flatten = tf.reshape(even, [-1]) # flatten tensors
#in row-major order to apply function across them
odd_flatten = tf.reshape(odd, [-1])
compare = tf.to_float(even_flatten<odd_flatten)
compare_not = tf.to_float(even_flatten>=odd_flatten)
#def oplu(x,y): # trivial function
# if x<y : # (x<y)==1
# return y, x
# else:
# return x, y # (x<y)==0
even_flatten_new = odd_flatten * compare + even_flatten * compare_not
odd_flatten_new = odd_flatten * compare_not + even_flatten * compare
# convolute back
even_new = tf.reshape(even_flatten_new,[100,128])
odd_new = tf.reshape(odd_flatten_new,[100,128])
Now I want to get back $[100,256]$ tensor with even and odd places filled. In numpy I would of course do:
y = np.empty((even_new.size + odd_newsize,), dtype=even_new.dtype)
y[:,0::2] = even_new
y[:,1::2] = odd_new
return y
But such thing is not possible for tensoflow, as tensor is not modifiable. I suppose it is possible with either sparse tensor or tf.gather_nd, but both require generating array of indices, which is again non-trivial task for me.
One more note: I don not want to use any python functions via tf.py_func
, as I checked that they run on CPU only. Maybe lambda and tf.map_fn
may help somehow? Thanks!