in metaicl/data.py [0:0]
def _prepro_each_datapoint(self, dp, is_first=True, is_training=False, for_demonstrations=False,
add_newlines=True):
dp = dp.copy()
if add_newlines:
if self.method=="direct":
if not is_first:
dp["input"] = "\n\n\n" + dp["input"]
dp["output"] = "\n" + dp["output"]
if "options" in dp:
dp["options"] = ["\n" + opt for opt in dp["options"]]
elif self.method=="channel":
if not is_first:
dp["output"] = "\n\n\n" + dp["output"]
if "options" in dp:
dp["options"] = ["\n\n\n" + opt for opt in dp["options"]]
dp["input"] = "\n" + dp["input"]
else:
raise NotImplementedError()
else:
if not is_first:
if self.method=="direct":
dp["input"] = " " + dp["input"]
elif self.method=="channel":
dp["output"] = " " + dp["output"]
if "options" in dp:
dp["options"] = [" "+opt for opt in dp["options"]]
else:
raise NotImplementedError()
if self.method=="direct":
dp["output"] = " " + dp["output"]
if "options" in dp:
dp["options"] = [" " + opt for opt in dp["options"]]
elif self.method=="channel":
dp["input"] = " " + dp["input"]
else:
raise NotImplementedError()
input_tokens = self.tokenizer(dp["input"])["input_ids"]
if is_training or for_demonstrations:
output_tokens = self.tokenizer(dp["output"])["input_ids"]
if "task" in dp:
if (dp["task"].startswith("inst:piqa") or dp["task"].startswith("inst:yahoo_answers_topics")) and \
len(input_tokens)+len(output_tokens)+2>self.max_length_per_example:
input_tokens = input_tokens[:self.max_length_per_example // 2]
output_tokens = output_tokens[:self.max_length_per_example // 2 - 2]
elif len(input_tokens)>=self.max_length_per_example - 2 - len(output_tokens):
if dp["task"].startswith("inst:") and len(input_tokens)<len(output_tokens):
output_tokens = output_tokens[:self.max_length_per_example - 2 - len(input_tokens)]
else:
input_tokens = input_tokens[:self.max_length_per_example - 2 - len(output_tokens)]
assert len(input_tokens)+len(output_tokens)+2<=self.max_length_per_example, \
(dp.get("task", None), len(input_tokens), len(output_tokens), self.max_length_per_example)
if self.method=="direct":
return input_tokens, output_tokens
elif self.method=="channel":
return output_tokens, input_tokens
else:
raise NotImplementedError()
else:
assert len(dp["options"])>=2, dp
assert dp["output"] in dp["options"]
option_tokens = [self.tokenizer(option)["input_ids"] for option in dp["options"]]
option_length = np.max([len(option) for option in option_tokens])
if len(input_tokens)>=self.max_length_per_example - 2 - option_length:
input_tokens = input_tokens[:self.max_length_per_example - 2 - option_length]
input_tokens = [input_tokens for _ in option_tokens]
output_tokens = option_tokens
option_tokens = [dp["options"].index(dp["output"])]
if self.method=="direct":
return input_tokens, output_tokens, option_tokens
elif self.method=="channel":
return output_tokens, input_tokens, option_tokens
else:
raise NotImplementedError()