def forward()

in optimum/graphcore/generation/on_device_generation.py [0:0]


    def forward(self, input_ids: torch.Tensor, absolute_step: torch.Tensor, **kwargs) -> torch.Tensor:
        # Workaround for generic slice assignment self.running_sequences[:, self.absolute_step] = tokens
        assert input_ids.shape[-1] == 1
        input_ids_mask = self.input_ids_mask.to(input_ids.device)
        padded_input_ids = torch.nn.functional.pad(input_ids, (0, self.max_length - 1))
        padded_input_ids = self._unflatten_beam_dim(padded_input_ids, self.num_beams).int()

        # Since we are constrained to keeping buffers in float, including ones holding tokens which are ints,
        # we do most of the processing as int and cast just before writing.
        sequences = self.sequences.int()
        running_sequences = input_ids_mask * padded_input_ids + (1 - input_ids_mask) * self.running_sequences.int()
        is_finished = self.is_finished.int()

        # 0. Check for termination of previous step
        not_max_length_yet = absolute_step < self.max_length

        worst_finished_score = torch.where(
            torch.any(is_finished, axis=1),
            torch.amin(self.log_probs, axis=1),
            torch.ones(self.batch_size) * LARGE_NEGATIVE_CONST,
        )
        if self.early_stopping == "never" and self.length_penalty > 0.0:
            best_running_score = self.running_log_probs[:, 0] / (self.max_length**self.length_penalty)
        else:
            best_running_score = self.running_log_probs[:, 0] / (absolute_step**self.length_penalty)
        improvement_still_possible = torch.any(best_running_score > worst_finished_score)

        still_open_beam = ~(torch.all(is_finished) & (self.early_stopping is True))

        continue_search = not_max_length_yet & still_open_beam & improvement_still_possible

        # 1. Return best beam for each batch and beam indices from previous step
        # Account for the edge-case where there are no finished sequences for a
        # particular batch item. If so, return running sequences for that batch item.
        none_finished = torch.any(is_finished, axis=1)
        return_sequences = torch.where(none_finished[:, None, None], sequences, running_sequences)
        return_sequences = return_sequences[:, 0]

        # 2. Get logits
        model_input_ids = torch.index_select(running_sequences, 2, absolute_step - 1)
        model_input_ids = self._flatten_beam_dim(model_input_ids, self.num_beams)

        logits = self.model(decoder_input_ids=model_input_ids, **kwargs)
        if hasattr(logits, "logits"):
            logits = logits.logits
        logits = logits.squeeze(1).float()

        # 3. Compute log probs
        vocab_size = logits.shape[-1]
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        log_probs = self.logits_processor(running_sequences, log_probs, absolute_step=absolute_step)
        log_probs = self._unflatten_beam_dim(log_probs, self.num_beams)
        log_probs = log_probs + self.running_log_probs.unsqueeze(-1)
        log_probs = log_probs.view(self.batch_size, self.num_beams * vocab_size)

        # 4. Retrieve top-2*K
        beams_to_keep = 2 * self.num_beams
        topk_log_probs, topk_indices = torch.topk(log_probs, k=beams_to_keep)
        topk_beam_indices = torch.div(topk_indices, vocab_size).int()
        topk_running_sequences = self._gather_beams(
            running_sequences, topk_beam_indices, self.batch_size, self.num_beams, beams_to_keep
        )
        topk_ids = topk_indices % vocab_size
        topk_sequences = poptorch.dynamic_update(
            topk_running_sequences, topk_ids.unsqueeze(-1).int(), 2, absolute_step, 1
        )

        # 5. Check which sequences have ended
        did_topk_just_finish = topk_ids == self.eos_token_id
        running_topk_log_probs = topk_log_probs + did_topk_just_finish * LARGE_NEGATIVE_CONST

        # 6. Get running sequences scores for next
        next_topk_indices = torch.topk(running_topk_log_probs, k=self.num_beams)[1]

        next_running_sequences = self._gather_beams(
            topk_sequences, next_topk_indices, self.batch_size, beams_to_keep, self.num_beams
        )
        next_running_log_probs = self._gather_beams(
            running_topk_log_probs, next_topk_indices, self.batch_size, beams_to_keep, self.num_beams
        )

        # 7. Process topk logits
        topk_log_probs = topk_log_probs / (absolute_step**self.length_penalty)
        beams_in_batch_are_full = is_finished.all(axis=-1, keepdims=True).repeat(1, did_topk_just_finish.shape[-1]) & (
            self.early_stopping is True
        )
        add_penalty = ~did_topk_just_finish | beams_in_batch_are_full
        topk_log_probs += add_penalty * LARGE_NEGATIVE_CONST

        # 8. Get scores, sequences, is sentence finished for next.
        merged_sequences = torch.cat([sequences, topk_sequences], axis=1)
        merged_log_probs = torch.cat([self.log_probs, topk_log_probs], axis=1)
        merged_is_finished = torch.cat([is_finished, did_topk_just_finish], axis=1)
        topk_merged_indices = torch.topk(merged_log_probs, k=self.num_beams)[1]
        next_sequences = self._gather_beams(
            merged_sequences, topk_merged_indices, self.batch_size, 3 * self.num_beams, self.num_beams
        )
        next_log_probs = self._gather_beams(
            merged_log_probs, topk_merged_indices, self.batch_size, 3 * self.num_beams, self.num_beams
        )
        next_is_finished = self._gather_beams(
            merged_is_finished, topk_merged_indices, self.batch_size, 3 * self.num_beams, self.num_beams
        )

        # 9. Determine the top k beam indices from the original set of all beams.
        next_running_indices = self._gather_beams(
            topk_beam_indices, next_topk_indices, self.batch_size, 2 * self.num_beams, self.num_beams
        )

        flat_batch_indices = torch.arange(self.batch_size * self.num_beams) // self.num_beams
        flat_batch_indices = flat_batch_indices * self.num_beams
        beam_indices = self._flatten_beam_dim(next_running_indices, self.num_beams)
        beam_indices = beam_indices + flat_batch_indices

        self.sequences.copy_(next_sequences.float())
        self.running_sequences.copy_(next_running_sequences.float())
        self.log_probs.copy_(next_log_probs)
        self.running_log_probs.copy_(next_running_log_probs)
        self.is_finished.copy_(next_is_finished.float())
        self._cached_beam_idx.copy_(beam_indices.float())

        return OnDeviceGenerationModelOutput(
            generated_tokens=return_sequences,
            done=~continue_search,
        )