I am in an unusual setting where I should not use running statistics (as that would be considered cheating e.g. meta-learning). However, I often run a forward pass on a set of points (5 in fact) and then I want to evaluate only on 1 point using the previous statistics but batch norm forgets the batch statistics it just uses. I've tried to hard code the value it should be but I get strange errors (even when I uncomment things like from the pytorch code itself like checking the dimension size).
How do I hardcode the previous batch statistics so that batch norm works on a new single data point and then reset them for a fresh new next batch?
note: I don't want to change the batch norm layer type.
Sample code I tried:
def set_tracking_running_stats(model):
for attr in dir(model):
if 'bn' in attr:
target_attr = getattr(model, attr)
target_attr.track_running_stats = True
target_attr.running_mean = torch.nn.Parameter(torch.zeros(target_attr.num_features, requires_grad=False))
target_attr.running_var = torch.nn.Parameter(torch.ones(target_attr.num_features, requires_grad=False))
target_attr.num_batches_tracked = torch.nn.Parameter(torch.tensor(0, dtype=torch.long), requires_grad=False)
# target_attr.reset_running_stats()
return
my most comment errors:
raise ValueError('expected 2D or 3D input (got {}D input)'
ValueError: expected 2D or 3D input (got 1D input)
and
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
related