0

I'm working with the tensorflow.keras API, and I've encountered a syntax which I'm unfamiliar with, i.e., applying a layer on a sub-models' output, as shown in the following example from this tutorial:

from tensorflow.keras import Model, layers
from tensorflow.keras.applications import resnet


target_shape = (200, 200)


base_cnn = resnet.ResNet50(
    weights="imagenet", input_shape=target_shape + (3,), include_top=False
)

flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)

embedding = Model(base_cnn.input, output, name="Embedding")

In the official reference of layers.Flatten for example, I couldn't find the explanation of what does applying it on a layer actually do. In the keras.Layer reference I've encountered this explanation:

call(self, inputs, *args, **kwargs): Called in call after making sure build() has been called. call() performs the logic of applying the layer to the input tensors (which should be passed in as argument).

So my question is:

What does flatten = layers.Flatten()(base_cnn.output) do?

AloneTogether
  • 25,814
  • 5
  • 20
  • 39
Michael
  • 2,167
  • 5
  • 23
  • 38
  • Yes, thank you. But its' not quite what I was looking for - I wanted to know what does the application does under the hood, I mean does it add the model and use it on the forward path in the call method.As this syntax is kind of new to me, I wanted some intuition into why it works. But thanks anyway. – Michael Nov 23 '21 at 15:00

1 Answers1

1

You are creating a model based on a pre-trained model. This pre-trained model will not be actively trained with the rest of your layers unless you explicitly set trainable=True. That is, you are only interested in extracting its useful features. A flattening operation is usually used to convert a multidimensional output into a one-dimensional tensor, and that is exactly what is happening in this line: flatten = layers.Flatten()(base_cnn.output). A one-dimensional tensor is often a desirable end result of a model, especially in supervised learning. The output of the pre-trained resnet model is (None, 7, 7, 2048) and you want to generate 1D feature vectors for each input and compare them, so you flatten that output, resulting in a tensor with the shape (None, 100352) or (None, 7 * 7 * 2048).

Alternatives to Flatten would be GlobalMaxPooling2D and GlobalAveragePooling2D, which downsample an input by taking the max or average value along the spatial dimensions. For more information on this topic check out this post.

AloneTogether
  • 25,814
  • 5
  • 20
  • 39