def convert_checkpoint_from_transformers_to_megatron()

in toolkits/model_checkpoints_convertor/bloom/checkpoint_reshaping_and_interoperability.py [0:0]


def convert_checkpoint_from_transformers_to_megatron(args):
    """
    Convert a checkpoint from HuggingFace Transformers to Megatron-LM. This allows converted checkpoints with variable
    tensor parallelism and pipeline parallelism sizes. It takes as input a checkpoint from HuggingFace Transformers
    which can have multiple shards.
    Args:
        args (argparse.Namespace): the arguments to the script
    """
    os.makedirs(args.save_path, exist_ok=True)
    # Search in directory above this
    sys.path.append(
        os.path.abspath(os.path.join(os.path.dirname(__file__),
                                     os.path.pardir)))
    if args.megatron_path is not None:
        sys.path.insert(0, args.megatron_path)

    try:
        from megatron.tokenizer.tokenizer import _vocab_size_with_padding
    except ModuleNotFoundError:
        print(
            'Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.'
        )
        exit(1)

    # load the transformers model state dict and config
    sub_dirs = [
        x for x in os.listdir(args.load_path) if x.startswith('pytorch_model')
    ]
    if len(sub_dirs) == 1:
        checkpoint_name = 'pytorch_model.bin'
        state_dict = torch.load(os.path.join(args.load_path, checkpoint_name),
                                map_location='cpu')
    else:
        num_checkpoints = len(sub_dirs) - 1
        state_dict = merge_transformers_sharded_states(args.load_path,
                                                       num_checkpoints)

    config = BloomConfig.from_pretrained(args.load_path)

    # Saving config and tokenzier files
    os.system("cp -rf "+args.load_path+"/*.json "+args.save_path)


    # Saving the tracker file
    tracker_filepath = os.path.join(args.save_path,
                                    'latest_checkpointed_iteration.txt')
    with open(tracker_filepath, 'w') as f:
        f.write('release')

    # create `release` dir in args.load_path
    release_dir = os.path.join(args.save_path, 'release')
    os.makedirs(release_dir, exist_ok=True)

    for k in list(state_dict.keys()):
        if k.replace('transformer.', '') != k:
            state_dict[k.replace('transformer.', '')] = state_dict[k]
            state_dict.pop(k)

    # megatron args
    megatron_args = {
        'unk_token_id': config.unk_token_id,
        'pad_token_id': config.pad_token_id,
        'bos_token_id': config.bos_token_id,
        'eos_token_id': config.eos_token_id,
        'orig_vocab_size': config.vocab_size,
        'hidden_size': config.hidden_size,
        'num_layers': config.n_layer,
        'num_attention_heads': config.n_head,
        'max_position_embeddings': 1024,
        'ffn_hidden_size': config.hidden_size * 4,
        'tensor_model_parallel_size': args.target_tensor_model_parallel_size,
        'pipeline_model_parallel_size': args.target_pipeline_model_parallel_size,
        'data_parallel_size': args.target_data_parallel_size,
        'make_vocab_size_divisible_by': args.make_vocab_size_divisible_by,
        'rank': 0,
        'tokenizer_type': 'BloomTokenizer',
    }

    margs = types.SimpleNamespace()
    for k, v in megatron_args.items():
        setattr(margs, k, v)

    # params dtype
    if args.target_params_dtype == 'fp16':
        dtype = torch.float16
    elif args.target_params_dtype == 'bf16':
        dtype = torch.bfloat16
    else:
        dtype = torch.float32
    setattr(margs, 'params_dtype', dtype)

    # Convert.
    print('Converting')
    output_state_dict = []
    for i in range(args.target_tensor_model_parallel_size):
        output_state_dict.append({})

    # Embedding layer
    print('converting embedding layer')
    word_embedding = state_dict['word_embeddings.weight'].to(dtype)
    word_embedding_layernorm_weight = state_dict[
        'word_embeddings_layernorm.weight'].to(dtype)
    word_embedding_layernorm_bias = state_dict[
        'word_embeddings_layernorm.bias'].to(dtype)
    orig_vocab_size = config.vocab_size
    padded_vocab_size = _vocab_size_with_padding(orig_vocab_size, margs)
    setattr(margs, 'padded_vocab_size', padded_vocab_size)

    # Cut out extra padding we don't need
    if orig_vocab_size > padded_vocab_size:
        full_word_embed = word_embedding[:padded_vocab_size, :]
        full_word_embed_layernorm_wright = word_embedding_layernorm_weight
        full_word_embed_layernorm_bias = word_embedding_layernorm_bias

    # Expanding embedding to larger size by replicating final entry
    elif orig_vocab_size < padded_vocab_size:
        padding_size = padded_vocab_size - orig_vocab_size
        full_word_embed = torch.cat(
            (word_embedding,
             word_embedding[-1].unsqueeze(0).expand(padding_size, -1)))

        full_word_embed_layernorm_wright = word_embedding_layernorm_weight
        full_word_embed_layernorm_bias = word_embedding_layernorm_bias

    # Same size!
    else:
        full_word_embed = word_embedding
        full_word_embed_layernorm_wright = word_embedding_layernorm_weight
        full_word_embed_layernorm_bias = word_embedding_layernorm_bias

    config.vocab_size = full_word_embed.shape[0]
    print(f'New vocab size: {config.vocab_size}')
    # Split into new tensor model parallel sizes
    out_word_embed = torch.chunk(full_word_embed,
                                 args.target_tensor_model_parallel_size,
                                 dim=0)

    for i in range(args.target_tensor_model_parallel_size):
        word_emb_dict = get_element_from_dict_by_path(
            output_state_dict[i], 'model.language_model.embedding')
        word_emb_dict['word_embeddings.weight'] = out_word_embed[i]
        word_emb_dict[
            'word_embeddings.norm.weight'] = full_word_embed_layernorm_wright

        word_emb_dict[
            'word_embeddings.norm.bias'] = full_word_embed_layernorm_bias

    # Transformer layers
    print('converting transformer layers')
    if config.n_layer % args.target_pipeline_model_parallel_size != 0:
        raise ValueError(
            f'Number of layers ({config.n_layer}) must be divisible by number of pipeline parallelism'
            f' ({args.target_pipeline_model_parallel_size})')
    num_layers = config.n_layer // args.target_pipeline_model_parallel_size

    layer_re = re.compile('h\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)')
    # The number of heads.
    heads = config.n_head
    # The hidden_size per head.
    hidden_size_per_head = config.hidden_size // config.n_head
    for pp_rank in range(args.target_pipeline_model_parallel_size):
        print(pp_rank)
        layer_offset = pp_rank * num_layers
        if pp_rank > 0:
            output_state_dict = []
            for i in range(args.target_tensor_model_parallel_size):
                output_state_dict.append({})

        for layer in range(num_layers):
            pp_layer_id = layer + layer_offset
            layers_to_copy = [
                layer_name for layer_name in state_dict.keys()
                if layer_name.startswith(f'h.{pp_layer_id}.')
            ]

            for layer_name in layers_to_copy:
                m = layer_re.match(layer_name)
                # Stop if that's not a layer
                if m is None:
                    break

                # The index of the layer.
                _ = int(m.group(1))
                # The name of the operation.
                op_name = m.group(2)
                # Is it a weight or a bias?
                weight_or_bias = m.group(3)
                params = state_dict[layer_name].to(dtype)
                # handle layernorm

                if op_name.startswith('input_layernorm') or op_name.startswith(
                        'post_attention_layernorm'):
                    out_name = 'input_layernorm' if op_name.endswith(
                        'input_layernorm') else 'post_attention_layernorm'
                    layer_name = f'layers.{layer}.{out_name}.{weight_or_bias}'

                # handle attention K, V, Q weights
                elif op_name.startswith('self_attention.query_key_value'
                                        ) and weight_or_bias == 'weight':

                    layer_name = f'layers.{layer}.self_attention.query_key_value.{weight_or_bias}'

                # handle attention K, V, Q bias
                elif op_name.startswith('self_attention.query_key_value'
                                        ) and weight_or_bias == 'bias':
                    layer_name = f'layers.{layer}.self_attention.query_key_value.{weight_or_bias}'

                # handle attention and mlp weights
                elif weight_or_bias == 'weight':
                    out_name = transformers_to_megatron.get(op_name, None)
                    if out_name is None:
                        continue
                    layer_name = f'layers.{layer}.{out_name}.{weight_or_bias}'

                # handle attention and mlp bias
                elif weight_or_bias == 'bias':
                    out_name = transformers_to_megatron.get(op_name, None)
                    if out_name is None:
                        continue
                    layer_name = f'layers.{layer}.{out_name}.{weight_or_bias}'

                # skip
                else:
                    continue

                if op_name + '.' + weight_or_bias in tensor_parallel_params:
                    dim = 1 if op_name in [
                        'self_attention.dense', 'mlp.dense_4h_to_h'
                    ] else 0
                    params = torch.chunk(
                        params,
                        args.target_tensor_model_parallel_size,
                        dim=dim)

                for i in range(args.target_tensor_model_parallel_size):
                    params_dict = get_element_from_dict_by_path(
                        output_state_dict[i], 'model.language_model.encoder')
                    params_dict[layer_name] = (params[i] if (
                        op_name + '.' +
                        weight_or_bias in tensor_parallel_params) else params)

        if pp_rank == args.target_pipeline_model_parallel_size - 1:
            # handle final layernorm
            for weight_or_bias in ['weight', 'bias']:
                params = state_dict[f'ln_f.{weight_or_bias}'].to(dtype)
                layer_name = f'final_layernorm.{weight_or_bias}'
                for i in range(args.target_tensor_model_parallel_size):
                    params_dict = get_element_from_dict_by_path(
                        output_state_dict[i], 'model.language_model.encoder')
                    params_dict[layer_name] = params

            # add the LM head
            for i in range(args.target_tensor_model_parallel_size):
                params_dict = get_element_from_dict_by_path(
                    output_state_dict[i], 'model.word_embeddings_for_head')
                params_dict['weight'] = out_word_embed[i]

        # saving the state dict as per the tp_rank and pp_rank
        for tp_rank in range(args.target_tensor_model_parallel_size):
            output_state_dict[tp_rank]['checkpoint_version'] = 3.0
            output_state_dict[tp_rank]['args'] = margs
            checkpoint_dir = (f'mp_rank_{tp_rank:02d}'
                              if args.target_pipeline_model_parallel_size == 1
                              else f'mp_rank_{tp_rank:02d}_{pp_rank:03d}')

            checkpoint_name = 'model_optim_rng.pt'
            checkpoint_dir = os.path.join(release_dir, checkpoint_dir)
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
            if args.print_checkpoint_structure:
                print(
                    f'Checkpoint structure of model state dict shard belonging to TP rank {tp_rank} and PP rank'
                    f' {pp_rank}:')
                recursive_print(None, output_state_dict[tp_rank])
            torch.save(output_state_dict[tp_rank], checkpoint_path)