I don't know if this is a bug with tensorflow or if it is me doing something wrong in my code. Anyways, I have tried to pass a tensor into the tf.tensordot axes variable and seem to get weird outputs when I run the code. My code looks as follows:
t1 = tf.ones([2,3,4], dtype=tf.float32)
t2 = tf.ones([2,2,3,4], dtype=tf.float32)
axes = tf.constant([[2],[3]], dtype=tf.int32)
out = tf.tensordot(t1, t2, axes)
print(out)
with tf.Session() as sess:
sess.run(out)
If I run this code, I get the output from the print statement:
Tensor("Tensordot:0", dtype=float32)
Obviously, the calculations have gone through, and the program isn't complaining about any errors or so. However, the created tensor doesn't seem to have any shape. At least no registered shape. If I run the same code above again, but now switch the axes variable to:
axes = [[2],[3]]
I get the following output from the print statement:
Tensor("Tensordot:0", shape=(2, 3, 2, 2, 3), dtype=float32)
I have tried looking around to see if there's anything I'm doing wrong, or if I'm thinking wrong about this, but I haven't found much info on this topic. I suspect this is a bug in tensorflow, and if it is, are there any workarounds to this issue where I can still use a tensor to define axes, or where axes gets its shape from tensor info?
To give you a hint of where I'm coming from: in my main algorithm, I'm trying to multiply the last dimensions of t1 and t2 (of course, when they have arbitrary ranks) with tensordot.