in src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py [0:0]
def convert_checkpoint_from_megatron_to_transformers(args):
"""
Convert NVIDIA Megatron-LM checkpoint to HuggingFace Transformers checkpoint. This handles Megatron checkpoints
with different tensor parallelism and pipeline parallelism sizes. It saves the converted checkpoint into shards
using HuggingFace Transformers checkpoint sharding functionality. This greatly extends the functionality of
`convert_megatron_gpt2_checkpoint.py`
Args:
args (argparse.Namespace): the arguments to the script
"""
# Load Megatron-LM checkpoint arguments from the state dict
sub_dirs = os.listdir(args.load_path)
possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000"]
for sub_dir in possible_sub_dirs:
if sub_dir in sub_dirs:
rank0_checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir))[0]
rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name)
break
print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}")
check_torch_load_is_safe()
state_dict = torch.load(rank0_checkpoint_path, map_location="cpu", weights_only=True)
megatron_args = state_dict.get("args", None)
if megatron_args is None:
raise ValueError(
"Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints"
" containing all the megatron arguments. This is because it loads all config related to model"
" architecture, the tensor and pipeline model parallel size from the checkpoint instead of user having to"
" manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron"
" arguments to use this utility."
)
# Create Transformers GPT2 config from Megatron-LM arguments
if megatron_args is not None:
if megatron_args.bias_gelu_fusion:
activation_function = "gelu_fast"
elif megatron_args.openai_gelu:
activation_function = "gelu_new"
else:
activation_function = "gelu"
else:
# in the very early days this used to be "gelu_new"
activation_function = "gelu_new"
vocab_size = (
megatron_args.padded_vocab_size
if getattr(megatron_args, "orig_vocab_size", None) is None
else megatron_args.orig_vocab_size
)
print(vocab_size)
config = GPT2Config(
vocab_size=vocab_size,
n_positions=megatron_args.max_position_embeddings,
n_embd=megatron_args.hidden_size,
n_layer=megatron_args.num_layers,
n_head=megatron_args.num_attention_heads,
n_inner=megatron_args.ffn_hidden_size,
activation_function=activation_function,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
use_cache=True,
bos_token_id=vocab_size - 1,
eos_token_id=vocab_size - 1,
architectures=["GPT2LMHeadModel"],
)
output_state_dict = {}
checkpoint_version = state_dict.get("checkpoint_version", 0.0)
tp_size = megatron_args.tensor_model_parallel_size
pp_size = megatron_args.pipeline_model_parallel_size
dtype = torch.float32
# The regex to extract layer names.
layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
# Convert.
print("Converting")
# Embeddings
print("Converting embeddings")
tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0)
# Convert and store the position embeddings.
position_embeddings = get_element_from_dict_by_path(
tp_state_dicts[0], "model.language_model.embedding.position_embeddings.weight"
)
output_state_dict["transformer.wpe.weight"] = position_embeddings.to(dtype)
# Convert and store the word embeddings.
word_embeddings = torch.cat(
[
get_element_from_dict_by_path(
tp_state_dicts[tp_rank], "model.language_model.embedding.word_embeddings.weight"
)
for tp_rank in range(tp_size)
],
dim=0,
)
word_embeddings = word_embeddings[:vocab_size].to(dtype)
output_state_dict["transformer.wte.weight"] = word_embeddings
# Transformer Layers
print("Converting transformer layers")
# The number of heads.
heads = config.n_head
# The hidden_size per head.
hidden_size_per_head = config.n_embd // config.n_head
n_positions = config.n_positions
num_layers = config.num_hidden_layers // pp_size
for pp_rank in range(pp_size):
if pp_size > 0:
print(f"Converting pipeline parallel rank {pp_rank}")
tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, pp_rank)
# The transformer.
path = (
"model.language_model.transformer"
if "transformer" in get_element_from_dict_by_path(tp_state_dicts[0], "model.language_model").keys()
else "model.language_model.encoder"
)
# Extract the layers.
for key, val in get_element_from_dict_by_path(tp_state_dicts[0], path).items():
# Match the name.
m = layer_re.match(key)
# Stop if that's not a layer
if m is None:
break
# The index of the layer.
layer_idx = int(m.group(1)) + pp_rank * num_layers
# The name of the operation.
op_name = m.group(2)
# Is it a weight or a bias?
weight_or_bias = m.group(3)
# The name of the layer.
layer_name = f"transformer.h.{layer_idx}"
if op_name + "." + weight_or_bias not in tensor_parallel_params:
params = val.to(dtype)
else:
dim = 1 if op_name in ["self_attention.dense", "mlp.dense_4h_to_h", "attention.dense"] else 0
params = torch.cat(
[val]
+ [
get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key]
for tp_rank in range(1, tp_size)
],
dim=dim,
).to(dtype)
# For layernorm(s), simply store the layer norm.
if op_name.endswith("layernorm"):
ln_name = "ln_1" if op_name.startswith("input") else "ln_2"
output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params
# Transpose the QKV matrix.
elif (
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
) and weight_or_bias == "weight":
# Insert a tensor of 1x1xDxD bias.
causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=dtype)).view(
1, 1, n_positions, n_positions
)
output_state_dict[layer_name + ".attn.bias"] = causal_mask
# Insert a "dummy" tensor for masked_bias.
masked_bias = torch.tensor(-1e4, dtype=dtype)
output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias
out_val = megatron_to_transformers_fix_query_key_value_ordering(
params,
checkpoint_version,
3,
heads,
hidden_size_per_head,
)
# Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D.
out_val = out_val.transpose(0, 1).contiguous()
# Store.
output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val
# Transpose the bias.
elif (
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
) and weight_or_bias == "bias":
out_val = megatron_to_transformers_fix_query_key_value_ordering(
params, checkpoint_version, 3, heads, hidden_size_per_head
)
# Store. No change of shape.
output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val
# Transpose the weights.
elif weight_or_bias == "weight":
out_name = megatron_to_transformers[op_name]
output_state_dict[layer_name + out_name + "weight"] = params.transpose(0, 1)
# Copy the bias.
elif weight_or_bias == "bias":
out_name = megatron_to_transformers[op_name]
output_state_dict[layer_name + out_name + "bias"] = params
if config.n_layer != (layer_idx + 1):
raise ValueError(f"Expected {config.n_layer} layers but found {layer_idx + 1}")
# The final layernorm.
print("Converting final layernorm")
params = get_element_from_dict_by_path(tp_state_dicts[0], str(path))
output_state_dict["transformer.ln_f.weight"] = params["final_layernorm.weight"].to(dtype)
output_state_dict["transformer.ln_f.bias"] = params["final_layernorm.bias"].to(dtype)
# For LM head, transformers' wants the matrix to weight embeddings.
print("Converting LM head")
output_state_dict["lm_head.weight"] = word_embeddings.to(dtype)
# It should be done!
print("Conversion from Megatron-LM to Transformers is done!")
# Print the structure of converted state dict.
if args.print_checkpoint_structure:
recursive_print(None, output_state_dict)
# Add tokenizer class info to config
# see https://github.com/huggingface/transformers/issues/13906)
if args.tokenizer_name is None:
tokenizer_name = "openai-community/gpt2"
else:
tokenizer_name = args.tokenizer_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer_class = type(tokenizer).__name__
config.tokenizer_class = tokenizer_class
# Store the config to file.
print("Saving config")
config.save_pretrained(args.save_path)
# Save tokenizer based on args
if args.tokenizer_name is not None:
print(f"Adding {tokenizer_class} tokenizer files")
tokenizer.save_pretrained(args.save_path)
# Store the state_dict to file.
max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size
state_dict_split = split_torch_state_dict_into_shards(output_state_dict, max_shard_size=max_shard_size)
shards = index = None
for tensors in state_dict_split.filename_to_tensors.values():
shards = {tensor: state_dict[tensor] for tensor in tensors}
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the model
for shard_file, shard in shards.items():
torch.save(shard, os.path.join(args.save_path, shard_file))
if index is None:
print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}")
else:
save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
print(
f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)