I have defined the model as in the code below, and I used batch normalization merging to make 3 layers into 1 linear layer.
- The first layer of the model is a linear layer and there is no bias.
- The second layer of the model is a batch normalization and there is no weight and bias ( affine is false )
- The third layer of the model is a linear layer.
The variables named new_weight and new_bias are the weight and bias of the newly created linear layer, respectively.
My question is: Why is the output of the following two print functions different? And where is the wrong part in the code below the batch merge comment?
import torch
import torch.nn as nn
import torch.optim as optim
learning_rate = 0.01
in_nodes = 20
internal_nodes = 8
out_nodes = 9
batch_size = 100
# model define
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.layer1 = nn.Linear(in_nodes, internal_nodes, bias=False)
self.layer2 = nn.BatchNorm1d(internal_nodes, affine=False)
self.layer3 = nn.Linear(internal_nodes, out_nodes)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
# optimizer and criterion
model = M()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
# training
for batch_num in range(1000):
model.train()
optimizer.zero_grad()
input = torch.randn(batch_size, in_nodes)
target = torch.ones(batch_size, out_nodes)
output = model(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# batch merge
divider = torch.sqrt(model.layer2.eps + model.layer2.running_var)
w_bn = torch.diag(torch.ones(internal_nodes) / divider)
new_weight = torch.mm(w_bn, model.layer1.weight)
new_weight = torch.mm(model.layer3.weight, new_weight)
b_bn = - model.layer2.running_mean / divider
new_bias = model.layer3.bias + torch.squeeze(torch.mm(model.layer3.weight, b_bn.reshape(-1, 1)))
input = torch.randn(batch_size, in_nodes)
print(model(input))
print(torch.t(torch.mm(new_weight, torch.t(input))) + new_bias)