def init_model_with_materialized_weights()

in picotron/checkpoint.py [0:0]


def init_model_with_materialized_weights(model, model_config, save_dir):
    #Initialize model with correct tensor shapes but random weights
    initialization_manager = InitializationManager(model, model_config)
    layer_names = initialization_manager.get_layer_names_in_sft_format()

    # print(f"Rank {pgm.process_group_manager.global_rank} responsible for {len(layer_names)} layers")
    
    if len(layer_names) == 0:
        raise Exception("Some ranks has no layers. There are too many ranks and not enough layers to distribute.")

    state_dict = {}

    index_path = os.path.join(save_dir, "model.safetensors.index.json")

    if os.path.exists(index_path): # Handle sharded checkpoint
        with open(index_path, 'r') as f:
            index = json.load(f)
        
        for sft_name in layer_names:
            shard_path = os.path.join(save_dir, index['weight_map'][sft_name])
            with safe_open(shard_path, framework="pytorch", device="cpu") as f:
                hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
                tensor = f.get_tensor(sft_name)
                tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
                state_dict[hf_name] = tensor

    else: # Handle single file checkpoint
        safetensors_path = os.path.join(save_dir, "model.safetensors")
        with safe_open(safetensors_path, framework="pytorch", device="cpu") as f:
            if len(f.keys()) > len(layer_names):
                print(f"rank {pgm.process_group_manager.global_rank}: Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
            
            for sft_name in layer_names:
                hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
                tensor = f.get_tensor(sft_name)
                tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
                state_dict[hf_name] = tensor

    # Force creation of lm_head (even if it is tie_embedding)
    if pgm.process_group_manager.pp_is_last_stage or not isinstance(model, PipelineParallel):
        vocab_size = model_config.vocab_size
        if pgm.process_group_manager.tp_world_size > 1:
            # For TP>1, the final_proj is already wrapped in ColumnParallel
            # Just need to initialize state_dict with correct sharded size
            vocab_per_rank = vocab_size // pgm.process_group_manager.tp_world_size
            # Note: For ColumnParallelLinear, weight shape should be (output_size_per_partition, in_features)
            state_dict['final_proj.weight'] = torch.zeros(vocab_per_rank, model_config.hidden_size)
        else:
            # For TP=1, create the full layer. FinalProjection expects weight shape (out_features, in_features)
            # FinalProjection is needed so that we cann call .reset_parameters() on it
            model.final_proj = FinalProjection(model_config.hidden_size, vocab_size, bias=False)
            state_dict['final_proj.weight'] = torch.zeros(vocab_size, model_config.hidden_size)

    # Synchronize across distributed processes and load weights
    dist.barrier()
    model.load_state_dict(state_dict, strict=True, assign=True)
    dist.barrier()

    assert_no_meta_tensors(model)
    # Initialize model parameters
    initialization_manager.init_model_parameters()
    dist.barrier()
    return model