I can use tf.matmul(A, B)
to do batch matrix multiplication when:
A.shape == (..., a, b)
andB.shape == (..., b, c)
,
where the ...
are the same.
But I want an additional broadcasting:
A.shape == (a, b, 2, d)
andB.shape == (a, 1, d, c)
result.shape == (a, b, 2, c)
I expect the result to be a x b
batches of matrix multiplication between (2, d)
and (d, c)
.
How to do this?
Test code:
import tensorflow as tf
import numpy as np
a = 3
b = 4
c = 5
d = 6
x_shape = (a, b, 2, d)
y_shape = (a, d, c)
z_shape = (a, b, 2, c)
x = np.random.uniform(0, 1, x_shape)
y = np.random.uniform(0, 1, y_shape)
z = np.empty(z_shape)
with tf.Session() as sess:
for i in range(b):
x_now = x[:, i, :, :]
z[:, i, :, :] = sess.run(
tf.matmul(x_now, y)
)
print(z)