I am looking for the analogous of np.delete
in tensorflow - so I have batches of tensors - each batch has shape (batch_size, variable_length)
and I want to get a tensor of shape (batch_size, 2 * variable_length / 3)
. As seen each batch has a different length which is stored and read from the tfrecord. I am a bit at a loss here on what API I should use for that. Related (for numpy):
where the solution would simply be np.delete(x, slice(2, None, 3))
(after performing a reshape to cater for batch_size)
As requested in the comments I post the code for parsing a single example proto - although I am interested in deleting the nth (3rd) element of a tensor as a standalone question.
@classmethod
def parse_single_example(cls, example_proto):
instance = cls()
features_dict = cls._get_features_dict(example_proto)
instance.path_length = features_dict['path_length']
...
instance.coords = tf.decode_raw(features_dict['coords'], DATA_TYPE) # the tensor
...
return instance.coords, ...
@classmethod
def _get_features_dict(cls, value):
features_dict = tf.parse_single_example(value,
features={'coords': tf.FixedLenFeature([], tf.string),
...
'path_length': tf.FixedLenFeature([], tf.int64)})
return features_dict