def forward()

in optimum/graphcore/generation/on_device_generation.py [0:0]


    def forward(self, **kwargs):
        input_ids_key = "decoder_input_ids"
        input_ids = kwargs.pop(input_ids_key, None)
        if input_ids is None:
            input_ids_key = "input_ids"
            input_ids = kwargs.pop(input_ids_key, None)
            if input_ids is None:
                raise ValueError(
                    "The on device generation model was called with kwargs that are missing both `decoder_input_ids` "
                    "and `input_ids`. Please provide one of these as inputs (default is `decoder_input_ids`)."
                )
        if input_ids.shape[-1] > 1:
            raise ValueError("Context length (input_ids.shape[-1]) > 1 is not supported yet.")

        if generation_step := kwargs.pop("generation_step", None) is not None:
            self._generation_step.copy_(generation_step)

        absolute_step = self._generation_step + self.context_length

        # Make sure generation_step does not go out of bounds.
        self._generation_step.copy_(self._generation_step % self.max_length)

        # Reset on-device state buffers when starting generation anew.
        begin_new_generation = (self._generation_step == 0).int()
        self.generation_strategy.reset_state(begin_new_generation)
        self._reset_generation_step(begin_new_generation)

        outputs = self.generation_strategy(input_ids, absolute_step, **kwargs)
        if not isinstance(outputs, OnDeviceGenerationModelOutput):
            raise TypeError(
                f"Unexpected type {type(outputs)} returned from {self.generation_strategy.__class__.__name__}."
            )

        self._generation_step.copy_(self._generation_step + 1)

        return outputs