in data.py [0:0]
def get_sents(self, item, multiple_false_responses=False, use_backtranslate=False):
item_id = self.indices[item]
row = self.data.dialogs["true_response"].loc[item_id]
# if self.bert_input:
# context = row['bert_context']
# true_response = row['bert_true_response']
# else:
context = row["context"]
corrupt_context = None
true_response = row["true_response"]
context = context.split("\n")
dialog_id = row["dialog_id"]
context_id = row["context_id"]
if use_backtranslate:
flip = random.uniform(0, 1)
if flip > 0.5:
true_response = self.get_response('backtranslate', dialog_id, context_id, variable=False)
frs = []
if self.args.corrupt_type == "only_syntax":
for fs in only_syntax:
frs.append(self.get_response(fs, dialog_id, context_id, variable=fs in variable_suffixes))
elif self.args.corrupt_type == "only_semantics":
for fs in only_semantics:
frs.append(self.get_response(fs, dialog_id, context_id, variable=fs in variable_suffixes))
elif self.args.corrupt_type in ["all","all_context"]:
for fs in all_corrupt:
frs.append(self.get_response(fs, dialog_id, context_id, variable=fs in variable_suffixes))
if self.args.corrupt_type == "all_context":
corrupt_context = self.get_context("corrupt_context", dialog_id, context_id, variable=True)
frs.append(self.get_response("corrupt_context", dialog_id, context_id, variable=True))
else:
variable = self.args.corrupt_type in variable_suffixes
# false_response = [self.get_response(
# self.args.corrupt_type, dialog_id, context_id, variable=variable
# )]
frs = []
for ri, response in enumerate(self.get_next_response(self.args.corrupt_type, dialog_id, context_id)):
if ri > self.args.num_nce - 1:
break
frs.append(response)
false_response = frs
if multiple_false_responses:
false_response = frs
else:
false_response = [random.choice(frs)]
return context, true_response, false_response, corrupt_context