optimum/habana/distributed/serialization.py (265 lines of code) (raw):
# Copyright 2024 The Foundation Model Stack Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file has been modified from its original version.
# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack
import collections
import os
from collections import ChainMap
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Union
import torch
from .tp import TPModule
__adapters: MutableMapping[str, MutableMapping[str, Callable[[Mapping], Mapping]]] = {}
def register_adapter(
architecture: str,
source: str,
adapter: Callable[[Mapping], Mapping],
):
"""
Registers a state dict adapter to be available to the (de) serialization
API.
Args:
architecture: The name of the model architecture, e.g. 'llama'
source: A label representing the format of the weights to be converted.
E.g. 'hf'
adapter: the class of the adapter. The class must accept one constructor
parameter, which will be a state dict (`OrderedDict`)
"""
sources: MutableMapping[str, Callable[[Mapping], Mapping]] = {}
if architecture in __adapters:
sources = __adapters[architecture]
if source in sources:
raise KeyError(f"Variant {source} already registered for architecture {architecture}")
sources[source] = adapter
__adapters[architecture] = sources
def list_sources(architecture: str):
"""
Lists available sources (attribute formats) of a model architecture.
E.g. `models.list_variants('llama')` -> ['meta', 'fms', 'hf']
Args:
architecture: one of the registered architectures returned by
`models.list_models()`.
"""
if architecture not in __adapters:
return []
return list(__adapters[architecture].keys())
def _get_adapter(architecture: str, source: Optional[str]) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]:
if source is None or architecture not in __adapters or source not in __adapters[architecture]:
# if no adapter is registered, assume the attributes are already in
# fms format.
# should we raise an error here instead?
return lambda x: x
else:
return __adapters[architecture][source]
def get_adapted(architecture: str, source: Optional[str], state_dict: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Convert a state dict to FMS format, using an adapter specified by name.
Args:
architecture: one of the architectures from `models.list_models()`.
E.g. llama.
source: A reference to an attribute format
state_dict: the model.state_dict() to be converted/adapted.
"""
# sometimes we only load onto rank 0 so may not have a state_dict here.
if not len(state_dict):
return state_dict
adapter = _get_adapter(architecture, source)
adapted = adapter(state_dict)
return adapted
def _get_safetensors_item(key, file: Path, device: torch.device) -> torch.Tensor:
from safetensors import safe_open # type: ignore[import-untyped]
with torch.no_grad():
with safe_open(file, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined]
return model_weights.get_tensor(key)
class LazySafetensorsDict(collections.UserDict):
def set_lazy_tensor(self, key, file, device):
super().__setitem__(key, lambda: _get_safetensors_item(key, file, device))
def __getitem__(self, key):
lazy_tensor = super().__getitem__(key)
if callable(lazy_tensor):
lazy_tensor = lazy_tensor()
super().__setitem__(key, lazy_tensor)
return lazy_tensor
def load_state_dict(
model_path: Union[str, Path],
*,
source: Optional[str] = None,
distributed_strategy: Optional[str] = None,
checkpoint_sharding: Optional[str] = None,
initial_device: torch.device = torch.device("cpu"),
rank: int = 0,
world_size: int = 1,
) -> MutableMapping[str, Any]:
"""
Validates that the file(s) found at a checkpoint path are compatible with
the intended (possibly distributed) use-case, and returns a lazy loading
state dict if possible (some formats may not support that).
If model_path is a directory, it'll try to load models based on the source
(e.g. .bin for HF, .pth for Meta), and, if no source is specified or hasn't
been registered, it'll try .safetensors, .pth, and .bin.
Args:
model_path: the path to find the weights. If not set, return None.
source: If the weights in the state dict didn't come from an FMS model,
`source` specifies which conversion function might be needed.
See `serialization.list_sources(architecture)`
distributed_strategy: the kind of possibly-distributed model in which we
intend to load these weights. E.g. tp, fsdp, None. Used for
validation.
checkpoint_sharding: the sharding format of the checkpoint.
E.g. layer, tp, fsdp.
initial_device: where the state dict will be loaded if not lazy.
If meta, return empty dict.
"""
if model_path is None or initial_device.type == "meta":
return {}
if checkpoint_sharding == "fsdp" and distributed_strategy not in ["fsdp", "hsdp"]:
raise ValueError("FSDP checkpoints can only be loaded into an FSDP model")
if checkpoint_sharding == "tp" and distributed_strategy != "tp":
raise ValueError("TP checkpoints can only be loaded into a TP model")
# Before creating the Path object, check if model_path has a glob pattern
if isinstance(model_path, str):
model_path, sep, glob_pattern = model_path.partition("*")
else:
sep = ""
glob_pattern = ""
glob_pattern = sep + glob_pattern
model_path = Path(os.path.expanduser(model_path))
checkpoints = []
if model_path.is_dir():
if glob_pattern != "":
glob_pattern_list = [glob_pattern]
elif source == "meta":
glob_pattern_list = ["*.pth", "*.safetensors"]
elif source == "hf":
glob_pattern_list = ["*.bin", "*.safetensors"]
else:
glob_pattern_list = ["*.safetensors", "*.pth", "*.bin"]
for glob_pattern_possibility in glob_pattern_list:
file_list = list(model_path.glob(glob_pattern_possibility))
if len(file_list) > 0:
checkpoints = sorted(file_list)
break
if model_path.is_file():
checkpoints = [model_path]
# Check if we found some files
assert len(checkpoints) > 0, f"Can't find the requested checkpoint data at {model_path}"
if checkpoint_sharding is not None and checkpoint_sharding != "layer":
assert world_size == len(checkpoints), (
f"Loading a {checkpoint_sharding}-sharded checkpoint with len={len(checkpoints)} but world size is {world_size}"
)
checkpoints = [checkpoints[rank]]
# if there's only one checkpoint for fsdp/hsdp, load it only into rank zero
# and it will be distributed by the FSDP `sync_module_states` parameter
if checkpoint_sharding is None and distributed_strategy in {"hsdp", "fsdp"}:
if rank == 0:
checkpoints = [checkpoints[0]]
else:
return {}
checkpoint_sds = []
if checkpoints[0].suffix == ".safetensors":
for ckp in checkpoints:
checkpoint_sds.append(
_load_safetensors_state_dict(
ckp,
initial_device,
)
)
else:
with torch.no_grad():
checkpoint_sds = [
torch.load(str(ckpt_path), map_location=initial_device, mmap=True) for ckpt_path in checkpoints
]
return ChainMap(*checkpoint_sds)
def _load_safetensors_state_dict(
checkpoint: Path,
device: torch.device,
):
sd = LazySafetensorsDict()
from safetensors import safe_open
with safe_open(checkpoint, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined]
sd_keys = list(model_weights.keys())
for key in sd_keys:
sd.set_lazy_tensor(key, checkpoint, device)
return sd
class FusableWeightsMissingError(Exception):
missing_weights: List[str] = []
def __init__(self, missing_weights):
self.missing_weights = missing_weights
super().__init__()
def load_state_dict_into_model(
model: torch.nn.Module,
state_dict: MutableMapping[str, Any],
architecture: str,
source: str,
distributed_strategy: Optional[str] = None,
checkpoint_sharding: Optional[str] = None,
initial_device: torch.device = torch.device("cpu"),
rank: int = 0,
world_size: int = 0,
) -> None:
"""
This function loads state_dict into model in the most efficient way possible,
and it removes all weights that have been used in model from state_dict
in order to conserve memory.
Args:
model: The model where the weights are being loaded.
state_dict: The dictionary with all the weights. If it has been mmaped
(for torch.load) or it is an instance of LazySafetensorsDict,
the weights are loaded lazily from disk.
architecture: the model architecture, e.g. llama. See `models.list_models()`.
source: If the weights in the state dict didn't come from an FMS model,
`source` specifies which conversion function might be needed.
See `serialization.list_sources(architecture)`
distributed_strategy: the kind of possibly-distributed model in which we
intend to load these weights. E.g. tp, fsdp, None. Used for weight
sharding.
checkpoint_sharding: the sharding format of the checkpoint.
E.g. layer, tp, fsdp. Used for weight sharding.
initial_device: where the weights will be loaded from disk.
"""
# 1. Get the adapter from checkpoint sd to fms sd
adapter = _get_adapter(architecture, source)
# 2. Decide if model needs sharding and how (for now only TP)
needs_tp_sharding = checkpoint_sharding != "tp" and distributed_strategy == "tp"
# 3. Iterate over the weights and load them into the model
used_keys = set()
sd_keys = list(state_dict.keys())
with torch.no_grad():
for key in sd_keys:
if key in used_keys:
continue
used_keys.add(key)
try:
partial_sd = {key: state_dict[key]}
if partial_sd[key].device != initial_device:
partial_sd[key] = partial_sd[key].to(device=initial_device)
fms_partial_sd = adapter(partial_sd)
except FusableWeightsMissingError as e:
for weight in e.missing_weights:
used_keys.add(weight)
partial_sd[weight] = state_dict[weight]
if partial_sd[weight].device != initial_device:
partial_sd[weight] = partial_sd[weight].to(device=initial_device)
fms_partial_sd = adapter(partial_sd)
_load_partial_state_dict(model, fms_partial_sd, needs_tp_sharding, rank, world_size)
for p_key in partial_sd.keys():
if isinstance(state_dict, ChainMap):
for child_sd in state_dict.maps:
child_sd.pop(p_key, None)
else:
state_dict.pop(p_key)
del partial_sd
del fms_partial_sd
def _copy_colwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size):
"""
This function copies the correct shard of the weights for a colwise-TP'd module
according to the rank of the process and the world_size.
Args
====
param: torch.nn.Parameter
Parameter that has had TP applied
tensor_value: torch.Tensor
tensor that needs sharding
rank: int
Rank of the current process
world_size: int
Total number of TP processes
"""
# Divide the weight matrix along the first dimension.
output_size_per_partition = param.shape[0]
if not is_bias:
tensor = tensor_value[
(rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition),
:,
]
else:
tensor = tensor_value[(rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition)]
param.copy_(tensor, non_blocking=True)
def _copy_rowwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size):
"""
This function copies the correct shard of the weights for a rowwise-TP'd module
according to the rank of the process and the world_size.
Args
====
param: torch.nn.Parameter
Parameter that has had TP applied
tensor_value: torch.Tensor
tensor that needs sharding
rank: int
Rank of the current process
world_size: int
Total number of TP processes
"""
# Divide the weight matrix along the last dimension.
if not is_bias:
output_size_per_partition = param.shape[1]
tensor = tensor_value[
:,
(rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition),
]
param.copy_(tensor, non_blocking=True)
else:
if rank == 0:
_copy_if_present(param, tensor_value)
else:
param.zero_()
def _copy_embedding(param: torch.nn.Parameter, tensor_value, rank, world_size):
"""
This function copies the correct shard of the weights for a TP'd embedding module
according to the rank of the process and the world_size.
Args
====
param: torch.nn.Parameter
Parameter that has had TP applied
tensor_value: torch.Tensor
tensor that needs sharding
rank: int
Rank of the current process
world_size: int
Total number of TP processes
"""
# Divide the weight matrix along the last dimension.
output_size_per_partition = param.shape[1]
tensor = tensor_value[
:,
(rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition),
]
param.copy_(tensor, non_blocking=True)
def _copy_if_present(parameter, tensor_value):
parameter.copy_(tensor_value, non_blocking=True)
def _load_partial_state_dict(
model: torch.nn.Module,
state_dict,
needs_tp_sharding: bool,
rank=0,
world_size=1,
):
unused_params = []
for key, tensor_value in state_dict.items():
target_module = model
# Find where to put the weight and decide whether it needs TP'ing
key_steps = key.split(".")
prefix = ""
key_step = 0
tp_module = None
# Navigate the model tree to find the module where the parameter is
# located and whether there is a TPModule in the way in case the
# parameter requires sharding
while key_step < len(key_steps) - 1:
try:
target_module = getattr(target_module, key_steps[key_step])
if key_step > 0:
prefix += "."
prefix += key_steps[key_step]
key_step += 1
if isinstance(target_module, Iterable):
target_module = target_module[int(key_steps[key_step])] # type: ignore[index]
prefix += "." + key_steps[key_step]
key_step += 1
if isinstance(target_module, TPModule):
tp_module = target_module
except AttributeError:
unused_params.append(key)
break
# Check if target_module has the Parameter/buffer
try:
param = getattr(target_module, key_steps[-1])
# If TP sharding is not needed, copy the parameter
# into the model
if not needs_tp_sharding or tp_module is None:
_copy_if_present(param, tensor_value)
elif tp_module is not None:
# Handle TP sharding
if key_steps[-2] in tp_module.colwise_param_names():
_copy_colwise(
param,
tensor_value,
key_steps[-1] == "bias",
rank,
world_size,
)
if key_steps[-2] in tp_module.rowwise_param_names():
_copy_rowwise(
param,
tensor_value,
key_steps[-1] == "bias",
rank,
world_size,
)
if key_steps[-2] in tp_module.embedding_param_names():
_copy_embedding(
param,
tensor_value,
rank,
world_size,
)
except AttributeError:
unused_params.append(key)