in optimum/graphcore/generation/on_device_generation.py [0:0]
def forward(self, input_ids: torch.Tensor, absolute_step: torch.Tensor, **kwargs) -> torch.Tensor:
# Workaround for generic slice assignment self.running_sequences[:, self.absolute_step] = tokens
assert input_ids.shape[-1] == 1
input_ids_mask = self.input_ids_mask.to(input_ids.device)
padded_input_ids = torch.nn.functional.pad(input_ids, (0, self.max_length - 1))
padded_input_ids = self._unflatten_beam_dim(padded_input_ids, self.num_beams).int()
# Since we are constrained to keeping buffers in float, including ones holding tokens which are ints,
# we do most of the processing as int and cast just before writing.
sequences = self.sequences.int()
running_sequences = input_ids_mask * padded_input_ids + (1 - input_ids_mask) * self.running_sequences.int()
is_finished = self.is_finished.int()
# 0. Check for termination of previous step
not_max_length_yet = absolute_step < self.max_length
worst_finished_score = torch.where(
torch.any(is_finished, axis=1),
torch.amin(self.log_probs, axis=1),
torch.ones(self.batch_size) * LARGE_NEGATIVE_CONST,
)
if self.early_stopping == "never" and self.length_penalty > 0.0:
best_running_score = self.running_log_probs[:, 0] / (self.max_length**self.length_penalty)
else:
best_running_score = self.running_log_probs[:, 0] / (absolute_step**self.length_penalty)
improvement_still_possible = torch.any(best_running_score > worst_finished_score)
still_open_beam = ~(torch.all(is_finished) & (self.early_stopping is True))
continue_search = not_max_length_yet & still_open_beam & improvement_still_possible
# 1. Return best beam for each batch and beam indices from previous step
# Account for the edge-case where there are no finished sequences for a
# particular batch item. If so, return running sequences for that batch item.
none_finished = torch.any(is_finished, axis=1)
return_sequences = torch.where(none_finished[:, None, None], sequences, running_sequences)
return_sequences = return_sequences[:, 0]
# 2. Get logits
model_input_ids = torch.index_select(running_sequences, 2, absolute_step - 1)
model_input_ids = self._flatten_beam_dim(model_input_ids, self.num_beams)
logits = self.model(decoder_input_ids=model_input_ids, **kwargs)
if hasattr(logits, "logits"):
logits = logits.logits
logits = logits.squeeze(1).float()
# 3. Compute log probs
vocab_size = logits.shape[-1]
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
log_probs = self.logits_processor(running_sequences, log_probs, absolute_step=absolute_step)
log_probs = self._unflatten_beam_dim(log_probs, self.num_beams)
log_probs = log_probs + self.running_log_probs.unsqueeze(-1)
log_probs = log_probs.view(self.batch_size, self.num_beams * vocab_size)
# 4. Retrieve top-2*K
beams_to_keep = 2 * self.num_beams
topk_log_probs, topk_indices = torch.topk(log_probs, k=beams_to_keep)
topk_beam_indices = torch.div(topk_indices, vocab_size).int()
topk_running_sequences = self._gather_beams(
running_sequences, topk_beam_indices, self.batch_size, self.num_beams, beams_to_keep
)
topk_ids = topk_indices % vocab_size
topk_sequences = poptorch.dynamic_update(
topk_running_sequences, topk_ids.unsqueeze(-1).int(), 2, absolute_step, 1
)
# 5. Check which sequences have ended
did_topk_just_finish = topk_ids == self.eos_token_id
running_topk_log_probs = topk_log_probs + did_topk_just_finish * LARGE_NEGATIVE_CONST
# 6. Get running sequences scores for next
next_topk_indices = torch.topk(running_topk_log_probs, k=self.num_beams)[1]
next_running_sequences = self._gather_beams(
topk_sequences, next_topk_indices, self.batch_size, beams_to_keep, self.num_beams
)
next_running_log_probs = self._gather_beams(
running_topk_log_probs, next_topk_indices, self.batch_size, beams_to_keep, self.num_beams
)
# 7. Process topk logits
topk_log_probs = topk_log_probs / (absolute_step**self.length_penalty)
beams_in_batch_are_full = is_finished.all(axis=-1, keepdims=True).repeat(1, did_topk_just_finish.shape[-1]) & (
self.early_stopping is True
)
add_penalty = ~did_topk_just_finish | beams_in_batch_are_full
topk_log_probs += add_penalty * LARGE_NEGATIVE_CONST
# 8. Get scores, sequences, is sentence finished for next.
merged_sequences = torch.cat([sequences, topk_sequences], axis=1)
merged_log_probs = torch.cat([self.log_probs, topk_log_probs], axis=1)
merged_is_finished = torch.cat([is_finished, did_topk_just_finish], axis=1)
topk_merged_indices = torch.topk(merged_log_probs, k=self.num_beams)[1]
next_sequences = self._gather_beams(
merged_sequences, topk_merged_indices, self.batch_size, 3 * self.num_beams, self.num_beams
)
next_log_probs = self._gather_beams(
merged_log_probs, topk_merged_indices, self.batch_size, 3 * self.num_beams, self.num_beams
)
next_is_finished = self._gather_beams(
merged_is_finished, topk_merged_indices, self.batch_size, 3 * self.num_beams, self.num_beams
)
# 9. Determine the top k beam indices from the original set of all beams.
next_running_indices = self._gather_beams(
topk_beam_indices, next_topk_indices, self.batch_size, 2 * self.num_beams, self.num_beams
)
flat_batch_indices = torch.arange(self.batch_size * self.num_beams) // self.num_beams
flat_batch_indices = flat_batch_indices * self.num_beams
beam_indices = self._flatten_beam_dim(next_running_indices, self.num_beams)
beam_indices = beam_indices + flat_batch_indices
self.sequences.copy_(next_sequences.float())
self.running_sequences.copy_(next_running_sequences.float())
self.log_probs.copy_(next_log_probs)
self.running_log_probs.copy_(next_running_log_probs)
self.is_finished.copy_(next_is_finished.float())
self._cached_beam_idx.copy_(beam_indices.float())
return OnDeviceGenerationModelOutput(
generated_tokens=return_sequences,
done=~continue_search,
)