in modeling/model.py [0:0]
def proc_coref_output(batch, token_pred, attentions, token_ids, tokenizer, args, config):
'''
process the model output, extract the start/end index and the corresponding words in coreference links
'''
assert isinstance(attentions, tuple)
tokenId2wordId = batch['tokenId2wordId'][0]
# token index of current utterance
curr_start_token_idx = batch['curr_start_token_idx'][0]
curr_end_token_idx = batch['curr_end_token_idx'][0]
curr_utt_token_len = curr_end_token_idx - curr_start_token_idx
# work index of current utterance
curr_utt_word = batch['curr_utt'][0]
curr_utt_word_len = len(curr_utt_word.split())
curr_start_word_idx = tokenId2wordId[curr_start_token_idx]
token_pred = token_pred[0][curr_start_token_idx:].tolist()
assert len(token_pred) == curr_utt_token_len
whole_input = batch['whole_input'][0].split()
recon_input = tokenizer.convert_ids_to_tokens(token_ids[0].tolist())
recon_input = [token.replace('Ġ', '') for token in recon_input]
mention = False
word_pred = [-1] * curr_utt_word_len
links = []
for local_token_idx, step_pred in enumerate(token_pred):
global_token_idx = local_token_idx + curr_start_token_idx # token index in the whole input sequence
# map mention prediction back to word sequence
global_word_idx = tokenId2wordId[global_token_idx]
local_word_idx = global_word_idx-curr_start_word_idx
word_pred[local_word_idx] = step_pred
# formulate the same format as input data
if not mention and step_pred == 1.:
mention = True
ref_start_word_idx, ref_start_token_idx, ref_start_dist = get_ref_word_idx(global_token_idx, attentions, tokenId2wordId, args, config)
_start = {'mention_type': 'start', 'mention_idx': global_word_idx, 'mention_word': whole_input[global_word_idx], \
'attention_idx': ref_start_word_idx, 'attention_word': whole_input[ref_start_word_idx]}
if mention and step_pred == 2.:
mention = False
ref_end_word_idx, ref_end_token_idx, ref_end_dist = get_ref_word_idx(global_token_idx, attentions, tokenId2wordId, args, config)
_end = {'mention_type': 'end', 'mention_idx': global_word_idx, 'mention_word': whole_input[global_word_idx], \
'attention_idx': ref_end_word_idx, 'attention_word': whole_input[ref_end_word_idx]}
get_valid_ref(_start, _end, ref_start_dist, ref_end_dist, tokenId2wordId, whole_input)
links.append([_start, _end])
assert -1 not in word_pred
return [links]