def collect_coref_hiddens_layer()

in modeling/model.py [0:0]


	def collect_coref_hiddens_layer(self, meta_link_batch, hiddens, batch):
		'''
		collect coref hiddens from one layer hiddens (B, T, H)
		'''
		B = hiddens.size(0)
		assert len(meta_link_batch) == B
		coref_hiddens_batch = []
		for b_idx, meta_link_example in enumerate(meta_link_batch):
			wordId2tokenId = batch['wordId2tokenId'][b_idx]
			start_end_list = []
			for start_end_link in meta_link_example:
				assert start_end_link[0]['mention_type'] == 'start'
				assert start_end_link[1]['mention_type'] == 'end'
				m_word_idx_start = start_end_link[0]['mention_idx']
				m_word_idx_end   = start_end_link[1]['mention_idx']
				r_word_idx_start = start_end_link[0]['attention_idx']
				r_word_idx_end   = start_end_link[1]['attention_idx']

				m_token_idx_start =  wordId2tokenId[m_word_idx_start][0]
				m_token_idx_end   =  wordId2tokenId[m_word_idx_end][-1]
				r_token_idx_start =  wordId2tokenId[r_word_idx_start][0]
				r_token_idx_end   =  wordId2tokenId[r_word_idx_end][-1]

				# mention/reference_start/end_token_idx
				if self.args.coref_attn_mention and m_token_idx_start < m_token_idx_end: # only consider reasonable reasonable predictions
					start_end_list.append((m_token_idx_start, m_token_idx_end))
				if r_token_idx_start < r_token_idx_end:
					start_end_list.append((r_token_idx_start, r_token_idx_end))

			if len(start_end_list) > 0: # has at least one coref link
				start_end_list = sorted(start_end_list, key=lambda x: x[0]) # sort by start_idx
				coref_hiddens_example = []
#				if self.args.coref_attn_zeros:
				coref_hiddens_example.append( torch.zeros(1, 1, self.config.n_embd).to(self.args.device) )
				for start_idx, end_idx in start_end_list:
					coref_hiddens_example.append( hiddens[b_idx, start_idx: end_idx, :].unsqueeze(0) ) # (1, T'', H)
				coref_hiddens_example = torch.cat(coref_hiddens_example, dim=1) # (1, T', H)
			else:
				coref_hiddens_example = torch.zeros(1, 1, self.config.n_embd).to(self.args.device) # (1, 1, H)
			coref_hiddens_batch.append(coref_hiddens_example)

		assert len(coref_hiddens_batch) == B
		# padding
		coref_len_batch = [ x.size(1) for x in coref_hiddens_batch]
		max_coref_len = max(coref_len_batch)
		mask = []
		for b_idx in range(B):
			coref_len = coref_len_batch[b_idx]
			pad_len = max_coref_len - coref_len
			mask.append( [1]*coref_len + [0]*pad_len )
			coref_hiddens_batch[b_idx] = torch.cat([coref_hiddens_batch[b_idx], torch.zeros(1, pad_len, self.config.n_embd).to(self.args.device)], dim=1)

		coref_hiddens_batch = torch.cat(coref_hiddens_batch, dim=0) # (B, T', H)
		mask = torch.tensor(mask).float().to(self.args.device) # (B, T')
		assert coref_hiddens_batch.size() == (B, max_coref_len, self.config.n_embd)
		assert mask.size() == (B, max_coref_len)
		coref_attn = {'hiddens': coref_hiddens_batch, 'mask': mask}
		return coref_attn