def build_cluster()

in modeling/utils/coref_utils.py [0:0]


def build_cluster(context, curr_utt, coref_utt, ground_truth=True):
	'''
		return the cluster index of each token in the context + current utterance
	'''
	assert isinstance(context, list) and isinstance(curr_utt, str) and isinstance(coref_utt, str)
	context = copy.deepcopy(context)
	context.append(curr_utt)
	context = [ s.split() for s in context ]
	coref_utt = coref_utt.split()

	cluster_info = [['-']*len(s) for s in context]
	if ground_truth:
		assert coref_utt.count('<M>') == coref_utt.count('</M>')
		assert coref_utt.count('<R>') == coref_utt.count('</R>')
		assert coref_utt.count('<M>') == coref_utt.count('<R>')

	is_mention, is_reference = False, False
	all_MR = []
	for token in coref_utt:
		if token == '<M>':
			is_mention = True
			mention = []

		elif token == '</M>':
			is_mention = False

		elif token == '<R>':
			is_reference = True
			reference = []

		elif token == '</R>':
			is_reference = False
			mention = ' '.join(mention)
			reference = ' '.join(reference)
			all_MR.append((mention, reference))
		else:
			if is_mention:
				mention.append(token)
			elif is_reference:
				reference.append(token)

	# fill in cluster indexes starting by the short reference since there might be overlapping between references
	clusterIdx2spanList = {}
	all_MR = sorted(all_MR, key= lambda x: len(x[1].split()))
	for mention, reference in all_MR:
		if mention == "" or reference == "":
			continue
		cluster_idx = align_cluster(clusterIdx2spanList, mention, reference)

		# align cluster index as long as the span can be found in context or current utterance
		fill_in_cluster_info(cluster_idx, cluster_info[-1], context[-1], mention.split())
		for sent_idx, sent in enumerate(context): # consider reference in current utterance as well
			fill_in_cluster_info(cluster_idx, cluster_info[sent_idx], sent, reference.split())

	return cluster_info