def new_fwd()

in optimum/neuron/generation/utils.py [0:0]


    def new_fwd(*args, **kwargs):
        # Pad input to max length
        cur_len = None
        input_ids_string = "decoder_input_ids" if is_encoder_decoder else "input_ids"
        if input_ids_string in kwargs:
            current_input_ids = kwargs[input_ids_string]
            batch_size, cur_len = current_input_ids.shape
            num_padding_values = generation_config.max_length - cur_len
            kwargs[input_ids_string] = _pad_input_ids_for_general_sampling(
                current_input_ids, num_padding_values, generation_config.pad_token_id
            )

            # For decoder only models, pad decoder attention mask in addition to prompts
            if "attention_mask" in kwargs and not is_encoder_decoder and num_padding_values > 0:
                kwargs["attention_mask"] = torch.cat(
                    [
                        kwargs["attention_mask"],
                        torch.zeros((batch_size, (generation_config.max_length - cur_len)))
                        .long()
                        .to(kwargs["attention_mask"].device),
                    ],
                    1,
                )
                # create position_ids on the fly for batch generation
                if "position_ids" in set(inspect.signature(current_fwd).parameters.keys()):
                    position_ids = kwargs["attention_mask"].long().cumsum(-1) - 1
                    position_ids.masked_fill_(kwargs["attention_mask"] == 0, 1)
                    kwargs["position_ids"] = position_ids

        # Move inputs to device
        _move_dict_args_to_device(kwargs, main_device)

        # Forward
        kwargs = args_and_kwargs_to_kwargs_only(current_fwd, args, kwargs)
        outputs = current_fwd(**kwargs)
        # Gather outputs if NxD tensor parallelism is applied and the output logits have not been gathered.
        if (
            is_neuronx_distributed_available()
            and parallel_state.model_parallel_is_initialized()
            and parallel_state.get_tensor_model_parallel_size() > 1
            and outputs["logits"].shape[-1] != vocab_size
        ):
            outputs["logits"] = xm.all_gather(
                outputs["logits"],
                dim=-1,
                groups=parallel_state.get_tensor_model_parallel_group(as_list=True),
            )
        xm.mark_step()

        # Move to CPU
        _move_dict_args_to_device(outputs, to_device)

        # Post-process output as a function of cur_len
        outputs["logits"] = outputs["logits"][:, :cur_len, ...].to(output_dtype)

        return outputs