in optimum/graphcore/generation/on_device_generation.py [0:0]
def forward(self, **kwargs):
input_ids_key = "decoder_input_ids"
input_ids = kwargs.pop(input_ids_key, None)
if input_ids is None:
input_ids_key = "input_ids"
input_ids = kwargs.pop(input_ids_key, None)
if input_ids is None:
raise ValueError(
"The on device generation model was called with kwargs that are missing both `decoder_input_ids` "
"and `input_ids`. Please provide one of these as inputs (default is `decoder_input_ids`)."
)
if input_ids.shape[-1] > 1:
raise ValueError("Context length (input_ids.shape[-1]) > 1 is not supported yet.")
if generation_step := kwargs.pop("generation_step", None) is not None:
self._generation_step.copy_(generation_step)
absolute_step = self._generation_step + self.context_length
# Make sure generation_step does not go out of bounds.
self._generation_step.copy_(self._generation_step % self.max_length)
# Reset on-device state buffers when starting generation anew.
begin_new_generation = (self._generation_step == 0).int()
self.generation_strategy.reset_state(begin_new_generation)
self._reset_generation_step(begin_new_generation)
outputs = self.generation_strategy(input_ids, absolute_step, **kwargs)
if not isinstance(outputs, OnDeviceGenerationModelOutput):
raise TypeError(
f"Unexpected type {type(outputs)} returned from {self.generation_strategy.__class__.__name__}."
)
self._generation_step.copy_(self._generation_step + 1)
return outputs