def _prepare_inputs_for_on_device_generation()

in optimum/graphcore/generation/utils.py [0:0]


    def _prepare_inputs_for_on_device_generation(self, model_inputs, on_device_generation_steps, batch_size):
        """
        A model-agnostic version of `prepare_inputs_for_generation` whose main purpose is to duplicate
        decoder inputs by `on_device_generation_steps=inference_device_iterations` and perform additional input validation.
        Since we are duplicating tensors, we restrict duplication to `torch.Tensor` and the exceptional case of
        `encoder_outputs.last_hidden_state`.
        """
        adapted_model_inputs = {}
        for k, v in model_inputs.items():
            if k in ("attention_mask", "encoder_outputs") and self.encoder_output_buffer_enabled:
                # These inputs will copied onto device via buffers, so we don't need to duplicate them.
                adapted_model_inputs[k] = v
                continue
            if k == "beam_idx":
                # With on-device generation, beam_idx at each step is handled through buffers.
                continue

            if k == "input_features" and self.cond_encoder_enabled:
                v = v.repeat(on_device_generation_steps, *(1 for _ in range(v.ndim - 1)))
            elif torch.is_tensor(v):
                if v.shape[0] != batch_size:
                    raise ValueError(f"Unexpected size in dim 0 for {k}, expected {batch_size}.")
                v = v.repeat(on_device_generation_steps, *(1 for _ in range(v.ndim - 1)))
            elif k == "encoder_outputs":
                v_type = type(v)
                if not isinstance(v, BaseModelOutput):
                    raise ValueError(
                        "Expected `encoder_outputs` to be an instance of `BaseModelOutput`, " f"received {v_type}."
                    )
                v = v.last_hidden_state
                v = v.repeat(on_device_generation_steps, *(1 for _ in range(v.ndim - 1)))
                v = v_type(last_hidden_state=v)
            elif v is None:
                pass
            elif isinstance(v, (int, float, str, bool)):
                pass
            else:
                raise TypeError(
                    f"Unexpected type {type(v)} received for decoder input {k}. On device generation enforces "
                    "stricter input validation to minimise unexpected errors. Improvements are always welcome."
                )
            adapted_model_inputs[k] = v
        return adapted_model_inputs