in data.py [0:0]
def __getitem__(self, item):
"""
Return single instance
:param item:
:return:
"""
# inter = self.interactions[item]
# dialog_id = inter['dialog_id']
# context_id = inter['context_id']
# if self.args.corrupt_type == "all":
# typs = ["rand_utt","drop","shuffle","model_true","model_false"]
# self.args.corrupt_type = random.choice(typs)
# if self.args.corrupt_type == "rand_utt":
# sc = self.cd.random_clean(dialog_id=dialog_id)
# elif self.args.corrupt_type == "drop":
# sc = self.cd.random_drop(self.data.dialogs[dialog_id][context_id],
# drop=self.args.drop_per)
# elif self.args.corrupt_type == "shuffle":
# sc = self.cd.change_word_order(self.data.dialogs[dialog_id][context_id])
# elif self.args.corrupt_type in ["model_true", "model_false"]:
# sc = self.cd.get_nce_semantics(dialog_id, context_id)
# else:
# raise NotImplementedError("args.corrupt_type {} not implemented".format(
# self.args.corrupt_type))
X_hat = None
multiple_false_responses = self.args.train_mode == "nce"
context, true_response, false_responses, corrupt_context = self.get_sents(
item, multiple_false_responses=multiple_false_responses
)
# tokenize X and Y
X = [self.data.tokenizer.tokenize(sent) for sent in context]
X = [self.data.tokenizer.convert_tokens_to_ids(sent) for sent in X]
if corrupt_context:
X_hat = [self.data.tokenizer.tokenize(sent) for sent in corrupt_context]
X_hat = [self.data.tokenizer.convert_tokens_to_ids(sent) for sent in X_hat]
Y = self.data.tokenizer.convert_tokens_to_ids(
self.data.tokenizer.tokenize(true_response)
)
# if type(false_response) != str:
# print(context)
# print(true_response)
# print(false_response)
assert type(false_responses) == list
Y_hats = [
self.data.tokenizer.convert_tokens_to_ids(self.data.tokenizer.tokenize(fr))
for fr in false_responses
]
# Y_hat = self.data.tokenizer.convert_tokens_to_ids(self.data.tokenizer.tokenize(false_response))
if self.bert_input:
if self.is_transition_fn:
X = [
self.data.tokenizer.build_inputs_with_special_tokens(sent)
for sent in X
]
if corrupt_context:
X_hat = [
self.data.tokenizer.build_inputs_with_special_tokens(sent)
for sent in X_hat
]
else:
# flatten
X = [word for sent in X for word in sent]
X = self.data.tokenizer.build_inputs_with_special_tokens(X)
if corrupt_context:
X_hat = [word for sent in X_hat for word in sent]
X_hat = self.data.tokenizer.build_inputs_with_special_tokens(X_hat)
Y = self.data.tokenizer.build_inputs_with_special_tokens(Y)
Y_hats = [
self.data.tokenizer.build_inputs_with_special_tokens(Y_hat)
for Y_hat in Y_hats
]
else:
# flatten X
X = [word for sent in X for word in sent]
if corrupt_context:
X_hat = [word for sent in X_hat for word in sent]
# # corrupt context if needed
# if self.args.train_mode != 'ref_score':
# if self.args.corrupt_context_type != 'none':
# inter['X_hat'] = self.cd.get_full_corrupt_context(dialog_id, len(inter['X']))
# X_hat = [self.data.tokenizer.tokenize(sent) for sent in inter['X_hat']]
# X_hat = [self.data.tokenizer.convert_tokens_to_ids(sent) for sent in X_hat]
# assert len(X) > 0
# assert len(Y) > 0
# assert len(Y_hat) > 0
# if len(Y_hats) == 1:
# return X, Y, Y_hats[0], None
# else:
return X, Y, Y_hats, X_hat