in modeling/utils/coref_utils.py [0:0]
def build_cluster(context, curr_utt, coref_utt, ground_truth=True):
'''
return the cluster index of each token in the context + current utterance
'''
assert isinstance(context, list) and isinstance(curr_utt, str) and isinstance(coref_utt, str)
context = copy.deepcopy(context)
context.append(curr_utt)
context = [ s.split() for s in context ]
coref_utt = coref_utt.split()
cluster_info = [['-']*len(s) for s in context]
if ground_truth:
assert coref_utt.count('<M>') == coref_utt.count('</M>')
assert coref_utt.count('<R>') == coref_utt.count('</R>')
assert coref_utt.count('<M>') == coref_utt.count('<R>')
is_mention, is_reference = False, False
all_MR = []
for token in coref_utt:
if token == '<M>':
is_mention = True
mention = []
elif token == '</M>':
is_mention = False
elif token == '<R>':
is_reference = True
reference = []
elif token == '</R>':
is_reference = False
mention = ' '.join(mention)
reference = ' '.join(reference)
all_MR.append((mention, reference))
else:
if is_mention:
mention.append(token)
elif is_reference:
reference.append(token)
# fill in cluster indexes starting by the short reference since there might be overlapping between references
clusterIdx2spanList = {}
all_MR = sorted(all_MR, key= lambda x: len(x[1].split()))
for mention, reference in all_MR:
if mention == "" or reference == "":
continue
cluster_idx = align_cluster(clusterIdx2spanList, mention, reference)
# align cluster index as long as the span can be found in context or current utterance
fill_in_cluster_info(cluster_idx, cluster_info[-1], context[-1], mention.split())
for sent_idx, sent in enumerate(context): # consider reference in current utterance as well
fill_in_cluster_info(cluster_idx, cluster_info[sent_idx], sent, reference.split())
return cluster_info