in optimum/neuron/generation/utils.py [0:0]
def new_fwd(*args, **kwargs):
# Pad input to max length
cur_len = None
input_ids_string = "decoder_input_ids" if is_encoder_decoder else "input_ids"
if input_ids_string in kwargs:
current_input_ids = kwargs[input_ids_string]
batch_size, cur_len = current_input_ids.shape
num_padding_values = generation_config.max_length - cur_len
kwargs[input_ids_string] = _pad_input_ids_for_general_sampling(
current_input_ids, num_padding_values, generation_config.pad_token_id
)
# For decoder only models, pad decoder attention mask in addition to prompts
if "attention_mask" in kwargs and not is_encoder_decoder and num_padding_values > 0:
kwargs["attention_mask"] = torch.cat(
[
kwargs["attention_mask"],
torch.zeros((batch_size, (generation_config.max_length - cur_len)))
.long()
.to(kwargs["attention_mask"].device),
],
1,
)
# create position_ids on the fly for batch generation
if "position_ids" in set(inspect.signature(current_fwd).parameters.keys()):
position_ids = kwargs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(kwargs["attention_mask"] == 0, 1)
kwargs["position_ids"] = position_ids
# Move inputs to device
_move_dict_args_to_device(kwargs, main_device)
# Forward
kwargs = args_and_kwargs_to_kwargs_only(current_fwd, args, kwargs)
outputs = current_fwd(**kwargs)
# Gather outputs if NxD tensor parallelism is applied and the output logits have not been gathered.
if (
is_neuronx_distributed_available()
and parallel_state.model_parallel_is_initialized()
and parallel_state.get_tensor_model_parallel_size() > 1
and outputs["logits"].shape[-1] != vocab_size
):
outputs["logits"] = xm.all_gather(
outputs["logits"],
dim=-1,
groups=parallel_state.get_tensor_model_parallel_group(as_list=True),
)
xm.mark_step()
# Move to CPU
_move_dict_args_to_device(outputs, to_device)
# Post-process output as a function of cur_len
outputs["logits"] = outputs["logits"][:, :cur_len, ...].to(output_dtype)
return outputs