I'm currently trying to extend a model that is based on FairSeq/PyTorch. During training I need to train two encoders: one with the target sample, and the original one with the source sample.
So the current forward function looks like this:
def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out
And based on this this idea i want something like this:
def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out
def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
concat = some_concatination_func(encoder_out, autoencoder_out)
decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
return decoder_out
Is there any way to do this?
Edit: These are the constraints that I have, since I need to extend FairseqEncoderDecoderModel:
@register_model('transformer_mass')
class TransformerMASSModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
Edit 2:
The parameters passed to the forward function in Fairseq can be altered by implementing your own Criterion, see for example CrossEntropyCriterion, where sample['net_input']
is passed to the __call__
function of the model, which invokes the forward
method.