0

I would like to implement a custom tf layer that performs a mathematical operation involving the actual batch-size of the input tensor:

import tensorflow as tf
from   tensorflow import keras

class MyLayer(keras.layers.Layer):

    def build(self, input_shape):
        self.batch_size = input_shape[0]
        super().build(input_shape)

    def call(self,input):
        self.batch_size + 1 # do something with the batch size
        return input

However, when building a graph, its value is initially None, which breaks the functionality in MyLayer:

input = keras.Input(shape=(10,))
x     = MyLayer()(input)
TypeError: in user code:

    <ipython-input-41-98e23e82198d>:11 call  *
        self.batch_size + 1 # do something with the batch size

    TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'

Is there any way to make such layers work after the model has been constructed?

John Titor
  • 461
  • 3
  • 13

1 Answers1

2

Use tf.shape to grab the batch size inside your layer's call method.

Example:

import tensorflow as tf


# custom layer
class MyLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
    def call(self, x):
        bs = tf.shape(x)[0]
        return x, tf.add(bs, 1)
    
    
# network
x_in = tf.keras.Input(shape=(None, 10,))
x = MyLayer()(x_in)

# model def
model = tf.keras.models.Model(x_in, x)

# forward pass
_, shp = model(tf.random.normal([5, 10]))

# shape value
print(shp)
# tf.Tensor(6, shape=(), dtype=int32)

o-90
  • 17,045
  • 10
  • 39
  • 63
  • Upvoted. Is it possible to get the batch size inside the `train_step` function within `tf.keras.Model` class? Actually facing [this](https://stackoverflow.com/q/66472201/9215780) issue. – Innat Mar 05 '21 at 01:12
  • @M.Innat you can [override `train_step`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#train_step). You could write a custom train_step that copies and pastes the [source code](https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/keras/engine/training.py#L724-L759), and then adds whatever bach_size logic you need. – o-90 Mar 05 '21 at 01:23