0

I was reading the code for Generative Adversarial Nets Code by https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py, I would like to know what the * sign means here, I searched on Google and Stackoverflow but could not find a clear explanation.

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
  • It changes it from Sequential( [x1, x2, x3], [y1, y2, y3]) to Sequential(x1, x2, x3, y1, y2, y3) – matt Mar 24 '21 at 12:12

1 Answers1

2

*x is iterable unpacking notation in Python. See this related answer.

def block returns a list of layers, and *block(...) unpacks the returned list into positional arguments to the nn.Sequential call.

Here's a simpler example:

def block(in_feat, out_feat):
    return (nn.Linear(in_feat, out_feat), nn.LeakyReLU(0.2, inplace=True))

self.model = nn.Sequential(
    *block(128, 256),
)

# Equivalent to:
# layers = block(128, 256)
# self.model = nn.Sequential(layers[0], layers[1])

# Also equivalent to:
# layers = block(128, 256)
# self.model = nn.Sequential(*layers)
Dan Zheng
  • 1,493
  • 2
  • 13
  • 22