def forward()

in optimum/graphcore/generation/utils.py [0:0]


    def forward(self, t, beam_idx=None, **model_inputs):
        """
        Args:
            t : (`torch.Tensor(int)`) Tensor with single int representing the current length of the sequence being generated
            beam_idx: (`torch.LongTensor` of shape `(batch_size * num_beams,)`):
                Beam indices indicating to which beam the tokens were added, required for reordering the on-device KV cache.
            model_inputs : Regular model_inputs passed to the wrapped model.
        Returns:
            The output logits at position `t` only
        """
        for module in self._modules_with_attributes_in_buffers["_generation_step"]:
            module._generation_step.copy_(t)

        # When generation is done on host, the beam_idx has to be provided as an input.
        # When generation is done on device, the beam_idx is stored in a separate buffer.
        if beam_idx is None:
            if hasattr(self.pipelined_model, "generation_strategy") and hasattr(
                self.pipelined_model.generation_strategy, "_cached_beam_idx"
            ):
                beam_idx = self.pipelined_model.generation_strategy._cached_beam_idx.int()
        for module in self._modules_with_attributes_in_buffers["_beam_idx"]:
            if beam_idx is None:
                raise ValueError(
                    "A module registered a `beam_idx` buffer, but the pipelined model is not called with such, "
                    "or the on device beam search did not register `_cached_beam_idx`. For the first case, "
                    "`beam_idx` can be provided to the model via `prepare_inputs_for_generation`."
                )
            module._beam_idx.copy_(beam_idx)

        # Run the decoder
        kwargs = self._get_buffered_outputs()
        outputs = self.pipelined_model(**model_inputs, **kwargs)
        if isinstance(outputs, ModelOutput) and not isinstance(outputs, OnDeviceGenerationModelOutput):
            outputs = type(outputs)(
                logits=outputs.logits,
            )
        return outputs