0

I don't know if this is a bug with tensorflow or if it is me doing something wrong in my code. Anyways, I have tried to pass a tensor into the tf.tensordot axes variable and seem to get weird outputs when I run the code. My code looks as follows:

t1 = tf.ones([2,3,4], dtype=tf.float32)
t2 = tf.ones([2,2,3,4], dtype=tf.float32)

axes = tf.constant([[2],[3]], dtype=tf.int32)
out = tf.tensordot(t1, t2, axes)
print(out)

with tf.Session() as sess:
    sess.run(out)

If I run this code, I get the output from the print statement:

Tensor("Tensordot:0", dtype=float32)

Obviously, the calculations have gone through, and the program isn't complaining about any errors or so. However, the created tensor doesn't seem to have any shape. At least no registered shape. If I run the same code above again, but now switch the axes variable to:

axes = [[2],[3]]

I get the following output from the print statement:

Tensor("Tensordot:0", shape=(2, 3, 2, 2, 3), dtype=float32)

I have tried looking around to see if there's anything I'm doing wrong, or if I'm thinking wrong about this, but I haven't found much info on this topic. I suspect this is a bug in tensorflow, and if it is, are there any workarounds to this issue where I can still use a tensor to define axes, or where axes gets its shape from tensor info?

To give you a hint of where I'm coming from: in my main algorithm, I'm trying to multiply the last dimensions of t1 and t2 (of course, when they have arbitrary ranks) with tensordot.

Andreas Forslöw
  • 2,220
  • 23
  • 32

1 Answers1

0

I actually managed to solve my last problem after having looked at tf.tensordot's source code. Apparently, tensor rank can both be accessed by tf.rank and tf.get_shape().n_dims. Therefore, I could write the following code:

t1 = tf.ones([2,3,4], dtype=tf.float32)
t2 = tf.ones([2,2,3,4], dtype=tf.float32)

t1_rank = t1.get_shape().ndims
t2_rank = t2.get_shape().ndims

axes = [[t1_rank-1], [t2_rank-1]]
out = tf.tensordot(t1, t2, axes)
print(out)

with tf.Session() as sess:
    sess.run(out)

To get the following output from the print statement:

Tensor("Tensordot:0", shape=(2, 3, 2, 2, 3), dtype=float32)

I have a small suspicion that this is an inefficient way of doing this, as the tensor rank now has to go through python in order to get inserted to the axes and eventually tf.tensordot op. However, this seems to be the way they solve rank calculations, and it's part of setting up the graph, so it might just be a one-time calculation. I don't know much about how the graph is set up, though, so don't take my word for it.

Andreas Forslöw
  • 2,220
  • 23
  • 32