in modeling/model.py [0:0]
def collect_coref_hiddens_layer(self, meta_link_batch, hiddens, batch):
'''
collect coref hiddens from one layer hiddens (B, T, H)
'''
B = hiddens.size(0)
assert len(meta_link_batch) == B
coref_hiddens_batch = []
for b_idx, meta_link_example in enumerate(meta_link_batch):
wordId2tokenId = batch['wordId2tokenId'][b_idx]
start_end_list = []
for start_end_link in meta_link_example:
assert start_end_link[0]['mention_type'] == 'start'
assert start_end_link[1]['mention_type'] == 'end'
m_word_idx_start = start_end_link[0]['mention_idx']
m_word_idx_end = start_end_link[1]['mention_idx']
r_word_idx_start = start_end_link[0]['attention_idx']
r_word_idx_end = start_end_link[1]['attention_idx']
m_token_idx_start = wordId2tokenId[m_word_idx_start][0]
m_token_idx_end = wordId2tokenId[m_word_idx_end][-1]
r_token_idx_start = wordId2tokenId[r_word_idx_start][0]
r_token_idx_end = wordId2tokenId[r_word_idx_end][-1]
# mention/reference_start/end_token_idx
if self.args.coref_attn_mention and m_token_idx_start < m_token_idx_end: # only consider reasonable reasonable predictions
start_end_list.append((m_token_idx_start, m_token_idx_end))
if r_token_idx_start < r_token_idx_end:
start_end_list.append((r_token_idx_start, r_token_idx_end))
if len(start_end_list) > 0: # has at least one coref link
start_end_list = sorted(start_end_list, key=lambda x: x[0]) # sort by start_idx
coref_hiddens_example = []
# if self.args.coref_attn_zeros:
coref_hiddens_example.append( torch.zeros(1, 1, self.config.n_embd).to(self.args.device) )
for start_idx, end_idx in start_end_list:
coref_hiddens_example.append( hiddens[b_idx, start_idx: end_idx, :].unsqueeze(0) ) # (1, T'', H)
coref_hiddens_example = torch.cat(coref_hiddens_example, dim=1) # (1, T', H)
else:
coref_hiddens_example = torch.zeros(1, 1, self.config.n_embd).to(self.args.device) # (1, 1, H)
coref_hiddens_batch.append(coref_hiddens_example)
assert len(coref_hiddens_batch) == B
# padding
coref_len_batch = [ x.size(1) for x in coref_hiddens_batch]
max_coref_len = max(coref_len_batch)
mask = []
for b_idx in range(B):
coref_len = coref_len_batch[b_idx]
pad_len = max_coref_len - coref_len
mask.append( [1]*coref_len + [0]*pad_len )
coref_hiddens_batch[b_idx] = torch.cat([coref_hiddens_batch[b_idx], torch.zeros(1, pad_len, self.config.n_embd).to(self.args.device)], dim=1)
coref_hiddens_batch = torch.cat(coref_hiddens_batch, dim=0) # (B, T', H)
mask = torch.tensor(mask).float().to(self.args.device) # (B, T')
assert coref_hiddens_batch.size() == (B, max_coref_len, self.config.n_embd)
assert mask.size() == (B, max_coref_len)
coref_attn = {'hiddens': coref_hiddens_batch, 'mask': mask}
return coref_attn