in empchat/transformer_local.py [0:0]
def forward(self, input_, mask):
# Input is [B, seq_len, dim]
# Mask is [B, seq_len]
batch_size, seq_len, dim = input_.size()
assert (
dim == self.dim
), f"Dimensions do not match: {dim} input vs {self.dim} configured"
n_heads = self.n_heads
dim_per_head = dim // n_heads
def prepare_head(tensor):
# input is [batch_size, seq_len, n_heads * dim_per_head]
# output is [batch_size * n_heads, seq_len, dim_per_head]
tensor = tensor.view(batch_size, seq_len, n_heads, dim_per_head)
tensor = (
tensor.transpose(1, 2)
.contiguous()
.view(batch_size * n_heads, seq_len, dim_per_head)
)
return tensor
in_droped = self.in_dropout(input_)
query = prepare_head(self.q_lin(in_droped))
keys = prepare_head(self.k_lin(in_droped))
values = prepare_head(self.v_lin(in_droped))
scale = math.sqrt(dim_per_head)
dot_prod = query.bmm(keys.transpose(1, 2))
# [B * n_heads, seq_len, seq_len]
attn_mask = (
(mask == 0)
.view(batch_size, 1, 1, seq_len)
.repeat(1, n_heads, seq_len, 1)
.view(batch_size * n_heads, seq_len, seq_len)
)
dot_prod.masked_fill_(attn_mask, -float("inf"))
attn_weights = F.softmax(dot_prod / scale, dim=-1)
attentioned = attn_weights.bmm(values)
attentioned = (
attentioned.view(batch_size, n_heads, seq_len, dim_per_head)
.transpose(1, 2)
.contiguous()
.view(batch_size, seq_len, dim)
)
return self.out_lin(attentioned)