in toolkits/model_checkpoints_convertor/falcon/checkpoint_reshaping_and_interoperability.py [0:0]
def convert_checkpoint_from_megatron_to_transformers(args):
"""
Convert NVIDIA Megatron-LM checkpoint to HuggingFace Transformers checkpoint. This handles Megatron checkpoints
with different tensor parallelism and pipeline parallelism sizes. It saves the converted checkpoint into shards
using HuggingFace Transformers checkpoint sharding functionality. This greatly extends the functionality of
`convert_megatron_gpt2_checkpoint.py`
Args:
args (argparse.Namespace): the arguments to the script
"""
# Load Megatron-LM checkpoint arguments from the state dict
sub_dirs = os.listdir(args.load_path)
possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000"]
for sub_dir in possible_sub_dirs:
if sub_dir in sub_dirs:
rank0_checkpoint_name = [i for i in os.listdir(os.path.join(args.load_path, sub_dir)) if 'rng' in i][0]
rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name)
break
print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}")
state_dict = torch.load(rank0_checkpoint_path, map_location="cpu")
megatron_args = state_dict.get("args", None)
if megatron_args is None:
raise ValueError(
"Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints"
" containing all the megatron arguments. This is because it loads all config related to model"
" architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to"
" manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron"
" arguments to use this utility."
)
# Saving config and tokenzier files
if args.load_path.endswith('/'):
config_path = '/'.join(args.load_path[:-1].split('/')[:-1])
else:
config_path = '/'.join(args.load_path.split('/')[:-1])
os.system("cp -rf "+config_path+"/*.json " + args.save_path)
os.system("cp -rf "+config_path+"/*.py " + args.save_path)
activation_function = "gelu"
vocab_size = (
megatron_args.padded_vocab_size
if getattr(megatron_args, "orig_vocab_size", None) is None
else megatron_args.orig_vocab_size
)
auto_map_dict = {
"AutoConfig": "configuration_RW.RWConfig",
"AutoModel": "modelling_RW.RWModel",
"AutoModelForSequenceClassification": "modelling_RW.RWForSequenceClassification",
"AutoModelForTokenClassification": "modelling_RW.RWForTokenClassification",
"AutoModelForQuestionAnswering": "modelling_RW.RWForQuestionAnswering",
"AutoModelForCausalLM": "modelling_RW.RWForCausalLM"
}
if megatron_args.num_layers != 60:
config = RWConfig(
vocab_size=vocab_size,
hidden_size=megatron_args.hidden_size,
n_layer=megatron_args.num_layers,
n_head=megatron_args.num_attention_heads,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
auto_map=auto_map_dict,
alibi=False,
multi_query=True,
use_cache=True,
parallel_attn=True,
bos_token_id=getattr(megatron_args, "bos_token_id", 11),
eos_token_id=getattr(megatron_args, "eos_token_id", 11),
architectures=["RWForCausalLM"]
)
else:
config = RWConfig_40b(
vocab_size=vocab_size,
hidden_size=megatron_args.hidden_size,
n_layer=megatron_args.num_layers,
n_head=megatron_args.num_attention_heads,
n_head_kv=8,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
auto_map=auto_map_dict,
alibi=False,
multi_query=True,
use_cache=True,
parallel_attn=True,
bias=False,
bos_token_id=getattr(megatron_args, "bos_token_id", 11),
eos_token_id=getattr(megatron_args, "eos_token_id", 11),
architectures=["RWForCausalLM"]
)
output_state_dict = {}
config.save_pretrained(args.save_path)
checkpoint_version = state_dict.get("checkpoint_version", 0.0)
tp_size = megatron_args.tensor_model_parallel_size
pp_size = megatron_args.pipeline_model_parallel_size
# params dtype
if args.target_params_dtype == 'fp16':
dtype = torch.float16
elif args.target_params_dtype == 'bf16':
dtype = torch.bfloat16
else:
dtype = torch.float32
# The regex to extract layer names.
layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
# Convert.
print("Converting")
# Embeddings
print("Converting embeddings")
tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0)
# Convert and store the word embeddings.
word_embeddings = []
word_embeddings_layernorm_weight = []
word_embeddings_layernorm_bias = []
for tp_rank in range(tp_size):
embeddings = get_element_from_dict_by_path(
tp_state_dicts[tp_rank], "model.language_model.embedding"
)
if 'word_embeddings.weight' in embeddings:
word_embeddings.append(embeddings['word_embeddings.weight'])
else:
word_embeddings.append(embeddings['word_embeddings']['weight'])
word_embeddings = torch.cat(word_embeddings, dim=0)
word_embeddings = word_embeddings.to(dtype)
output_state_dict["transformer.word_embeddings.weight"] = word_embeddings
output_state_dict["lm_head.weight"] = word_embeddings
# Reset the vocab size
config.vocab_size = word_embeddings.shape[0]
# Transformer Layers
print("Converting transformer layers")
# The number of heads.
heads = config.n_head
# The hidden_size per head.
hidden_size_per_head = config.hidden_size // config.n_head
num_layers = config.n_layer // pp_size
for pp_rank in range(pp_size):
if pp_size > 0:
print(f"Converting pipeline parallel rank {pp_rank}")
tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, pp_rank)
# The transformer.
path = (
"model.language_model.transformer"
if "transformer" in get_element_from_dict_by_path(tp_state_dicts[0], "model.language_model").keys()
else "model.language_model.encoder"
)
# Extract the layers.
for key, val in get_element_from_dict_by_path(tp_state_dicts[0], path).items():
# Match the name.
m = layer_re.match(key)
# Stop if that's not a layer
if m is None:
break
# The index of the layer.
layer_idx = int(m.group(1)) + pp_rank * num_layers
# The name of the operation.
op_name = m.group(2)
# Is it a weight or a bias?
weight_or_bias = m.group(3)
# The name of the layer.
layer_name = f"transformer.h.{layer_idx}"
# get merged weights on tps
if op_name + "." + weight_or_bias not in tensor_parallel_params_mg:
params = val.to(dtype)
else:
dim = 1 if op_name in ["self_attention.dense", "mlp.dense_4h_to_h"] else 0
params = torch.cat(
[val]
+ [
get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key]
for tp_rank in range(1, tp_size)
],
dim=dim,
).to(dtype)
hf_op_name = megatron_to_transformers_opmap[op_name]
full_hf_op = f'{layer_name}.{hf_op_name}.{weight_or_bias}'
if config.n_layer != 60:
if 'ln_attn' in full_hf_op:
full_hf_op = full_hf_op.replace('ln_attn', 'input_layernorm')
if 'key_value' not in op_name:
output_state_dict[full_hf_op] = params
else:
if weight_or_bias == 'bias': continue
original = output_state_dict[full_hf_op]
output_state_dict[full_hf_op] = torch.cat([original, params])
if config.n_layer != (layer_idx + 1):
raise ValueError(f"Expected {config.n_layer} layers but found {layer_idx + 1}")
# The final layernorm.
print("Converting final layernorm")
params = get_element_from_dict_by_path(tp_state_dicts[0], str(path))
output_state_dict["transformer.ln_f.weight"] = params["final_layernorm.weight"].to(dtype)
output_state_dict["transformer.ln_f.bias"] = params["final_layernorm.bias"].to(dtype)
# It should be done!
print("Conversion from Megatron-LM to Transformers is done!")
# Print the structure of converted state dict.
if args.print_checkpoint_structure:
recursive_print(None, output_state_dict)
# Store the config to file.
print("Saving config")
config.save_pretrained(args.save_path)
# Store the state_dict to file.
max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size
shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size)
# Save the model
if not os.path.exists(args.save_path):
os.system(f'mkdir -p {args.save_path}')
for shard_file, shard in shards.items():
torch.save(shard, os.path.join(args.save_path, shard_file))
if index is None:
print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}")
else:
save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
print(
f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)