import torch
import torch.nn as nn
data = torch.ones(3,3,6,6)
conv = nn.Conv2d(3, 16, kernel_size = 3, padding = 1)
print(data[0].unsqueeze(0).shape)
for i in range(3):
print((conv(data)[i] == conv(data[i].unsqueeze(0))).all())
Results:
torch.Size([1, 3, 6, 6])
tensor(False)
tensor(False)
tensor(False)
I thought it would print True but ended up printing False instead. Any idea why?