0

I know there are several posts with this error but my problem is a bit different. It doesn't look like the network I'm trying to import has a different architecture than the one I currently have. I think that the names of the first layer (initial) differ.

RuntimeError: Error(s) in loading state_dict for Generator:
        Missing key(s) in state_dict: "initial.0.weight", "initial.0.bias".
        Unexpected key(s) in state_dict: "initial.weight", "initial.bias".

Initial is defined like so:

self.initial = nn.Sequential(
            nn.Linear(...),
            ...,
        )

I would rename self.initial into self.initial.0 if it was possible but it is not. I must be missing something but I'm a bit confused as to what I am missing.

1 Answers1

0

There's a mismatch between the implemented and saved network structure: your initial() is an nn.Sequential() container while the one you're trying to load seems to be a single layer. You may try reducing your implementation to self.initial = nn.Linear(...) and see whether the checkpoint loads correctly. Assuming there's only one layer with learnable parameters in your container, you may refactor it from something like this:

# A typical container:
self.initial = nn.Sequential(
    nn.Linear(...),
    nn.InstanceNorm2d(...),
    nn.ReLU()
)

to the oldschool way:

self.initial = nn.Linear(...)
self.norm = nn.InstanceNorm2d(...)
self.relu = nn.ReLU()

with respective changes to your forward() implementation:

def forward(self, batch):
    # originally something like:
    batch = self.initial(batch)
    # normalizing/nonlinearity usually have no parameters (unless affine=True)
    # so adding/removing such layers should not break model loading:
    batch = self.norm(batch)
    batch = self.relu(batch)
dx2-66
  • 2,376
  • 2
  • 4
  • 14