in src_code/modules/rnn_interactive_agent.py [0:0]
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
#In this layer, we perform self attention
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
# Pass through the pre-attention projection: b x lq x (n*dv)
# Separate different heads: b x lq x n x dv
q_ = self.w_qs_1(q).view(sz_b, len_q, n_head, d_k)
k_ = self.w_ks_1(k).view(sz_b, len_k, n_head, d_k)
v_ = self.w_vs_1(v).view(sz_b, len_v, n_head, d_v)
residual1 = q_
# Transpose for attention dot product: b x n x lq x dv
q_, k_, v_ = self.layer_norm_q_1(q_).transpose(1, 2), self.layer_norm_k_1(k_).transpose(1, 2), self.layer_norm_v_1(v_).transpose(1, 2)
if mask is not None:
mask = mask.unsqueeze(1) # For head axis broadcasting.
q_, attn1 = self.attention_1(q_, k_, v_, mask=mask)
# Transpose to move the head dimension back: b x lq x n x dv
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
q_ = q_.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
q_ = self.fc_1(q_)
# In second layer we use attention
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
# Pass through the pre-attention projection: b x lq x (n*dv)
# Separate different heads: b x lq x n x dv
q_ = self.w_qs_2(q_).view(sz_b, len_q, n_head, d_k)
k_ = self.w_ks_2(k).view(sz_b, len_k, n_head, d_k)
v_ = self.w_vs_2(v).view(sz_b, len_v, n_head, d_v)
residual2 = q_
# Transpose for attention dot product: b x n x lq x dv
q_, k_, v_ = self.layer_norm_q_2(q_).transpose(1, 2), self.layer_norm_k_2(k_).transpose(1, 2), self.layer_norm_v_2(v_).transpose(1, 2)
if mask is not None:
mask = mask.unsqueeze(1) # For head axis broadcasting.
q_, attn2 = self.attention_2(q_, k_, v_, mask=mask)
# Transpose to move the head dimension back: b x lq x n x dv
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
q_ = q_.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
q_ = self.fc_2(q_)
return q_, th.cat((residual1, residual2), dim=-1), attn2.squeeze()