optimum/neuron/models/training/checkpointing.py (171 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Union
import torch
from huggingface_hub import split_torch_state_dict_into_shards
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from ...utils.import_utils import is_peft_available
from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla
from .modeling_utils import MODEL_PARALLEL_SHARDS_DIR_NAME
from .transformations_utils import ModelWeightTransformationSpecs, to_original_weights
if is_peft_available():
from peft.utils.constants import (
SAFETENSORS_WEIGHTS_NAME as PEFT_SAFETENSORS_WEIGHTS_NAME,
)
from peft.utils.constants import (
WEIGHTS_NAME as PEFT_WEIGHTS_NAME,
)
else:
PEFT_SAFETENSORS_WEIGHTS_NAME = PEFT_WEIGHTS_NAME = ""
@requires_torch_xla
def xser_load_on_cpu(path: str):
"""
Modified version from neuronx_distributed `_xser_load` function load located at:
https://github.com/aws-neuron/neuronx-distributed/blob/e83494557cb4c5b7e185ccf6c9240bfed9a1993d/src/neuronx_distributed/parallel_layers/checkpointing.py#L252
Instead of moving the loaded tensors to the XLA device, it keeps them on CPU.
"""
import torch_xla.core.xla_model as xm
import torch_xla.utils.serialization as xser
ref_data = torch.load(path)
def convert_fn(tensors):
rewritten_tensors = []
for t in tensors:
rewritten_tensors.append(torch.load(os.path.join(path + ".tensors", "tensor_{}.pt".format(t.tid))))
return rewritten_tensors
def select_fn(v):
return type(v) is xser.TensorReference
return xm.ToXlaTensorArena(convert_fn, select_fn).transform(ref_data)
def consolidate_tensor_parallel_checkpoints(
sharded_checkpoints: List[Path],
load_function: Callable[[Union[str, Path]], Dict[str, Any]],
metadata: Dict[str, Any],
adapter_name: Optional[str] = None,
) -> Dict[str, "torch.Tensor"]:
state_dicts = []
sharded_checkpoints = sorted(sharded_checkpoints)
for sharded_checkpoint in sharded_checkpoints:
if not sharded_checkpoint.is_file():
continue
state_dicts.append(load_function(sharded_checkpoint.as_posix()))
parameters_metadata = metadata["parameters"]
transformation_specs_metadata = metadata["model_weight_transformation_specs"]
# We recreate the transformation specs from the metadata.
transformations_specs = []
for specs_metadata in transformation_specs_metadata:
specs = ModelWeightTransformationSpecs.from_metadata(specs_metadata)
transformations_specs.append(specs)
# We transform the sharded state dicts as follows:
# [state_dict_tp_rank_0, state_dict_tp_rank_1, ...]
# -> {
# key: [state_dict_tp_rank_0[key], state_dict_tp_rank_1[key], ...],
# for key in state_dict_tp_rank_0.keys()
# }
parameter_names = state_dicts[0].keys()
sharded_state_dicts = {name: [state_dict[name] for state_dict in state_dicts] for name in parameter_names}
consolidated_state_dict = to_original_weights(
transformations_specs, sharded_state_dicts, parameters_metadata, adapter_name=adapter_name
)
return consolidated_state_dict
@requires_neuronx_distributed
def consolidate_model_parallel_checkpoints(
checkpoint_dir: Path, adapter_name: Optional[str] = None
) -> Dict[str, "torch.Tensor"]:
model_checkpoint_dir = checkpoint_dir / "model"
# Case 1: the checkpoint was saved with xser.
sharded_checkpoints = list(model_checkpoint_dir.glob("dp_rank*.tensors"))
if sharded_checkpoints:
sharded_checkpoints = model_checkpoint_dir.glob("dp_rank_*")
sharded_checkpoints = [
p for p in sharded_checkpoints if not (p.name.endswith("info.pt") or p.name.endswith("tensors"))
]
load_function = xser_load_on_cpu
# Case 2: If no file was found, maybe the checkpoint was saved without xser.
if not sharded_checkpoints:
sharded_checkpoints = list(model_checkpoint_dir.glob("dp_rank_*.pt"))
load_function = partial(torch.load, weights_only=True)
if not sharded_checkpoints:
raise ValueError(f"Could not find any sharded checkpoint in {model_checkpoint_dir.as_posix()}")
pp_size = max((int(checkpoint_path.stem[-2:]) for checkpoint_path in sharded_checkpoints)) + 1
checkpoints_grouped_by_pp_ranks = [[] for _ in range(pp_size)]
metadatas = []
for pp_rank in range(pp_size):
for checkpoint_path in sharded_checkpoints:
checkpoint_name = checkpoint_path.stem
if int(checkpoint_name[-2:]) == pp_rank:
checkpoints_grouped_by_pp_ranks[pp_rank].append(checkpoint_path)
if (checkpoint_dir / f"mp_metadata_pp_rank_{pp_rank}.pt").is_file():
metadatas.append(torch.load(checkpoint_dir / f"mp_metadata_pp_rank_{pp_rank}.pt"))
else:
with open(checkpoint_dir / f"mp_metadata_pp_rank_{pp_rank}.json") as fp:
metadatas.append(json.load(fp))
consolidated_state_dict = {}
for pp_rank, checkpoint_group_for_pp_rank in enumerate(checkpoints_grouped_by_pp_ranks):
consolidated_for_pp_rank = consolidate_tensor_parallel_checkpoints(
checkpoint_group_for_pp_rank,
load_function,
metadatas[pp_rank],
adapter_name=adapter_name,
)
consolidated_state_dict.update(**consolidated_for_pp_rank)
for key, tensor in consolidated_state_dict.items():
consolidated_state_dict[key] = tensor
return consolidated_state_dict
@requires_safetensors
def consolidate_model_parallel_checkpoints_to_unified_checkpoint(
checkpoint_dir: Union[str, Path],
output_dir: Union[str, Path],
save_format: Literal["pytorch", "safetensors"] = "safetensors",
):
from safetensors.torch import save_file
# We import here to avoid circular import.
from ...peft.peft_model import ADAPTER_MODEL_PARALLEL_SHARDS_DIR_NAME
if not isinstance(checkpoint_dir, Path):
checkpoint_dir = Path(checkpoint_dir)
directories = list(checkpoint_dir.iterdir())
directories_to_consolidate = []
if checkpoint_dir.name != MODEL_PARALLEL_SHARDS_DIR_NAME:
if (checkpoint_dir / MODEL_PARALLEL_SHARDS_DIR_NAME).is_dir():
directories_to_consolidate = [checkpoint_dir / MODEL_PARALLEL_SHARDS_DIR_NAME]
else:
for dir in directories:
if dir.is_dir() and dir.name.startswith("adapter_"):
directories_to_consolidate.append(dir / ADAPTER_MODEL_PARALLEL_SHARDS_DIR_NAME)
if not directories_to_consolidate:
raise ValueError(f"Could not find the tensor parallel shards from {checkpoint_dir}")
else:
directories_to_consolidate = [checkpoint_dir]
if not isinstance(output_dir, Path):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for checkpoint_dir in directories_to_consolidate:
# We need to go one level up because the checkpoint directory is at the shards level here.
parent_dir = checkpoint_dir.parent
current_output_dir = output_dir
is_adapter_model = parent_dir.name.startswith("adapter_")
adapter_name = None
if is_adapter_model:
safe_weights_name = PEFT_SAFETENSORS_WEIGHTS_NAME
weights_name = PEFT_WEIGHTS_NAME
if parent_dir.name != "adapter_default":
adapter_name = parent_dir.name.split("_", maxsplit=1)[-1]
current_output_dir = output_dir / adapter_name
else:
adapter_name = "default"
else:
safe_weights_name = SAFE_WEIGHTS_NAME
weights_name = WEIGHTS_NAME
current_output_dir.mkdir(parents=True, exist_ok=True)
state_dict = consolidate_model_parallel_checkpoints(checkpoint_dir, adapter_name=adapter_name)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=safe_weights_name if save_format == "safetensors" else weights_name
)
# Save index if sharded
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if save_format == "safetensors" else WEIGHTS_INDEX_NAME
with open(current_output_dir / save_index_file, "w") as fp:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
fp.write(content)
# Save the model
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {}
for tensor in tensors:
shard[tensor] = state_dict[tensor].contiguous()
del state_dict[tensor]
if save_format == "safetensors":
save_file(shard, current_output_dir / shard_file, metadata={"format": "pt"})
else:
torch.save(shard, current_output_dir / shard_file)