in modules/SwissArmyTransformer/sat/generation/sampling_strategies/beam_search_strategy.py [0:0]
def forward(self, logits, tokens, mems):
batch_size, vocab_size = logits.shape
seq_len = tokens.shape[-1]
if self.context_length is None:
self.context_length = seq_len
logits = logits.float()
penalty_mat = torch.ones_like(logits)
if tokens.shape[-1]> self.context_length:
penalty_mat.scatter_(1,
tokens[:, self.context_length:], torch.ones_like(tokens[:, self.context_length:]).float() * self.repetition_penalty)
penalty_mat *= self.temperature
logits = logits.float() / penalty_mat
for invalid_slice in self.invalid_slices:
logits[..., invalid_slice] = -65504
if self.min_tgt_length > seq_len:
for end_token in self.end_tokens:
logits[..., end_token] = -65504
if self.ngram > 0 and seq_len > self.ngram:
for i in range(batch_size):
ngram_prefix = tokens[i, -(self.ngram-1):].tolist() # TODO ngram=1
for banned_index in self.cached_beam_ngram_bans[i].get(tuple(ngram_prefix), []):
logits[i, banned_index] = -65504
# logits = logits / self.temperature
logits = top_k_logits(logits, self.top_k, self.top_p)
next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
prev_scores = self.cached_beam_scores
if isinstance(self.cached_beam_scores, torch.Tensor):
prev_scores = prev_scores[:, None].expand_as(next_token_scores)
next_token_scores = next_token_scores + prev_scores
next_token_scores = next_token_scores.view(batch_size * vocab_size)
probs = F.softmax(logits.view(batch_size * vocab_size), dim=0)
next_tokens = torch.multinomial(probs,
num_samples=(max(1,len(self.end_tokens))+1) * self.num_beams) # [2*nb]
if get_model_parallel_world_size() > 1:
torch.distributed.broadcast(next_tokens, get_model_parallel_src_rank(), group=get_model_parallel_group())
next_token_scores = next_token_scores[next_tokens]
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0)
next_tokens = next_tokens[_indices]
next_indices = torch.div(next_tokens, vocab_size, rounding_mode='trunc')
next_tokens = next_tokens % vocab_size
# select out end beams or continue beams
if mems.shape[1] < batch_size:
mems = mems.expand(-1, batch_size, -1, -1)
beam_continue = []
scores_continue = []
bans_continue = []
mems_contiue = []
end_beams_changed = False
for i in range(len(next_tokens)):
beam = torch.cat((tokens[next_indices[i]], next_tokens[i:i+1]))
if int(next_tokens[i]) in self.end_tokens:
changed = self._add_end_beams(next_token_scores[i], beam)
end_beams_changed = end_beams_changed or changed
elif len(beam_continue) < self.num_beams:
beam_continue.append(beam)
mems_contiue.append(mems[:, next_indices[i]])
# update caches
scores_continue.append(next_token_scores[i])
if self.ngram > 0:
bans = self.cached_beam_ngram_bans[next_indices[i]].copy()
ngram_prefix = tuple(tokens[next_indices[i], -(self.ngram-1):].tolist()) # TODO ngram=1
bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
bans_continue.append(bans)
else:
break
tokens = torch.stack(beam_continue)
mems = torch.stack(mems_contiue, dim=1)
self.cached_beam_scores = torch.tensor(scores_continue, device=logits.device)
self.cached_beam_ngram_bans = bans_continue
# check if done, this is not a official solution
if end_beams_changed:
self.end_beams_unchanged = 0
elif len(self.end_beams) > 0:
self.end_beams_unchanged += 1
if self.end_beams_unchanged >= self.stop_n_iter_unchanged:
self.is_done = True
return tokens, mems