toolkits/model_checkpoints_convertor/bloom/deepspeed_to_megatron.py (149 lines of code) (raw):

#!/usr/bin/env python import argparse import os from collections import OrderedDict import torch from deepspeed.checkpoint.deepspeed_checkpoint import (ARGS_KEY, DeepSpeedCheckpoint) MODEL_KEY = 'model' LANGUGAGE_MODEL_KEY = 'language_model' EMBEDDING_KEY = 'embedding' ENCODER_KEY = 'encoder' WORD_EMBEDDINGS_FOR_HEAD_KEY = 'word_embeddings_for_head' WORD_EMBEDDINGS_KEY = 'word_embeddings' FINAL_LAYER_NORM_KEY = 'final_layernorm' CHECKPOINT_VERSION_KEY = 'checkpoint_version' CHECKPOINT_VERSION_VALUE = 3.0 ITERATION_KEY = 'iteration' def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--input_folder', default=None, type=str, help='Input DeepSpeed Checkpoint folder') parser.add_argument('--output_folder', default=None, type=str, help='Output Megatron checkpoint folder') parser.add_argument('--target_tp', default=1, type=int, help='Target TP degree') parser.add_argument('--target_pp', default=1, type=int, help='Target PP degree') parser.add_argument( '--for_release', action='store_true', help='Convert for release purpose, reset some (progress) counters.') args = parser.parse_args() print(f'args = {args}') return args def _convert_ds_transformer_state(sd_list): new_sd = OrderedDict() for i, sd in enumerate(sd_list): for key, value in sd.items(): new_key = f'layers.{i}.{key}' new_sd[new_key] = value return new_sd def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): path_list = [] iter_folder = f'iter_{iteration:07d}' for i in range(0, tp_degree): path_list.append([]) for j in range(0, pp_degree): rank_folder = f'mp_rank_' \ f'{i:02d}' if pp_degree == 1\ else f'mp_rank_{i:02d}_{j:03d}' ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt') path_list[i].append( os.path.join(base_folder, iter_folder, ckpt_path)) return path_list def _create_megatron_dict(): language_model_dict = {EMBEDDING_KEY: {}, ENCODER_KEY: {}} megatron_dict = { MODEL_KEY: { LANGUGAGE_MODEL_KEY: language_model_dict }, CHECKPOINT_VERSION_KEY: CHECKPOINT_VERSION_VALUE } return megatron_dict def _save_checkpoint(file_path, chkpt_sd): dir, _ = os.path.split(file_path) os.makedirs(dir, exist_ok=True) torch.save(chkpt_sd, file_path) def _renest_sd(sd): new_sd = OrderedDict() for key, value in sd.items(): new_sd[key] = value return new_sd def _create_rank_checkpoint(ds_checkpoint, tp_index, pp_index, for_release=False): meg_encoder_sd = OrderedDict() meg_embedding_sd = OrderedDict() meg_embedding_for_head_sd = OrderedDict() transformer_sd = ds_checkpoint.get_transformer_state(tp_index, pp_index) meg_encoder_sd.update(_convert_ds_transformer_state(transformer_sd)) if pp_index in [0, ds_checkpoint.pp_degree - 1]: embedding_sd = ds_checkpoint.get_embedding_state(tp_index) nested_embedding_sd = _renest_sd(embedding_sd) if pp_index == 0: meg_embedding_sd.update(nested_embedding_sd) if pp_index == ds_checkpoint.pp_degree - 1: for key, value in embedding_sd.items(): if key.startswith(WORD_EMBEDDINGS_KEY): fields = key.split('.') new_fields = fields[1:] new_key = '.'.join(new_fields) meg_embedding_for_head_sd[new_key] = value final_norm_sd = ds_checkpoint.get_final_norm_state(tp_index) new_final_norm_sd = { f'{FINAL_LAYER_NORM_KEY}.{key}': value for key, value in final_norm_sd.items() } meg_encoder_sd.update(new_final_norm_sd) checkpoint_sd = _create_megatron_dict() iteration = ds_checkpoint.get_iteration() checkpoint_sd[ITERATION_KEY] = iteration if pp_index == 0: checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ EMBEDDING_KEY] = meg_embedding_sd checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd if pp_index == ds_checkpoint.pp_degree - 1: checkpoint_sd[MODEL_KEY][ WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd checkpoint_sd[ARGS_KEY] = ds_checkpoint.get_args() # Adjust specific fields checkpoint_sd[ ARGS_KEY].tensor_model_parallel_size = ds_checkpoint.tp_degree checkpoint_sd[ ARGS_KEY].pipeline_model_parallel_size = ds_checkpoint.pp_degree if for_release: checkpoint_sd[ARGS_KEY].consumed_train_samples = 0 checkpoint_sd[ARGS_KEY].consumed_valid_samples = 0 return checkpoint_sd def _create_latest_file(base_folder, iteration): file_path = os.path.join(base_folder, 'latest_checkpointed_iteration.txt') os.makedirs(base_folder, exist_ok=True) with open(file_path, 'w') as f: f.write(str(iteration)) def main(): print('Convert DeepSpeed Checkpoint to Megatron Checkpoint') args = parse_arguments() print(f'Converting DeepSpeed checkpoint' f' in {args.input_folder} to Megatron' f' checkpoint in {args.output_folder}') ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp) iteration = ds_checkpoint.get_iteration() _create_latest_file(args.output_folder, iteration) checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree) for i in range(0, ds_checkpoint.tp_degree): for j in range(0, ds_checkpoint.pp_degree): sd = _create_rank_checkpoint(ds_checkpoint, i, j, args.for_release) _save_checkpoint(checkpoint_paths[i][j], sd) if __name__ == '__main__': main()