I am trying to wrap my head around skip connections in a sequential model. With the functional API I would be doing something as easy as (quick example, maybe not be 100% syntactically correct but should get the idea):
x1 = self.conv1(inp)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.deconv4(x)
x = self.deconv3(x)
x = self.deconv2(x)
x = torch.cat((x, x1), 1))
x = self.deconv1(x)
I am now using a sequential model and trying to do something similar, create a skip connection that brings the activations of the first conv layer all the way to the last convTranspose. I have taken a look at the U-net architecture implemented here and it's a bit confusing, it does something like this:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
Isn't this just adding layers to the sequential model well, sequentially? There is the down
conv which is followed by submodule
(which recursively adds inner layers) and then concatenated to up
which is the upconv layer. I am probably missing something important on how the Sequential
API works, but how does the code snipped from U-NET actually implements the skip?