0

I am trying to migrate from keras to pytorch and I am getting some inconsistent behaviors. First of all I'm noticing that I have a different number of parameters but the other thing I'm noticing is that the keras training converges a lot faster and more smoothly than the pytorch. Have I done something wrong? Can anyone help? The relevant code is below.

I have the following simple architecture in keras:

def identity_block(input_tensor, units):
   x = layers.Dense(units, kernel_regularizer=reg)(input_tensor)
   x = layers.BatchNormalization()(x)
   x = layers.Activation('relu')(x)

   x = layers.Dense(units, kernel_regularizer=reg)(x)
   x = layers.BatchNormalization()(x)
   x = layers.Activation('relu')(x)

   x = layers.Dense(units, kernel_regularizer=reg)(x)
   x = layers.BatchNormalization()(x)
   x = layers.add([x, input_tensor])
   x = layers.Activation('relu')(x)
   return x


def dens_block(input_tensor, units, reps=2):
    x = input_tensor

    for _ in range(reps):
        x = layers.Dense(units, kernel_regularizer=reg)(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)

    x = layers.Dense(units, kernel_regularizer=reg)(x)
    x = layers.BatchNormalization()(x)
    shortcut = layers.Dense(units, kernel_regularizer=reg)(input_tensor)
    shortcut = layers.BatchNormalization()(shortcut)
    x = layers.add([x, shortcut])
    x = layers.Activation('relu')(x)
    return x

def resnet_block(input_tensor, width=16):
    x = dens_block(input_tensor, width)
    x = identity_block(x, width)
    x = identity_block(x, width)
    return x

def RegResNet(input_size=8, reps=3, initial_weights=None, width=16, num_gpus=0, lr=1e-4):
    input_layer = layers.Input(shape=(input_size,))
    x = input_layer
    for _ in range(reps):
        x = resnet_block(x, width)

    x = layers.BatchNormalization()(x)
    x = layers.Dense(1, activation=None)(x)

    model = models.Model(inputs=input_layer, outputs=x)
    return model

Which yields a summary I couldn't fit and the following parameter count.

==================================================================================================
Total params: 9,905
Trainable params: 8,913
Non-trainable params: 992
==================================================================================================

Below is my pytorch translation of the same architecture:

class _DenseBlock(nn.Module):
        def __init__(self, input_size, output_size):
            super(_DenseBlock, self).__init__()
            self.input_size = input_size
            self.output_size = output_size
            self.linear = nn.Linear(self.input_size, self.output_size)
            self.bnorm = nn.BatchNorm1d(self.output_size)

    def forward(self, x):
        x = self.linear(x)
        x = self.bnorm(x)
        x = nn.ReLU()(x)
        return x

class DenseBlock(nn.Module):
    def __init__(self, input_size, output_size, repetitions=2):
        super(DenseBlock, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.repetitions = repetitions
        self.dense_blocks = nn.ModuleList([_DenseBlock(self.input_size, self.output_size)]
                      + [_DenseBlock(self.output_size, self.output_size)
                         for _ in range(self.repetitions - 1)])
        self.b_norm1 = nn.BatchNorm1d(self.output_size)
        self.b_norm2 = nn.BatchNorm1d(self.output_size)
        self.linear_1 = nn.Linear(self.output_size, self.output_size)
        self.linear_2 = nn.Linear(self.input_size, self.output_size)



    def forward(self, x):
        identity = x
        
        for l in self.dense_blocks:
            x = l(x)

        x = self.linear_1(x)
        x = self.b_norm1(x)

        shortcut = self.linear_2(identity)
        shortcut = self.b_norm2(shortcut)

        x = shortcut + x
        x = nn.ReLU()(x)
        return x

class IdentityBlock(nn.Module):
    def __init__(self, input_size, output_size):
        super(IdentityBlock, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.linear_1 = nn.Linear(self.input_size, self.output_size)
        self.linear_2 = nn.Linear(self.input_size, self.output_size)
        self.linear_3 = nn.Linear(self.input_size, self.output_size)
        self.bnorm_1 = nn.BatchNorm1d(self.input_size)
        self.bnorm_2 = nn.BatchNorm1d(self.output_size)
        self.bnorm_3 = nn.BatchNorm1d(self.output_size)

    def forward(self, x):
        input_tensor = x
        x = self.linear_1(x)
        x = self.bnorm_1(x)
        x = nn.ReLU()(x)

        x = self.linear_2(x)
        x = self.bnorm_2(x)
        x = nn.ReLU()(x)

        x = self.linear_3(x)
        x = self.bnorm_3(x)
        x = x + input_tensor
        x = nn.ReLU()(x)
        return x

class ResnetBlock(nn.Module):
    def __init__(self, input_size, output_size):
        super(ResnetBlock, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.dense_block = DenseBlock(self.input_size, self.output_size)
        self.identity_block_1 = IdentityBlock(self.output_size, self.output_size)
        self.identity_block_2 = IdentityBlock(self.output_size, self.output_size)

    def forward(self, x):
        x = self.dense_block(x)
        x = self.identity_block_1(x)
        x = self.identity_block_2(x)
        return x

class RegResNet(nn.Module):

    def __init__(self, input_size, block_width, repititions=3):
        super(RegResNet, self).__init__()
        self.input_size = input_size
        self.repititions = repititions
        self.block_width = block_width
        self.resnet_blocks = nn.ModuleList([ResnetBlock(self.input_size, self.block_width)]
                                           + [ResnetBlock(self.block_width, self.block_width)
                                              for _ in range(self.repititions - 1)])
        self.bnorm = nn.BatchNorm1d(self.block_width)
        self.out = nn.Linear(self.block_width, 1)

    def forward(self, x):
        for layer in self.resnet_blocks:
            x = layer(x)

        x = self.bnorm(x)
        x = self.out(x)
        return x

Using the torchsummary library yields I got the summary (which I couldnt fit) and the following parameter count

================================================================
    Total params: 8,913
    Trainable params: 8,913
    Non-trainable params: 0
================================================================
ClimbingTheCurve
  • 323
  • 2
  • 14
  • what's important here is the trainable params no? and you got the same nuber of trainable params – Rachel Shalom Aug 06 '20 at 17:21
  • Wow! Thank you I can't believe I missed that! Do you know where how it's determined which parameters are trainable? This is a solid indicator that I built the model correctly right? – ClimbingTheCurve Aug 06 '20 at 17:25
  • I am more of a pytorch kind of girl but I think that the difference here might be with the batch nornalization layer. in pytorch- by default is that the parameters there are learnable unless you use affine=False. I think it's a different case in keras. maybe this will help: https://stackoverflow.com/questions/47312219/what-is-the-definition-of-a-non-trainable-parameter#:~:text=In%20keras%2C%20non%2Dtrainable%20parameters,updated%20during%20training%20with%20backpropagation.&text=This%20means%20that%20keras%20won,like%20statistics%20in%20BatchNormalization%20layers. – Rachel Shalom Aug 06 '20 at 17:32
  • Thanks I'll check it out. I'm making the change, keras is dead and I'm sick of fighting for control of my training loops. – ClimbingTheCurve Aug 06 '20 at 17:35

0 Answers0