def forward()

in empchat/transformer_local.py [0:0]


    def forward(self, input_, mask):
        """
        input data is a LongTensor of shape [batch, seq_len], containing each
        word's index in the embeddings table.
        mask is a ByteTensor of shape [batch, seq_len], filled with 1 when
        inside the sequence and 0 outside.
        """
        seq_len = input_.size(1)
        positions = input_.new(seq_len).long()
        positions = torch.arange(seq_len, out=positions).unsqueeze(0)
        tensor = self.embeddings(input_)
        tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
        tensor *= mask.unsqueeze(-1).float()
        for i in range(self.n_layers):
            tensor = tensor + self.attentions[i](tensor, mask)
            tensor = self.normalize(tensor, self.layer_norm1[i])
            tensor = tensor + self.ffns[i](tensor, mask)
            tensor = self.normalize(tensor, self.layer_norm2[i])
            tensor *= mask.unsqueeze(-1).float()
        if self.fix_mean:
            output = tensor.sum(dim=1) / mask.float().sum(dim=1).unsqueeze(-1)
        else:
            output = tensor.mean(dim=1)
        return output