I would like to implement a custom tf layer that performs a mathematical operation involving the actual batch-size of the input tensor:
import tensorflow as tf
from tensorflow import keras
class MyLayer(keras.layers.Layer):
def build(self, input_shape):
self.batch_size = input_shape[0]
super().build(input_shape)
def call(self,input):
self.batch_size + 1 # do something with the batch size
return input
However, when building a graph, its value is initially None
, which breaks the functionality in MyLayer
:
input = keras.Input(shape=(10,))
x = MyLayer()(input)
TypeError: in user code:
<ipython-input-41-98e23e82198d>:11 call *
self.batch_size + 1 # do something with the batch size
TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'
Is there any way to make such layers work after the model has been constructed?