in low_rank_comparisons/src/model.py [0:0]
def forward(self, w, r, attn_mask=None, mems=None):
# def forward(self, _h, _c):
# if self.pre_hnorm:
# h = self.h_layer_norm(_h)
# else:
# h = _h
# if self.pre_cnorm:
# c = self.c_layer_norm(_c)
# else:
# c = _c
# head_q = self.q_net(h)
# head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
# head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
# head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
# head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
###################### run through the memory component.
#qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
qlen = w.size(0)
if r is not None:
rlen = r.size(0)
else:
rlen = 0
bsz = w.size(1)
if self.cross_attn:
if self.pre_cnorm:
mems = self.c_layer_norm(mems)
# qlen, bsz, dim
w_head_q = self.q_net(w)
w_head_k, w_head_v = torch.chunk(self.kv_net(mems), 2, dim=-1)
if r is not None:
r_head_k = self.r_net(r)
if self.rel_pos == 'full':
r_head_q = self.r_q_net(r) #
else:
#assert qlen == rlen
# w: qlen, bsz, dim
if mems is not None:
if mems.dtype != w.dtype:
mems = mems.half()
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
r_head_k = self.r_net(r) #
if self.rel_pos == 'full':
r_head_q = self.r_q_net(r) #
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
# position_embedding.
# k_len, 1, dim
r_head_k = self.r_net(r)
if self.rel_pos == 'full':
r_head_q = self.r_q_net(r)
# q, k, v
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
#if torch.isnan(self.r_r_bias.mean()):
# print('r_r_bias has nan', self.r_r_bias)
#else:
# print('r_r_bias has no nan.', self.r_r_bias)
#if torch.isnan(self.r_w_bias.mean()):
# print('r_w_bias has nan', self.r_w_bias)
#else:
# print('r_w_bias has no nan', self.r_w_bias)
#assert qlen == rlen
# w: qlen, bsz, dim
#if torch.isnan(w_head_q.mean()):
# print('w_head_q has nan.')
#if torch.isnan(w_head_q.mean()):
# print('w_head_q has nan.')
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head
if r is not None:
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # rlen x n_head x d_head
if self.rel_pos == 'full':
r_head_q = r_head_q.view(rlen, self.n_head, self.d_head)
assert rlen == klen or rlen == klen + klen - 1 or rlen == qlen + klen
#### compute attention score
# r_w_bias is useless here.
#r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # rlen x n_head x d_head
#assert rlen == klen
###########full attention
# BC : position to context # klen x klen x bsz x n_head
#print('r_head_q', r_head_q)
#print('w_head_k', w_head_k)
#attn_score.mul_(self.scale)
# best transformer-xl: 19.2 vs 19.4 vs reported 18.3 : how about large batch size.
#### compute attention probability
# attn_mask : qlen, ken, 1
######################################
#if self.r_w_bias.dtype != w_head_q.dtype:
# self.r_w_bias = self.r_w_bias.half()
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q.float(), w_head_k.float())) # qlen x klen x bsz x n_head
if r is not None:
rr_head_q = w_head_q + self.r_r_bias # qlen x bsz x n_head x d_head
BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q.float(), r_head_k.float())) # qlen x rlen x bsz x n_head
if rlen == klen:
BD = self._rel_shift(BD)
elif rlen == klen + klen - 1:
assert klen == qlen
BD = self._rel_shift_bias(BD, klen)
elif rlen == qlen + klen:
BD = self._rel_shift_bias(BD, klen)
if self.rel_pos == 'full':
BC = torch.einsum('ind,jbnd->ijbn', (r_head_q.float(), w_head_k.float())) # rlen x klen x bsz x n_head
#assert klen == qlen
if rlen == klen:
BC = self._rel_shift_trans(BC, qlen) # rlen x klen x bsz x n_head
BC = self._rel_shift(BC)
#def _rel_shift_trans(self, x, qlen):
# [qlen x klen x bsz x n_head]
#attn_score = AC + BD + BC #* self.scale #/10.0
elif rlen == klen + klen - 1:
#assert klen == qlen
BC = torch.einsum('ind,jbnd->ijbn', (r_head_q.float(), w_head_k.float())) # rlen x klen x bsz x n_head
BC = self._rel_shift_trans(BC, qlen) # rlen x klen x bsz x n_head
BC = self._rel_shift(BC)[:, :klen]
#def _rel_shift_trans(self, x, qlen):
# [qlen x klen x bsz x n_head]
#attn_score = AC + BD + BC #* self.scale #/10.0
elif rlen == klen + qlen:
BC = self._rel_shift_trans_bias(BC, qlen)
attn_score = AC + BD + BC
else:
#def _rel_shift_bias(self, x):
#if torch.isnan(BD.mean()):
# print('BD has nan.')
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
else:
attn_score = AC
#if torch.isnan(attn_score.mean()):
# print('attn score (before scale) has nan.', self.scale)
attn_score.mul_(self.scale)
#if torch.isnan(attn_score.mean()):
# print('attn score (before mask) has nan.')
#### compute attention probability
# attn_mask : qlen, ken, 1
if attn_mask is not None and attn_mask.any().item():
# attn_mask : klen, bsz (used for maskedlm.)
if attn_mask.dim() == 2:
attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None].bool(), -float('inf')).type_as(attn_score)
# attn_mask : qlen, klen, 1 (used for lm.)
elif attn_mask.dim() == 3:
attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None].bool(), -float('inf')).type_as(attn_score)
# [qlen x klen x bsz x n_head]
#if torch.isnan(attn_score.mean()):
# print('attn score (after mask) has nan.')
attn_prob = F.softmax(attn_score, dim=1)
#if torch.isnan(attn_prob.mean()):
# print('attn prob has nan.')
attn_prob = self.dropatt(attn_prob)
#### compute attention vector
# [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v.type_as(attn_prob))).type_as(w)
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)#.type_as(w)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = w + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)
return output