in submission_code/tools.py [0:0]
def forward(self, x):
src, tgt = x
src_attn = (src != 0).float()
tgt_attn = (tgt != 0).float()
x = self.tr(
input_ids=src,
attention_mask=src_attn,
decoder_input_ids=tgt,
decoder_attention_mask=tgt_attn,
)
x = x[0].permute(0,2,1)
return x