in metaicl/data.py [0:0]
def _tensorize_for_training(self, train_data):
for dp in train_data:
assert type(dp)==dict, ("Each example should be a dictionary", dp)
assert "input" in dp and "output" in dp, ("Training example should contain input and output", dp)
# each datapoint: passage, question, options, output
bos_token_id = self.tokenizer.bos_token_id
eos_token_id = self.tokenizer.eos_token_id
input_ids, attention_mask, token_type_ids = [], [], []
n_answers = []
if self.use_demonstrations:
first_tokenized = []
nonfirst_tokenized = []
for dp in train_data:
first_tokenized.append(self._prepro_each_datapoint(
dp, is_first=True, is_training=True))
nonfirst_tokenized.append(self._prepro_each_datapoint(
dp, is_first=False, is_training=True))
N=1
def _draw_random(tot, n, exclude_indices):
r = np.random.choice([i for i in range(tot) if i not in exclude_indices])
if n==1:
return [r]
return [r] + _draw_random(tot, n-1, exclude_indices | set([r]))
for dp_idx, dp in enumerate(train_data):
for _ in range(N):
demon_indices = _draw_random(len(train_data), self.k, set([dp_idx]))
inputs = []
for demon_idx, index in enumerate(demon_indices):
if demon_idx==0:
inputs += first_tokenized[index][0] + first_tokenized[index][1]
else:
inputs += nonfirst_tokenized[index][0] + nonfirst_tokenized[index][1]
assert index!=dp_idx
inputs += nonfirst_tokenized[dp_idx][0]
outputs = nonfirst_tokenized[dp_idx][1]
encoded = prepro_sentence_pair_single(
inputs, outputs, self.max_length, bos_token_id, eos_token_id,
allow_truncation=True)
input_ids.append(encoded[0])
attention_mask.append(encoded[1])
token_type_ids.append(encoded[2])
else:
for dp in train_data:
inputs, outputs = self._prepro_each_datapoint(
dp, is_first=True, is_training=True)
encoded = prepro_sentence_pair_single(
inputs, outputs, self.max_length, bos_token_id, eos_token_id)
input_ids.append(encoded[0])
attention_mask.append(encoded[1])
token_type_ids.append(encoded[2])
return dict(input_ids=torch.LongTensor(input_ids),
attention_mask=torch.LongTensor(attention_mask),
token_type_ids=torch.LongTensor(token_type_ids))