in src/optimum/nvidia/pipelines/text_generation.py [0:0]
def _forward(self, model_inputs, **generate_kwargs):
input_ids = model_inputs["input_ids"]
prompt_text = model_inputs.pop("prompt_text")
attention_mask = model_inputs.get("attention_mask", None)
max_new_tokens = generate_kwargs.pop("max_new_tokens", None)
min_length = generate_kwargs.pop("min_length", -1)
num_beams = generate_kwargs.pop("num_beams", 1)
temperature = generate_kwargs.pop("temperature", 1.0)
top_k = generate_kwargs.pop("top_k", 50)
top_p = generate_kwargs.pop("top_p", 1.0)
repetition_penalty = generate_kwargs.pop("repetition_penalty", 1.0)
length_penalty = generate_kwargs.pop("length_penalty", 1.0)
seed = generate_kwargs.pop("seed", 2017)
# prefix_length = generate_kwargs.pop("prefix_length", 0)
# If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
# generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
# if prefix_length > 0:
# has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
# "generation_config" in generate_kwargs
# and generate_kwargs["generation_config"].max_new_tokens is not None
# )
# if not has_max_new_tokens:
# generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self._runtime
# generate_kwargs["max_length"] += prefix_length
# has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
# "generation_config" in generate_kwargs
# and generate_kwargs["generation_config"].min_new_tokens is not None
# )
# if not has_min_new_tokens and "min_length" in generate_kwargs:
# generate_kwargs["min_length"] += prefix_length
# BS x BEAMS x SL
generated_sequence, lengths = self._runtime.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
min_length=min_length,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
seed=seed,
)
return {
"generated_sequence": generated_sequence,
"lengths": lengths,
"input_ids": input_ids,
"prompt_text": prompt_text,
}