in empchat/transformer_local.py [0:0]
def forward(self, input_, mask):
"""
input data is a LongTensor of shape [batch, seq_len], containing each
word's index in the embeddings table.
mask is a ByteTensor of shape [batch, seq_len], filled with 1 when
inside the sequence and 0 outside.
"""
seq_len = input_.size(1)
positions = input_.new(seq_len).long()
positions = torch.arange(seq_len, out=positions).unsqueeze(0)
tensor = self.embeddings(input_)
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
tensor *= mask.unsqueeze(-1).float()
for i in range(self.n_layers):
tensor = tensor + self.attentions[i](tensor, mask)
tensor = self.normalize(tensor, self.layer_norm1[i])
tensor = tensor + self.ffns[i](tensor, mask)
tensor = self.normalize(tensor, self.layer_norm2[i])
tensor *= mask.unsqueeze(-1).float()
if self.fix_mean:
output = tensor.sum(dim=1) / mask.float().sum(dim=1).unsqueeze(-1)
else:
output = tensor.mean(dim=1)
return output