in optimum/amd/brevitas/data_utils.py [0:0]
def compute_perplexity(model: torch.nn.Module, data: List[Dict], context_length: int, tokenizer: Any, seed: int = 0):
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
model = model.eval()
cross_entropy_loss = nn.CrossEntropyLoss()
nlls = []
for sample in tqdm(data, desc="Computing perplexity..."):
sample_length = sample["input_ids"].shape[1]
for start_index in range(0, sample_length, context_length * 2):
end_index = min(start_index + sample_length, sample_length - 1)
subsample = {
"input_ids": sample["input_ids"][:, start_index : end_index + 1],
"attention_mask": sample["attention_mask"][:, start_index : end_index + 1],
}
# In case we are using torch.fx, we can not have optional inputs, and we have traced the model with past_key_values inputs, thus we need them here as well.
if "past_key_values" in sample and isinstance(model, torch.fx.GraphModule):
subsample["past_key_values"] = sample["past_key_values"]
# Add BOS token.
if tokenizer.bos_token_id is not None:
subsample["input_ids"][:, 0] = tokenizer.bos_token_id
use_accelerate = hasattr(model, "hf_device_map")
if not use_accelerate or (use_accelerate and not hasattr(model, "_hf_hook")):
device = next(model.parameters()).device
for name, val in subsample.items():
subsample[name] = recursive_to_device(val, device)
else:
# In accelerate by default `io_same_device=True`, and here we want the of the model output on device.
device = model._hf_hook.execution_device
for name, val in subsample.items():
subsample[name] = recursive_to_device(val, device)
lm_logits = model(**subsample)["logits"]
reference_labels = subsample["input_ids"][:, context_length:]
shift_logits = lm_logits[:, context_length - 1 : -1]
# Fuse batch and sequence length dimensions.
reference_labels = reference_labels.view(reference_labels.shape[-1])
shift_logits = shift_logits.view(-1, shift_logits.shape[-1])
loss = cross_entropy_loss(shift_logits, reference_labels)
nlls.append(loss)
ppl = torch.exp(torch.stack(nlls).mean())
return ppl