I want to know how BatchNormalization works in keras, so I write the code:
X_input = keras.Input((2,))
X = keras.layers.BatchNormalization(axis=1)(X_input)
model1 = keras.Model(inputs=X_input, outputs=X)
the input is a batch of two dimenstions vector, and normalizing it along axis=1, then print the output:
a = np.arange(4).reshape((2,2))
print('a=')
print(a)
print('output=')
print(model1.predict(a,batch_size=2))
and the output is:
a=
array([[0, 1],
[2, 3]])
output=
array([[ 0. , 0.99950039],
[ 1.99900079, 2.9985013 ]], dtype=float32)
I can not figure out the results. As far as I know, the mean of the batch should be ([0,1] + [2,3])/2 = [1,2], the var is 1/2*(([0,1] - [1,2])^2 + ([2,3]-[1,2])^2) = [1,1]. Finally, normalizing it with (x - mean)/sqrt(var), therefore the results are [-1, -1] and [1,1], where am I wrong?