def forward()

in optimum/neuron/models/inference/backend/modules/decoder/decoder_wrapper.py [0:0]


    def forward(self, input_ids, attention_mask, position_ids, seq_ids, sampling_params):
        input_ids, attention_mask, position_ids, seq_ids = self.convert_int64_to_int32(
            input_ids, attention_mask, position_ids, seq_ids
        )
        input_ids, attention_mask, position_ids, seq_ids = self.pad_to_max_compiled_seq(
            input_ids, attention_mask, position_ids, seq_ids
        )

        input_batch_size = seq_ids.shape[0]

        if input_batch_size > self.neuron_config.max_batch_size:
            raise ValueError(
                f"Input batch size {input_batch_size} exceeds the maximum batch size {self.neuron_config.max_batch_size}."
            )
        elif input_batch_size == self.neuron_config.batch_size:
            return self._forward(input_ids, attention_mask, position_ids, seq_ids, sampling_params)

        cur_batch = 0
        output_logits = []

        logging.debug(
            f"get input_batch_size as {input_batch_size} but compiled batch_size as {self.neuron_config.batch_size}"
        )

        args = (input_ids, attention_mask, position_ids, seq_ids, sampling_params)
        while cur_batch < input_batch_size:
            if cur_batch + self.neuron_config.batch_size <= input_batch_size:
                # we only process part of the input to run
                logging.debug(f"running foward on batch {cur_batch}:{cur_batch + self.neuron_config.batch_size}")
                outputs = self._forward(*[arg[cur_batch : cur_batch + self.neuron_config.batch_size] for arg in args])
            else:
                # we need to pad the input to run
                logging.debug(
                    f"running forward on batch {cur_batch}:{input_batch_size}, padded up to {self.neuron_config.batch_size}"
                )
                outputs = self._forward_with_pad(*[arg[cur_batch:input_batch_size] for arg in args])

            output_logits.append(outputs)
            cur_batch += self.neuron_config.batch_size

        if self.async_mode:
            # block on all requests here, since this is output manipulation
            output_logits = [self._get_async_output(ranked_logits) for ranked_logits in output_logits]

        return torch.cat(output_logits, dim=0)