optimum/amd/brevitas/data_utils.py (202 lines of code) (raw):
# Copyright 2023 The HuggingFace Team. All rights reserved.
# Licensed under the MIT License.
import random
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
from tqdm import tqdm
from optimum.utils.normalized_config import NormalizedConfigManager
from transformers import AutoConfig
if TYPE_CHECKING:
from .configuration import BrevitasQuantizationConfig
HIDDEN_SIZE_KEYS = ["d_model", "hidden_size"]
NUM_HEADS_KEYS = ["num_attention_heads"]
@torch.no_grad()
def recursive_to_device(tensor_or_iterable: Union[Iterable, torch.Tensor], device) -> None:
if isinstance(tensor_or_iterable, torch.Tensor):
return tensor_or_iterable.to(device)
elif isinstance(tensor_or_iterable, tuple): # Special handling of tuples, since they are immutable
tmp_list = []
for i in tensor_or_iterable:
tmp_list.append(recursive_to_device(i, device))
return tuple(tmp_list)
elif isinstance(tensor_or_iterable, Iterable):
for i in tensor_or_iterable:
tensor_or_iterable[i] = recursive_to_device(i, device)
return tensor_or_iterable
else:
raise ValueError(f"Cannot move {type(tensor_or_iterable)} to {device}")
@torch.no_grad()
def compute_perplexity(model: torch.nn.Module, data: List[Dict], context_length: int, tokenizer: Any, seed: int = 0):
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
model = model.eval()
cross_entropy_loss = nn.CrossEntropyLoss()
nlls = []
for sample in tqdm(data, desc="Computing perplexity..."):
sample_length = sample["input_ids"].shape[1]
for start_index in range(0, sample_length, context_length * 2):
end_index = min(start_index + sample_length, sample_length - 1)
subsample = {
"input_ids": sample["input_ids"][:, start_index : end_index + 1],
"attention_mask": sample["attention_mask"][:, start_index : end_index + 1],
}
# In case we are using torch.fx, we can not have optional inputs, and we have traced the model with past_key_values inputs, thus we need them here as well.
if "past_key_values" in sample and isinstance(model, torch.fx.GraphModule):
subsample["past_key_values"] = sample["past_key_values"]
# Add BOS token.
if tokenizer.bos_token_id is not None:
subsample["input_ids"][:, 0] = tokenizer.bos_token_id
use_accelerate = hasattr(model, "hf_device_map")
if not use_accelerate or (use_accelerate and not hasattr(model, "_hf_hook")):
device = next(model.parameters()).device
for name, val in subsample.items():
subsample[name] = recursive_to_device(val, device)
else:
# In accelerate by default `io_same_device=True`, and here we want the of the model output on device.
device = model._hf_hook.execution_device
for name, val in subsample.items():
subsample[name] = recursive_to_device(val, device)
lm_logits = model(**subsample)["logits"]
reference_labels = subsample["input_ids"][:, context_length:]
shift_logits = lm_logits[:, context_length - 1 : -1]
# Fuse batch and sequence length dimensions.
reference_labels = reference_labels.view(reference_labels.shape[-1])
shift_logits = shift_logits.view(-1, shift_logits.shape[-1])
loss = cross_entropy_loss(shift_logits, reference_labels)
nlls.append(loss)
ppl = torch.exp(torch.stack(nlls).mean())
return ppl
def get_wikitext2(
tokenizer: Any, seqlen: int, nsamples: int, split: str = "train", fuse_sequences: bool = True, seed: int = 42
):
random.seed(seed)
if split == "train":
data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
elif split == "validation":
data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
if fuse_sequences:
data = data.shuffle(seed=seed)
# wikitext2 is too big.
tokenized_data = tokenizer("\n\n".join(data["text"])[:100000], return_tensors="pt")
dataset = []
for _ in range(nsamples):
i = random.randint(0, tokenized_data.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = tokenized_data.input_ids[:, i:j]
attention_mask = torch.ones((1, seqlen), dtype=torch.int64)
dataset.append({"input_ids": inp, "attention_mask": attention_mask})
else:
dataset = []
with tqdm(total=nsamples) as pbar:
while len(dataset) < nsamples:
data_index = random.randint(0, len(data) - 1)
enc = tokenizer(data[data_index]["text"], return_tensors="pt")
if enc["input_ids"].shape[1] < seqlen:
continue
start_idx = random.randint(0, enc["input_ids"].shape[1] - seqlen)
end_idx = start_idx + seqlen - 1
attention_mask = torch.ones((1, seqlen), dtype=torch.int64)
input_ids = enc["input_ids"][:, start_idx : end_idx + 1]
# Add BOS token.
if tokenizer.bos_token_id is not None:
input_ids[:, 0] = tokenizer.bos_token_id
dataset.append({"input_ids": input_ids, "attention_mask": attention_mask})
pbar.update(1)
return dataset
def get_c4(
tokenizer: Any, seqlen: int, nsamples: int, split: str = "train", fuse_sequences: bool = True, seed: int = 42
):
random.seed(seed)
if split == "train":
data = load_dataset("allenai/c4", split="train", data_files={"train": "en/c4-train.00000-of-01024.json.gz"})
elif split == "validation":
data = load_dataset(
"allenai/c4",
split="validation",
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
)
if fuse_sequences:
data = data.shuffle(seed=seed)[:10000] # c4 is too big.
full_text = "\n\n".join(data["text"])
tokenized_data = tokenizer(full_text, return_tensors="pt")
dataset = []
for _ in range(nsamples):
i = random.randint(0, tokenized_data.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = tokenized_data.input_ids[:, i:j]
attention_mask = torch.ones((1, seqlen), dtype=torch.int64)
dataset.append({"input_ids": inp, "attention_mask": attention_mask})
else:
dataset = []
with tqdm(total=nsamples) as pbar:
while len(dataset) < nsamples:
data_index = random.randint(0, len(data) - 1)
enc = tokenizer(data[data_index]["text"], return_tensors="pt")
if enc["input_ids"].shape[1] < seqlen:
continue
start_idx = random.randint(0, enc["input_ids"].shape[1] - seqlen)
end_idx = start_idx + seqlen - 1
attention_mask = torch.ones((1, seqlen), dtype=torch.int64)
input_ids = enc["input_ids"][:, start_idx : end_idx + 1]
# Add BOS token.
if tokenizer.eos_token_id is not None:
input_ids[:, 0] = tokenizer.eos_token_id
dataset.append({"input_ids": input_ids, "attention_mask": attention_mask})
pbar.update(1)
return dataset
class DatasetToDevice(torch.utils.data.Dataset):
def __init__(self, data: List, device: Optional[Union[str, torch.device]]):
super().__init__()
self.data = data
self.device = device
def __getitem__(self, idx):
if self.device is not None:
return {name: recursive_to_device(val, self.device) for name, val in self.data[idx].items()}
else:
return self.data[idx]
def __len__(self):
return len(self.data)
def get_dataset_for_model(
model_name_or_path: str,
qconfig: "BrevitasQuantizationConfig",
dataset_name: str,
tokenizer: Any,
nsamples: int = 128,
seqlen: int = 2048,
seed: int = 0,
split: str = "train",
fuse_sequences: bool = True,
device: Optional[Union[str, torch.device]] = None,
):
"""
Get a dataset.
Args:
model_name_or_path (`str`):
A local folder containing the model or the model hosted on the Hugging Face Hub.
dataset_name (`str`):
Dataset name. Available options are `['wikitext2', 'c4']`.
tokenizer (`Any`):
Tokenizer of the model
nsamples (`int`, defaults to `128`):
Number of samples
seqlen (`int`, defaults to `2048`):
The sequence length of the model
seed (`int`, defaults to `0`):
Seed
split (`str`, defaults to `train`):
Split of the dataset. Can be either "train" or "validation"
Returns:
`List[Dict[str,torch.LongTensor]]`: The tokenized dataset.
"""
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
get_dataset_map = {
"wikitext2": get_wikitext2,
"c4": get_c4,
}
if split not in ["train", "validation"]:
raise ValueError(f"The split need to be 'train' or 'validation' but found {split}")
if dataset_name not in get_dataset_map:
raise ValueError(f"Expected a value in {list(get_dataset_map.keys())} but found {dataset_name}")
get_dataset_fn = get_dataset_map[dataset_name]
data = get_dataset_fn(
tokenizer=tokenizer, nsamples=nsamples, seqlen=seqlen, split=split, fuse_sequences=fuse_sequences, seed=seed
)
# In case the dataset is loaded to be used with an fx.GraphModule, we need to add empty past_key_values inputs in the dataset.
if qconfig.requires_fx_graph():
config = AutoConfig.from_pretrained(model_name_or_path)
normalized_config_class = NormalizedConfigManager.get_normalized_config_class(config.model_type)
normalized_config = normalized_config_class(config)
num_heads = normalized_config.num_attention_heads
if hasattr(normalized_config, "num_key_value_heads"):
num_kv_heads = normalized_config.num_key_value_heads
else:
num_kv_heads = num_heads
head_dim = normalized_config.hidden_size // num_heads
num_layers = normalized_config.num_layers
for sample in data:
sample["past_key_values"] = tuple(
(
torch.zeros(1, num_kv_heads, 0, head_dim, device=sample["input_ids"].device),
torch.zeros(1, num_kv_heads, 0, head_dim, device=sample["input_ids"].device),
)
for _ in range(num_layers)
)
data = DatasetToDevice(data, device=device)
return data