def proc_coref_output()

in modeling/model.py [0:0]


def proc_coref_output(batch, token_pred, attentions, token_ids, tokenizer, args, config):
	'''
		process the model output, extract the start/end index and the corresponding words in coreference links
	'''
	assert isinstance(attentions, tuple)
	tokenId2wordId = batch['tokenId2wordId'][0]

	# token index of current utterance
	curr_start_token_idx = batch['curr_start_token_idx'][0]
	curr_end_token_idx = batch['curr_end_token_idx'][0]
	curr_utt_token_len = curr_end_token_idx - curr_start_token_idx

	# work index of current utterance
	curr_utt_word = batch['curr_utt'][0]
	curr_utt_word_len = len(curr_utt_word.split())
	curr_start_word_idx = tokenId2wordId[curr_start_token_idx]
	
	token_pred = token_pred[0][curr_start_token_idx:].tolist()
	assert len(token_pred) == curr_utt_token_len

	whole_input = batch['whole_input'][0].split()
	recon_input = tokenizer.convert_ids_to_tokens(token_ids[0].tolist())
	recon_input = [token.replace('Ġ', '') for token in recon_input]

	mention = False
	word_pred = [-1] * curr_utt_word_len
	links = []
	for local_token_idx, step_pred in enumerate(token_pred):
		global_token_idx = local_token_idx + curr_start_token_idx # token index in the whole input sequence

		# map mention prediction back to word sequence
		global_word_idx = tokenId2wordId[global_token_idx]
		local_word_idx = global_word_idx-curr_start_word_idx
		word_pred[local_word_idx] = step_pred

		# formulate the same format as input data
		if not mention and step_pred == 1.:
			mention = True
			ref_start_word_idx, ref_start_token_idx, ref_start_dist = get_ref_word_idx(global_token_idx, attentions, tokenId2wordId, args, config)
			_start = {'mention_type': 'start', 'mention_idx': global_word_idx, 'mention_word': whole_input[global_word_idx], \
						'attention_idx': ref_start_word_idx, 'attention_word': whole_input[ref_start_word_idx]}
		if mention and step_pred == 2.:
			mention = False
			ref_end_word_idx, ref_end_token_idx, ref_end_dist = get_ref_word_idx(global_token_idx, attentions, tokenId2wordId, args, config)
			_end = {'mention_type': 'end', 'mention_idx': global_word_idx, 'mention_word': whole_input[global_word_idx], \
						'attention_idx': ref_end_word_idx, 'attention_word': whole_input[ref_end_word_idx]}

			get_valid_ref(_start, _end, ref_start_dist, ref_end_dist, tokenId2wordId, whole_input)
			links.append([_start, _end])

	assert -1 not in word_pred
	return [links]