I want to assign some values to slices of an input tensor in one of my model in TensorFlow 2.x (I am using 2.2 but ready to accept a solution for 2.1). A non-working template of what I am trying to do is:
import tensorflow as tf
from tensorflow.keras.models import Model
class AddToEven(Model):
def call(self, inputs):
outputs = inputs
outputs[:, ::2] += inputs[:, ::2]
return outputs
of course when building this (AddToEven().build(tf.TensorShape([None, None]))
) I get the following error:
TypeError: 'Tensor' object does not support item assignment
I can achieve this simple example via the following:
class AddToEvenScatter(Model):
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
n = tf.shape(inputs)[-1]
update_indices = tf.range(0, n, delta=2)[:, None]
scatter_nd_perm = [1, 0]
inputs_reshaped = tf.transpose(inputs, scatter_nd_perm)
outputs = tf.tensor_scatter_nd_add(
inputs_reshaped,
indices=update_indices,
updates=inputs_reshaped[::2],
)
outputs = tf.transpose(outputs, scatter_nd_perm)
return outputs
(you can sanity-check with:
model = AddToEvenScatter()
model.build(tf.TensorShape([None, None]))
model(tf.ones([1, 10]))
)
But as you can see it's very complicated to write. And this is only for a static number of updates (here 1) on a 1D (+ batch size) tensor.
What I want to do is a bit more involved and I think writing it with tensor_scatter_nd_add
is going to be a nightmare.
A lot of the current QAs on the topic cover the case for variables but not tensors (see e.g. this or this). It is mentionned here that indeed pytorch supports this, so I am surprised to see no response from any tf members on that topic recently. This answer doesn't really help me, because I will need some kind of mask generation which is going to be awful as well.
The question is thus: how can I do slice assignment efficiently (computation-wise, memory-wise and code-wise) w/o tensor_scatter_nd_add
? The trick is that I want this to be as dynamical as possible, meaning that the shape of the inputs
could be variable.
(For anyone curious I am trying to translate this code in tf).
This question was originally posted in a GitHub issue.