def forward()

in empchat/transformer_local.py [0:0]


    def forward(self, input_, mask):
        # Input is [B, seq_len, dim]
        # Mask is [B, seq_len]
        batch_size, seq_len, dim = input_.size()
        assert (
            dim == self.dim
        ), f"Dimensions do not match: {dim} input vs {self.dim} configured"
        n_heads = self.n_heads
        dim_per_head = dim // n_heads

        def prepare_head(tensor):
            # input is [batch_size, seq_len, n_heads * dim_per_head]
            # output is [batch_size * n_heads, seq_len, dim_per_head]
            tensor = tensor.view(batch_size, seq_len, n_heads, dim_per_head)
            tensor = (
                tensor.transpose(1, 2)
                .contiguous()
                .view(batch_size * n_heads, seq_len, dim_per_head)
            )
            return tensor

        in_droped = self.in_dropout(input_)
        query = prepare_head(self.q_lin(in_droped))
        keys = prepare_head(self.k_lin(in_droped))
        values = prepare_head(self.v_lin(in_droped))
        scale = math.sqrt(dim_per_head)
        dot_prod = query.bmm(keys.transpose(1, 2))
        # [B * n_heads, seq_len, seq_len]
        attn_mask = (
            (mask == 0)
            .view(batch_size, 1, 1, seq_len)
            .repeat(1, n_heads, seq_len, 1)
            .view(batch_size * n_heads, seq_len, seq_len)
        )
        dot_prod.masked_fill_(attn_mask, -float("inf"))
        attn_weights = F.softmax(dot_prod / scale, dim=-1)
        attentioned = attn_weights.bmm(values)
        attentioned = (
            attentioned.view(batch_size, n_heads, seq_len, dim_per_head)
            .transpose(1, 2)
            .contiguous()
            .view(batch_size, seq_len, dim)
        )
        return self.out_lin(attentioned)