0

Is it possible to interleave two Ragged Tensors in Tensorflow? Example: I have two RaggedTensors with the same "shape":

a = [[[10,10]],[[20,20],[21,21]]]
b = [[[30,30]],[[40,40],[41,41]]]

I would like to interleave them so the resulting tensor looks like this:

c = [[[10,10],[30,30]],[[20,20],[40,40],[21,21],[41,41]]]

Note that both tensors a and b always have the same "shape".

I have been trying to use the stack and the concat functions but both of them return non-desired shapes:

tf.stack([a,b],axis=-1)
c = [[[[10, 30], [10, 30]]], [[[20, 40], [20, 40]], [[21, 41], [21, 41]]]]

tf.concat([a,b],axis=-1)
c = [[[10, 10, 30, 30]], [[20, 20, 40, 40], [21, 21, 41, 41]]]

I have seen some other solutions for regular tensors that reshape the resulting tensor c after applying the stack/concat functions. E.g.,:

a = [[[10, 10]], [[20, 20]]]

b = [[[30, 30]], [[40, 40]]]

tf.reshape(
    tf.concat([a[..., tf.newaxis], b[..., tf.newaxis]], axis=1),
    [a.shape[0], -1, a.shape[-1]])

c = [[[10, 10],[30, 30]],[[20, 20],[40, 40]]]

However, as far as I know, since I am using Ragged Tensors the shape in some dimensions is None (I am using TF2.6).

AloneTogether
  • 25,814
  • 5
  • 20
  • 39

1 Answers1

0

If this is a part of a preprocessing step, the tf.data.Dataset API is one route. This has the benefit in using the "interleave" function of mixing up the interleaving pattern with different block_length settings and can interleave an arbitrary number of lists.

# I made a little longer to emphasize raggedness
a = tf.ragged.constant([[[10,10]],[[20,20],[21,21],[23,23]]])
b = tf.ragged.constant([[[30,30]],[[40,40],[41,41]]])
c = tf.concat([a,b],axis=1)

NUM_ELEMENTS=7

# Option 1) more flexible
def ragged_to_ds(x):
  return tf.data.Dataset.from_tensor_slices(x)

tf.data.Dataset.from_tensor_slices(c).interleave(ragged_to_ds, block_length=1).batch(NUM_ELEMENTS).get_single_element()

# Option 2) less mess but unbatch makes copies of the data
tf.data.Dataset.from_tensor_slices(c).unbatch().batch(NUM_ELEMENTS).get_single_element()

The tf.data.Dataset API can be powerful and expressive and can help with a large number of data processing and rearranging tasks.

Mike Holcomb
  • 403
  • 3
  • 9
  • Thank you for your reply! Unfortunately, both, a and b, are tensors that are generated using the output of a Layer inside a call() function of a bigger model. – Miquel Ferriol Galmés Oct 27 '21 at 23:14