5

I'm currently trying to extend a model that is based on FairSeq/PyTorch. During training I need to train two encoders: one with the target sample, and the original one with the source sample.

So the current forward function looks like this:

def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

And based on this this idea i want something like this:

def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
    concat = some_concatination_func(encoder_out, autoencoder_out)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
    return decoder_out

Is there any way to do this?

Edit: These are the constraints that I have, since I need to extend FairseqEncoderDecoderModel:

@register_model('transformer_mass')
class TransformerMASSModel(FairseqEncoderDecoderModel):
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder) 

Edit 2: The parameters passed to the forward function in Fairseq can be altered by implementing your own Criterion, see for example CrossEntropyCriterion, where sample['net_input'] is passed to the __call__ function of the model, which invokes the forward method.

qwertz
  • 315
  • 1
  • 4
  • 14
  • I don't see what the question is. Just put the 2 forward functions in your model class. Then, while training, use `model.forward_train` on the input. And while testing use `model.forward_test` on the input. Note that you can't do `model(input)` in that case since PyTorch considers that equivalent to `model.forward(input)`, so that would throw an error. – akshayk07 Nov 01 '19 at 07:09

2 Answers2

11

First of all you should always use and define forward not some other methods that you call on the torch.nn.Module instance.

Definitely do not overload eval() as shown by trsvchn as it's evaluation method defined by PyTorch (see here). This method allows layers inside your model to be put into evaluation mode (e.g. specific changes to layers like inference mode for Dropout or BatchNorm).

Furthermore you should call it with __call__ magic method. Why? Because hooks and other PyTorch specific stuff is registered that way properly.

Secondly, do not use some external mode string variable as suggested by @Anant Mittal. That's what train variable in PyTorch is for, it's standard to differentiate by it whether model is in eval mode or train mode.

That being said you are the best off doing it like this:

import torch


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        ...

    # You could split it into two functions but both should be called by forward
    def forward(
        self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
    ):
        encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
        if self.train:
            return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
        autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
        concat = some_concatination_func(encoder_out, autoencoder_out)
        return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)

You could (and arguably should) split the above into two separate methods, but that's not too bad as the function is rather short and readable that way. Just stick to PyTorch's way of handling things if easily possible and not some ad-hoc solutions. And no, there will be no problem with backpropagation, why would there be one?

Szymon Maszke
  • 22,747
  • 4
  • 43
  • 83
  • 1
    Oh, yes, you're right, my bad -- `eval` is reserved, I have to change it, thanks! – trsvchn Nov 01 '19 at 14:00
  • 1
    Thanks @Szymon Maszke, I wasn't aware of this variable. I just thought logically and just came up with a way so that only forward function has to be wriitten. But thanks a lot. – Anant Mittal Nov 02 '19 at 02:33
  • @Szymon Maszke Thanks a lot for the answer! The train variable is very helpful. I'm not sure however on how to pass the target tokens to the forward method, since I train with the fairseq-train command, that only passes the source tokens. – qwertz Nov 05 '19 at 00:48
  • @qwertz I'm not sure I follow. You can pass any variables to either `forward` or `__init__` at least if you don't want to export the code with JIT. And what is `fair-seq` train command? Could you provide an example and link to some source? How is it different from standard PyTorch? – Szymon Maszke Nov 05 '19 at 07:46
  • @Szymon Maszke I'm implementing the class "FairseqEncoderDecoderModel" from the Fairseq framework. https://github.com/pytorch/fairseq/blob/master/fairseq/models/fairseq_model.py#L184 And I train using this command: https://fairseq.readthedocs.io/en/latest/command_line_tools.html#fairseq-train So I do not actually handle the data loading myself and do not know where the forward method is ultimately called by the fairseq-train cmd command. – qwertz Nov 05 '19 at 10:01
  • Can't help you either unfortunately – Szymon Maszke Nov 05 '19 at 16:32
0

By default, calling model() invoke forward method which is train forward in your case, so you just need to define new method for your test/eval path inside your model class, smth like here:

Code:

class FooBar(nn.Module):
    """Dummy Net for testing/debugging.
    """

    def __init__(self):
        super().__init__()
        ...

    def forward(self, x):
        # here will be train forward
        ...

    def evaltest(self, x):
        # here will be eval/test forward
        ...

Examples:

model = FooBar()  # initialize model 

# train time
pred = model(x)   # calls forward() method under the hood

# test/eval time
test_pred = model.evaltest(x)

Comment: I would like to recommend you to split these two forward paths into 2 separate methods, because it easier to debug and to avoid some possible problems when backpropagating.

trsvchn
  • 8,033
  • 3
  • 23
  • 30