1

I'm using TensorFlow Keras backend and I have two tensors a, b of the same shape: (None, 4, 7), where None represents the batch dimension.

I want to do matrix multiplication, and I'm expecting a result of (None, 4, 4).
i.e. For each batch, do one matmul: (4,7)·(7,4) = (4,4)

Here's my code --

K.dot(a, K.reshape(b, (-1, 7, 4)))

This code gives a tensor of shape (None, 4, None, 4)

I'd like to know how does high-dimension matrix multiplication work? What's the right way to do this?

o_yeah
  • 688
  • 7
  • 17
  • https://stackoverflow.com/a/43829731/12162096 tf.matmul is exactly what I want. Now the question is simplified --> how to use tf.matmul in Keras, is that possible? – o_yeah Mar 21 '22 at 23:27
  • And also, in terms of matrix multiplication, the way that `Keras backend dot method` handles None dimension seems different from tf.matmul. i.e. In Keras, (None, 4, 4)·(None, 4, 7) = (None, 4, None, 7). In tf.matmul, (None, 4, 4)·(None, 4, 7) = (None, 4, 7) , which is what I need. – o_yeah Mar 21 '22 at 23:32

1 Answers1

1

IIUC, you can either use tf.matmul directly as part of your model and transpose b or explicitly wrap the operation in a Lambda layer:

import tensorflow as tf

a = tf.keras.layers.Input((4, 7))
b = tf.keras.layers.Input((4, 7))
output = tf.matmul(a, b, transpose_b=True)
model = tf.keras.Model([a, b], output)
model.summary()
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_15 (InputLayer)          [(None, 4, 7)]       0           []                               
                                                                                                  
 input_16 (InputLayer)          [(None, 4, 7)]       0           []                               
                                                                                                  
 tf.linalg.matmul_2 (TFOpLambda  (None, 4, 4)        0           ['input_15[0][0]',               
 )                                                                'input_16[0][0]']               
                                                                                                  
==================================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
__________________________________________________________________________________________________

Or

import tensorflow as tf

a = tf.keras.layers.Input((4, 7))
b = tf.keras.layers.Input((4, 7))
output = tf.keras.layers.Lambda(lambda x: tf.matmul(x[0], x[1], transpose_b=True))([a, b])
model = tf.keras.Model([a, b], output)
model.summary()
AloneTogether
  • 25,814
  • 5
  • 20
  • 39