in toolkits/model_checkpoints_convertor/starcoder/checkpoint_reshaping_and_interoperability.py [0:0]
def convert_checkpoint_from_transformers_to_megatron(args):
"""
Convert a checkpoint from HuggingFace Transformers to Megatron-LM. This allows converted checkpoints with variable
tensor parallelism and pipeline parallelism sizes. It takes as input a checkpoint from HuggingFace Transformers
which can have multiple shards.
Args:
args (argparse.Namespace): the arguments to the script
"""
os.makedirs(args.save_path, exist_ok=True)
# Search in directory above this
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
try:
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
except ModuleNotFoundError:
print(
'Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.'
)
exit(1)
# load the transformers model state dict and config
sub_dirs = [
x for x in os.listdir(args.load_path) if x.startswith('pytorch_model')
]
if len(sub_dirs) == 1:
checkpoint_name = 'pytorch_model.bin'
state_dict = torch.load(os.path.join(args.load_path, checkpoint_name),
map_location='cpu')
else:
num_checkpoints = len(sub_dirs) - 1
state_dict = merge_transformers_sharded_states(args.load_path,
num_checkpoints)
config = GPTBigCodeConfig.from_pretrained(args.load_path)
# Saving config and tokenzier files
os.system("cp -rf "+args.load_path+"/*.json "+args.save_path)
# Saving the tracker file
tracker_filepath = os.path.join(args.save_path,
'latest_checkpointed_iteration.txt')
with open(tracker_filepath, 'w') as f:
f.write('release')
# create `release` dir in args.load_path
release_dir = os.path.join(args.save_path, 'release')
os.makedirs(release_dir, exist_ok=True)
for k in list(state_dict.keys()):
if k.replace('transformer.', '') != k:
state_dict[k.replace('transformer.', '')] = state_dict[k]
state_dict.pop(k)
# megatron args
megatron_args = {
'pad_token_id': config.pad_token_id,
'bos_token_id': config.bos_token_id,
'eos_token_id': config.eos_token_id,
'orig_vocab_size': config.vocab_size,
'hidden_size': config.hidden_size,
'num_layers': config.n_layer,
'num_attention_heads': config.n_head,
'max_position_embeddings': config.n_positions,
'ffn_hidden_size': config.n_inner,
'tensor_model_parallel_size': args.target_tensor_model_parallel_size,
'pipeline_model_parallel_size': args.target_pipeline_model_parallel_size,
'data_parallel_size': args.target_data_parallel_size,
'make_vocab_size_divisible_by': args.make_vocab_size_divisible_by,
'rank': 0,
'tokenizer_type': 'StarcoderTokenizer',
}
margs = types.SimpleNamespace()
for k, v in megatron_args.items():
setattr(margs, k, v)
# params dtype
if args.target_params_dtype == 'fp16':
dtype = torch.float16
elif args.target_params_dtype == 'bf16':
dtype = torch.bfloat16
else:
dtype = torch.float32
setattr(margs, 'params_dtype', dtype)
# Convert.
print('Converting')
output_state_dict = []
for i in range(args.target_tensor_model_parallel_size):
output_state_dict.append({})
# Embedding layer
print('converting embedding layer')
word_embedding = state_dict['wte.weight'].to(dtype)
word_embedding_position = state_dict['wpe.weight'].to(dtype)
orig_vocab_size = config.vocab_size
padded_vocab_size = _vocab_size_with_padding(orig_vocab_size, margs)
setattr(margs, 'padded_vocab_size', padded_vocab_size)
# Cut out extra padding we don't need
if orig_vocab_size > padded_vocab_size:
full_word_embed = word_embedding[:padded_vocab_size, :]
# Expanding embedding to larger size by replicating final entry
elif orig_vocab_size < padded_vocab_size:
padding_size = padded_vocab_size - orig_vocab_size
full_word_embed = torch.cat(
(word_embedding,
word_embedding[-1].unsqueeze(0).expand(padding_size, -1)))
# Same size!
else:
full_word_embed = word_embedding
config.vocab_size = full_word_embed.shape[0]
print(f'New vocab size: {config.vocab_size}')
# Split into new tensor model parallel sizes
out_word_embed = torch.chunk(full_word_embed,
args.target_tensor_model_parallel_size,
dim=0)
for i in range(args.target_tensor_model_parallel_size):
word_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], 'model.language_model.embedding')
word_emb_dict['word_embeddings.weight'] = out_word_embed[i]
word_emb_dict['position_embeddings.weight'] = word_embedding_position
# Transformer layers
print('converting transformer layers')
if config.n_layer % args.target_pipeline_model_parallel_size != 0:
raise ValueError(
f'Number of layers ({config.n_layer}) must be divisible by number of pipeline parallelism'
f' ({args.target_pipeline_model_parallel_size})')
num_layers = config.n_layer // args.target_pipeline_model_parallel_size
layer_re = re.compile('h\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)')
# The number of heads.
heads = config.n_head
# The hidden_size per head.
hidden_size_per_head = config.hidden_size // config.n_head
for pp_rank in range(args.target_pipeline_model_parallel_size):
print(pp_rank)
layer_offset = pp_rank * num_layers
if pp_rank > 0:
output_state_dict = []
for i in range(args.target_tensor_model_parallel_size):
output_state_dict.append({})
for layer in range(num_layers):
pp_layer_id = layer + layer_offset
layers_to_copy = [
layer_name for layer_name in state_dict.keys()
if layer_name.startswith(f'h.{pp_layer_id}.')
]
for layer_name in layers_to_copy:
m = layer_re.match(layer_name)
# Stop if that's not a layer
if m is None:
break
# The index of the layer.
_ = int(m.group(1))
# The name of the operation.
op_name = m.group(2)
# Is it a weight or a bias?
weight_or_bias = m.group(3)
params = state_dict[layer_name].to(dtype)
# handle layernorm
if op_name.startswith('ln_'):
out_name = 'input_layernorm' if op_name=='ln_1' else 'post_attention_layernorm'
layer_name = f'layers.{layer}.{out_name}.{weight_or_bias}'
# handle attention K, V, Q weights
elif op_name.startswith('attn.c_attn'
) and weight_or_bias == 'weight':
layer_name = f'layers.{layer}.self_attention.query.{weight_or_bias}'
# handle attention K, V, Q bias
elif op_name.startswith('attn.c_attn'
) and weight_or_bias == 'bias':
layer_name = f'layers.{layer}.self_attention.query.{weight_or_bias}'
# handle attention and mlp weights
elif weight_or_bias == 'weight':
out_name = transformers_to_megatron.get(op_name, None)
if out_name is None:
continue
layer_name = f'layers.{layer}.{out_name}.{weight_or_bias}'
# handle attention and mlp bias
elif weight_or_bias == 'bias':
out_name = transformers_to_megatron.get(op_name, None)
if out_name is None:
continue
layer_name = f'layers.{layer}.{out_name}.{weight_or_bias}'
# skip
else:
continue
if op_name + '.' + weight_or_bias in tensor_parallel_params_hf:
dim = 1 if op_name in [
'attn.c_proj', 'mlp.c_proj'
] else 0
if 'c_attn' not in op_name:
params = torch.chunk(
params,
args.target_tensor_model_parallel_size,
dim=dim)
else:
if weight_or_bias == 'weight':
params = (torch.chunk(
params[:params.shape[1]],
args.target_tensor_model_parallel_size,
dim=dim), params[params.shape[1]:])
else:
params = (torch.chunk(
params[:config.hidden_size],
args.target_tensor_model_parallel_size,
dim=dim), params[config.hidden_size:])
for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(
output_state_dict[i], 'model.language_model.encoder')
if 'c_attn' not in op_name:
params_dict[layer_name] = (params[i] if (
op_name + '.' +
weight_or_bias in tensor_parallel_params_hf) else params)
else:
params_dict[layer_name] = (params[0][i] if (
op_name + '.' +
weight_or_bias in tensor_parallel_params_hf) else params[0])
params_dict[layer_name.replace('query', 'key_value')] = params[1]
if pp_rank == args.target_pipeline_model_parallel_size - 1:
# handle final layernorm
for weight_or_bias in ['weight', 'bias']:
params = state_dict[f'ln_f.{weight_or_bias}'].to(dtype)
layer_name = f'final_layernorm.{weight_or_bias}'
for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(
output_state_dict[i], 'model.language_model.encoder')
params_dict[layer_name] = params
# saving the state dict as per the tp_rank and pp_rank
for tp_rank in range(args.target_tensor_model_parallel_size):
output_state_dict[tp_rank]['checkpoint_version'] = 3.0
output_state_dict[tp_rank]['args'] = margs
checkpoint_dir = (f'mp_rank_{tp_rank:02d}'
if args.target_pipeline_model_parallel_size == 1
else f'mp_rank_{tp_rank:02d}_{pp_rank:03d}')
checkpoint_name = 'model_optim_rng.pt'
checkpoint_dir = os.path.join(release_dir, checkpoint_dir)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
if args.print_checkpoint_structure:
print(
f'Checkpoint structure of model state dict shard belonging to TP rank {tp_rank} and PP rank'
f' {pp_rank}:')
recursive_print(None, output_state_dict[tp_rank])
torch.save(output_state_dict[tp_rank], checkpoint_path)