in picotron/checkpoint.py [0:0]
def init_model_with_materialized_weights(model, model_config, save_dir):
#Initialize model with correct tensor shapes but random weights
initialization_manager = InitializationManager(model, model_config)
layer_names = initialization_manager.get_layer_names_in_sft_format()
# print(f"Rank {pgm.process_group_manager.global_rank} responsible for {len(layer_names)} layers")
if len(layer_names) == 0:
raise Exception("Some ranks has no layers. There are too many ranks and not enough layers to distribute.")
state_dict = {}
index_path = os.path.join(save_dir, "model.safetensors.index.json")
if os.path.exists(index_path): # Handle sharded checkpoint
with open(index_path, 'r') as f:
index = json.load(f)
for sft_name in layer_names:
shard_path = os.path.join(save_dir, index['weight_map'][sft_name])
with safe_open(shard_path, framework="pytorch", device="cpu") as f:
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
tensor = f.get_tensor(sft_name)
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
state_dict[hf_name] = tensor
else: # Handle single file checkpoint
safetensors_path = os.path.join(save_dir, "model.safetensors")
with safe_open(safetensors_path, framework="pytorch", device="cpu") as f:
if len(f.keys()) > len(layer_names):
print(f"rank {pgm.process_group_manager.global_rank}: Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
for sft_name in layer_names:
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
tensor = f.get_tensor(sft_name)
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
state_dict[hf_name] = tensor
# Force creation of lm_head (even if it is tie_embedding)
if pgm.process_group_manager.pp_is_last_stage or not isinstance(model, PipelineParallel):
vocab_size = model_config.vocab_size
if pgm.process_group_manager.tp_world_size > 1:
# For TP>1, the final_proj is already wrapped in ColumnParallel
# Just need to initialize state_dict with correct sharded size
vocab_per_rank = vocab_size // pgm.process_group_manager.tp_world_size
# Note: For ColumnParallelLinear, weight shape should be (output_size_per_partition, in_features)
state_dict['final_proj.weight'] = torch.zeros(vocab_per_rank, model_config.hidden_size)
else:
# For TP=1, create the full layer. FinalProjection expects weight shape (out_features, in_features)
# FinalProjection is needed so that we cann call .reset_parameters() on it
model.final_proj = FinalProjection(model_config.hidden_size, vocab_size, bias=False)
state_dict['final_proj.weight'] = torch.zeros(vocab_size, model_config.hidden_size)
# Synchronize across distributed processes and load weights
dist.barrier()
model.load_state_dict(state_dict, strict=True, assign=True)
dist.barrier()
assert_no_meta_tensors(model)
# Initialize model parameters
initialization_manager.init_model_parameters()
dist.barrier()
return model