in clutrr/template_mturk.py [0:0]
def extract_placeholder(df):
"""
Given the AMT annotated datasets, extract the placeholders.
Important to maintain the order of the entities after being matched
For example, to replace a proof state (2,3),(3,4), the order is
important.
For the paper, we provide the set of cleaned train and test splits for the placeholders
See `Clutrr.setup()` for download locations
:param df:
:return:
"""
#skipped = [109] # skipping the Jose - Richard row, shouldn't have approved it
skipped = []
for i, row in df.iterrows():
story = row['paraphrase']
ents_gender = {dd.split(':')[0]: dd.split(':')[1] for dd in row['genders'].split(',')}
words = word_tokenize(story)
ent_id_g = {}
if i in skipped:
continue
# skipping a problematic row where two names are very similar.
# TODO: remove this from the AMT study as well
if 'Micheal' in ents_gender and 'Michael' in ents_gender:
skipped.append(i)
continue
# build entity -> key list
# here order of entity is important, so first we fetch the ordering from
# the proof state
proof = eval(row['proof_state'])
m_built = []
if len(proof) > 0:
built = []
for prd in proof:
pr_lhs = list(prd.keys())[0]
pr_rhs = prd[pr_lhs]
if pr_lhs not in built:
built.extend(pr_rhs)
else:
pr_i = built.index(pr_lhs)
built[pr_i] = pr_rhs
for b in built:
if type(b) != list:
m_built.append(b)
else:
m_built.extend(b)
else:
# when there is no proof state, consider the order from query
query = eval(row['query'])
m_built.append((query[0], '', query[-1]))
# with the proof state, create an ordered ENT_id_gender dict
ent_gender_keys = {}
ordered_ents = []
# add entities in the dictionary
def add_ent(entity):
if entity not in ent_gender_keys:
ent_gender_keys[entity] = 'ENT_{}_{}'.format(len(ent_gender_keys), ents_gender[entity])
ordered_ents.append(entity)
for edge in m_built:
add_ent(edge[0])
add_ent(edge[-1])
if len(ordered_ents) != len(ents_gender):
print(i)
return
for ent_id, (ent, gender) in enumerate(ents_gender.items()):
matches = difflib.get_close_matches(ent, words, cutoff=0.9)
if len(matches) == 0:
print(row['paraphrase'])
print(ent)
return
match_idxs = [i for i, x in enumerate(words) if x in matches]
for wi in match_idxs:
words[wi] = ent_gender_keys[ent]
ent_id_g[ent_id] = gender
gender_key = '-'.join([ents_gender[ent] for ent in ordered_ents])
replaced = detokenizer.detokenize(words, return_str=True)
df.at[i, 'template'] = replaced
df.at[i, 'template_gender'] = gender_key
print('Skipped', skipped)
return df, skipped