toolkits/model_checkpoints_convertor/utils/__init__.py (146 lines of code) (raw):

import os import numpy as np import torch import json import logging import gc from transformers.modeling_utils import ( WEIGHTS_INDEX_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, ) try: from transformers.modeling_utils import shard_checkpoint USE_TRANSFORMERS_SAVE = True except: from huggingface_hub.constants import ( SAFETENSORS_WEIGHTS_FILE_PATTERN, SAFETENSORS_INDEX_FILE, ) from huggingface_hub import split_torch_state_dict_into_shards USE_TRANSFORMERS_SAVE = False from safetensors.torch import save_file from collections.abc import Mapping, Sequence def save_hfmodel(args, model, max_shard_size='10GB'): output_state_dict = model if not isinstance(model, dict): output_state_dict = model.state_dict() save_safetensors = (not USE_TRANSFORMERS_SAVE) or args.save_safetensors os.makedirs(args.save, exist_ok=True) # NOTE: remove all old index files if os.path.exists(os.path.join(args.save, SAFE_WEIGHTS_INDEX_NAME)): os.remove(os.path.join(args.save, SAFE_WEIGHTS_INDEX_NAME)) if os.path.exists(os.path.join(args.save, WEIGHTS_INDEX_NAME)): os.remove(os.path.join(args.save, WEIGHTS_INDEX_NAME)) index = None if USE_TRANSFORMERS_SAVE: weight_file = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME index_file = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size, weights_name=weight_file) else: if not args.save_safetensors: logging.warning("Since Transformer v4.47.0, the HF ckpt can only be saved with safetensors in the scripts.") weight_file = SAFETENSORS_WEIGHTS_FILE_PATTERN index_file = SAFETENSORS_INDEX_FILE state_dict_split = split_torch_state_dict_into_shards(output_state_dict, max_shard_size=max_shard_size, filename_pattern=weight_file) shards = {} for filename, tensors in state_dict_split.filename_to_tensors.items(): shards[filename] = {tensor: output_state_dict[tensor] for tensor in tensors} if state_dict_split.is_sharded: index = { "metadata": state_dict_split.metadata, "weight_map": state_dict_split.tensor_to_filename, } for shard_file, shard in shards.items(): target_file = os.path.join(args.save, shard_file) print(f'huggingface model is save to {target_file}') if save_safetensors: save_file(clone_state_dict(shard), target_file, metadata={"format": "pt"}) else: torch.save(clone_state_dict(shard), target_file) if index is not None: save_index_file = os.path.join(args.save, index_file) # Save the index as well with open(save_index_file, "w", encoding="utf-8") as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) print( f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " f"index located at {save_index_file}." ) @torch.inference_mode() def clone_state_dict(elem): """clone all tensors in the elem to cpu device. """ elem_type = type(elem) if isinstance(elem, torch.Tensor): elem = elem.clone() elif isinstance(elem, (np.ndarray, str)): pass elif isinstance(elem, Mapping): elem = dict(elem) for k, v in elem.items(): elem[k] = clone_state_dict(v) elem = elem_type(elem) elif isinstance(elem, Sequence): elem = list(elem) for i in range(len(elem)): elem[i] = clone_state_dict(elem[i]) elem = elem_type(elem) return elem def build_layer_id_mapping(args): """ global layer id <--> local layer id """ ltog, gtol = dict(), dict() assert args.target_decoder_first_pipeline_num_layers is None or args.target_num_layers_per_virtual_pipeline_stage is None, "Currently uneven VPP not supported" if args.target_decoder_first_pipeline_num_layers is not None: remained_layers = args.num_layers - args.target_decoder_first_pipeline_num_layers remained_stages = args.pipeline_model_parallel_size - 1 assert remained_layers % remained_stages == 0 pp_layers_per_stage = [args.target_decoder_first_pipeline_num_layers] +([remained_layers // remained_stages] * remained_stages) for pp_id, num_layers in enumerate(pp_layers_per_stage): for global_layer_id in range(offset, offset + num_layers): # NOTE: map a global transformer layer to a local pp rank # global_id <--> (pp_id, vpp_id, local_id) local_layer_id = global_layer_id - offset ltog[(pp_id, 0, local_layer_id)] = global_layer_id gtol[global_layer_id] = (pp_id, 0, local_layer_id) offset += num_layers else: n_chunks = args.pipeline_model_parallel_size pp_size = args.pipeline_model_parallel_size if args.target_num_layers_per_virtual_pipeline_stage is not None: assert args.num_layers % (args.target_num_layers_per_virtual_pipeline_stage * args.pipeline_model_parallel_size) == 0 n_chunks = args.num_layers // args.target_num_layers_per_virtual_pipeline_stage num_layer_per_chunk = args.num_layers // n_chunks pp_layers_per_stage = [num_layer_per_chunk] * n_chunks offset = 0 for chunk_id, num_layers in enumerate(pp_layers_per_stage): for global_layer_id in range(offset, offset + num_layers): # NOTE: map a global transformer layer to a local pp rank # global_id <--> (pp_id, vpp_id, local_id) pp_id = chunk_id % pp_size vpp_id = chunk_id // pp_size local_layer_id = global_layer_id - offset ltog[(pp_id, vpp_id, local_layer_id)] = global_layer_id gtol[global_layer_id] = (pp_id, vpp_id, local_layer_id) offset += num_layers return ltog, gtol def safe_copy( src_tensor: torch.Tensor, dst_tensor: torch.Tensor, skip_dtype_assert: bool = False, ): if not skip_dtype_assert: if src_tensor.dtype != dst_tensor.dtype: raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}") assert src_tensor.shape == dst_tensor.shape dst_tensor.data.copy_(src_tensor.data) return src_tensor.numel() def save_state_dict(args, model_chunks, checkpoint_name, has_vpp: bool=False, save_args: bool=True): """ Save some model chunks to a megatron checkpoint file """ state_dict = {} if save_args: state_dict['args'] = args state_dict['checkpoint_version'] = 3.0 state_dict['iteration'] = 0 if not has_vpp: state_dict['model'] = model_chunks[0] else: for vpp_id in range(len(model_chunks)): state_dict[f"model{vpp_id}"] = model_chunks[vpp_id] os.makedirs(os.path.dirname(checkpoint_name), exist_ok=True) torch.save(clone_state_dict(state_dict), checkpoint_name) del state_dict gc.collect()