in evaluation/model_utils.py [0:0]
def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, add_special_tokens=True, disable_tqdm=False, **generation_kwargs):
generations = []
if not disable_tqdm:
progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")
num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
for i in range(0, len(prompts), batch_size):
batch_prompts = prompts[i:i+batch_size]
tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens)
batch_input_ids = tokenized_prompts.input_ids
attention_mask = tokenized_prompts.attention_mask
if model.device.type == "cuda":
batch_input_ids = batch_input_ids.cuda()
attention_mask = attention_mask.cuda()
# try:
stop_criteria = KeywordsStoppingCriteria(stop_id_sequences, tokenizer)
batch_outputs = model.generate(
input_ids=batch_input_ids,
attention_mask=attention_mask,
stopping_criteria=StoppingCriteriaList([stop_criteria]),
# stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None,
# stopping_criteria=[KeyWordsCriteriaTrunc(stop_id_sequences, batch_input_ids.size(1))] if stop_id_sequences else None,
**generation_kwargs
)
# the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
# so some outputs still have the stop sequence, which we need to remove.
# if stop_id_sequences:
# for output_idx in range(batch_outputs.shape[0]):
# for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
# if any(batch_outputs[output_idx, token_idx: token_idx+len(stop_sequence)].tolist() == stop_sequence for stop_sequence in stop_id_sequences):
# batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
# break
# remove the prompt from the output
# we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
# we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
# space is important for some tasks (e.g., code completion).
batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
# duplicate the prompts to match the number of return sequences
batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
batch_generations = [
output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
]
# remove the remain stop sequence from the output.
for idx, prediction in enumerate(batch_generations):
for stop_sequence in stop_id_sequences:
batch_generations[idx] = prediction.split(stop_sequence)[0]
generations += batch_generations
if not disable_tqdm:
progress.update(len(batch_prompts)//num_return_sequences)
assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
return generations