in low_rank_comparisons/src/gpt2_beam.py [0:0]
def beam(model, data_iter, args):
model.eval()
total_loss = 0.
start_time = time.time()
all_predictions = {}
with torch.no_grad():
for idx, data in enumerate(data_iter):
data = {key: value for key, value in data.items()}
_id = data['id'].to(args.device)
_query = data['query'].to(args.device)
_query_len = data['query_len'].to(args.device)
## local adaptation start.
## local adaptation end.
output = None
score = None
batch_size = _id.size(0)
num_beams = args.beam
length_penalty = args.length_penalty
_batch = torch.arange(0, _id.size(0), device=args.device, dtype=torch.long)
past = None
len_past = None
_query = _query.repeat(1, num_beams).view(batch_size * num_beams, -1)
_query_len = _query_len.unsqueeze(-1).repeat(1, num_beams).view(-1)
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=_query.device)
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
#beam_scores[:, 1:] = -1e9
#beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
#self.beam_scores = []
#beam_tokens = []
#beam_idxes = []
best_sequence = torch.zeros((batch_size, args.eval_len), dtype=torch.long, device=_query.device)
best_score = {}
history = None
with torch.no_grad():
for i in range(0, args.eval_len):
if i == 0:
logits, past = model(_query)
logits = logits[_batch, (_query_len-1).long(), :] # batch_size * beam, vocab
else:
#print('token_id.shape', token_id.shape, token_id)
#print('past.shape', past[0].shape)
#print('len_past.shape', len_past.shape, len_past)
logits, past = model(token_id, past=past, len_past=len_past)
logits = logits[:, -1, :] # batch_size * beam, vocab
logits = _postprocess_next_token_scores(
logits,
history,
i,
batch_size,
num_beams,
repetition_penalty = args.repetition_penalty,
no_repeat_ngram_size = args.no_repeat_ngram_size,
min_length=args.min_length,
eos_token_id=args.eos_token_id,
)
softmax_probs = F.softmax(logits, dim=-1)
##_prob, _w_idx = torch.topk(softmax_probs, num_beams) # batch_size, beam
vocab_size = softmax_probs.shape[-1]
_logprob = torch.log(softmax_probs) # batch_size * beam, vocab
if i == 0:
next_scores = _logprob.view(batch_size, num_beams, -1)[:, 0, :] # batch_size, vocab
else:
next_scores = beam_scores.unsqueeze(-1) + _logprob.view(batch_size, num_beams, -1)
next_scores = next_scores.view(batch_size, -1) # batch_size, beam * vocab
#else:
# next_scores = _logprob + beam_scores[:, None].expand_as(_logprob) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
#next_scores = next_scores.view(
# batch_size, num_beams * vocab_size
#) # (batch_size, num_beams * vocab_size)
#print('vocab_size', vocab_size)
#print('next_scores.shape (1)', next_scores.shape)
next_scores, next_tokens = torch.topk(next_scores, num_beams, dim=1, largest=True, sorted=True) # batch_size, num_beams
#print('next_scores.shape (2)', next_scores.shape, next_scores)
#print('next_tokens.shape (2)', next_tokens.shape, next_tokens)
beam_id = (next_tokens // vocab_size).view(-1) # batch_size * num_beams
token_id = (next_tokens % vocab_size).view(-1).unsqueeze(-1) # batch_size, num_beams
beam_idx = beam_id.view(batch_size, num_beams) + (_batch * num_beams).unsqueeze(-1)
# past, 2, batch_size * beam, *, *, *,
#if past is not None:
#print('beam_id', beam_id)
#print('beam_idx', beam_idx)
#print('token_id', token_id.shape, token_id)
#print('past.shape (1)', past[0].shape)
past = _reorder_cache(past, beam_idx.view(-1))
#print('past.shape (2)', past[0].shape)
beam_scores = next_scores # batch_size, num_beams
len_past = (_query_len + i).long()
if history is None:
history = token_id.detach()
else:
history = torch.cat((history[beam_idx.view(-1)], token_id.detach()), dim=1).detach()
#print('history.shape (1)', history.shape)
_add_beam_candidate(best_score, best_sequence, batch_size, num_beams, beam_scores, history, eos_token_id = args.eos_token_id)
_add_beam_candidate(best_score, best_sequence, batch_size, num_beams, beam_scores, history)
with torch.no_grad():
_id = distributed_gather(args, _id)
output = distributed_gather(args, best_sequence)
#score = distributed_gather(args, score)
distributed_sync(args)
if args.rank == 0:
_id = _id.view(-1).cpu()
output = output.view(-1, output.shape[-1]).cpu()
#score = score.view(-1, score.shape[-1]).cpu()
for _b in range(0, _id.shape[-1]):
_i = int(_id[_b].item())
all_predictions[_i] = {}
all_predictions[_i]['id'] = _i
all_predictions[_i]['predict'] = output[_b].tolist()
#all_predictions[_i]['score'] = score[_b].tolist()
if idx % 10 == 0:
print('inference samples', idx)
if args.rank == 0:
pred_file = os.path.join(args.work_dir, args.output_file)
print('saving prediction file', pred_file)
with open(pred_file, 'w') as writer:
for _i in all_predictions:
writer.write(json.dumps(all_predictions[_i]) + '\n')