0

https://github.com/HuiZeng/Image-Adaptive-3DLUT/blob/master/models_x.py#L69

class Classifier_unpaired(nn.Module):
    def __init__(self, in_channels=3):
        super(Classifier_unpaired, self).__init__()

        self.model = nn.Sequential(
            nn.Upsample(size=(256,256),mode='bilinear'),
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.InstanceNorm2d(16, affine=True),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
            *discriminator_block(128, 128),
            #*discriminator_block(128, 128),
            nn.Conv2d(128, 3, 8, padding=0),
        )

    def forward(self, img_input):
        return self.model(img_input)

There is "*" before discriminator_block(), which is a function. What does it do?

  • 6
    The important part isn’t that it’s before a function, but that it’s inside an argument list (`nn.Sequential(…)`). Are you familiar with a call like `print(*foo)`? – Ry- Jan 03 '22 at 03:18

1 Answers1

2

The * when applied to a function parameter expands it from a sequence to multiple parameters. So these two calls are equivalent:

x = [1, 2]
foo(*x)
foo(1, 2)

In your case, it's being applied to parameters of Sequential, and it's expanding the return value of discriminator_block.

Mark Ransom
  • 299,747
  • 42
  • 398
  • 622