4

First I find another question here No broadcasting for tf.matmul in TensorFlow
But that question does not solve my problem.

My problem is a batch of matrices multiply another batch of vectors.

x=tf.placeholder(tf.float32,shape=[10,1000,3,4])
y=tf.placeholder(tf.float32,shape=[1000,4])

x is a batch of matrices.There are 10*1000 matrices.Each matrix is of shape [3,4]
y is a batch of vectors.There are 1000 vectors.Each vector is of shape[4]
Dim 1 of x and dim 0 of y are the same.(Here is 1000)
If tf.matmul had supported broadcasting,I could write

y=tf.reshape(y,[1,1000,4,1])
result=tf.matmul(x,y)
result=tf.reshape(result,[10,1000,3])

But tf.matmul does not support broadcasting
If I use the approach of the question I referenced above

x=tf.reshape(x,[10*1000*3,4])
y=tf.transpose(y,perm=[1,0]) #[4,1000]
result=tf.matmul(x,y)
result=tf.reshape(result,[10,1000,3,1000])

The result is of shape [10,1000,3,1000],not [10,1000,3].
I don't know how to remove the redundant 1000
How to get the same result as the tf.matmul which supports broadcasting?

Community
  • 1
  • 1
海牙客移库
  • 141
  • 1
  • 11

2 Answers2

3

I solve it myself.

x=tf.transpose(x,perm=[1,0,2,3]) #[1000,10,3,4]
x=tf.reshape(x,[1000,30,4])
y=tf.reshape(y,[1000,4,1])
result=tf.matmul(x,y) #[1000,30,1]
result=tf.reshape(result,[1000,10,3])
result=tf.transpose(result,perm=[1,0,2]) #[10,1000,3]
海牙客移库
  • 141
  • 1
  • 11
0

As indicated here, you can use a function to work around:

def broadcast_matmul(A, B):
  "Compute A @ B, broadcasting over the first `N-2` ranks"
  with tf.variable_scope("broadcast_matmul"):
    return tf.reduce_sum(A[..., tf.newaxis] * B[..., tf.newaxis, :, :],
                         axis=-2)
Yi Bill
  • 1
  • 2