in picotron/checkpoint.py [0:0]
def adjust_tensor_size(self, tensor, name):
"""Resize tensor based on architecture changes and tensor parallelism."""
tp_rank = pgm.process_group_manager.tp_rank
tp_size = pgm.process_group_manager.tp_world_size
hidden_size = self.model_config.hidden_size
# Handle embedding and final projection layers
if 'embedding.weight' in name or 'final_proj.weight' in name:
vocab_size = self.model_config.vocab_size
vocab_per_rank = vocab_size // tp_size
if tensor.shape[0] != vocab_per_rank:
start_idx = tp_rank * vocab_per_rank
end_idx = start_idx + vocab_per_rank
tensor = tensor[start_idx:end_idx, :]
return tensor
# Handle attention layers
if 'attention' in name:
head_dim = hidden_size // self.model_config.num_attention_heads
if 'q_proj.weight' in name:
total_heads = self.model_config.num_attention_heads
heads_per_rank = total_heads // tp_size
target_dim = heads_per_rank * head_dim
elif 'k_proj.weight' in name or 'v_proj.weight' in name:
total_heads = self.model_config.num_key_value_heads
heads_per_rank = total_heads // tp_size
target_dim = heads_per_rank * head_dim
elif 'out_proj.weight' in name:
# For out_proj, we split along the second dimension
target_dim = tensor.shape[0] # First dimension stays the same
if tensor.shape[1] != hidden_size // tp_size:
tensor = tensor[:, (hidden_size // tp_size) * tp_rank:(hidden_size // tp_size) * (tp_rank + 1)]
return tensor
else:
return tensor
if tensor.shape[0] != target_dim:
if target_dim > tensor.shape[0]:
pad_tensor = torch.empty(target_dim - tensor.shape[0], tensor.shape[1],
dtype=tensor.dtype, device=tensor.device)
tensor = torch.cat([tensor, pad_tensor], dim=0)
else:
tensor = tensor[:target_dim, :]
# Handle MLP layers
elif 'mlp' in name:
intermediate_size = self.model_config.intermediate_size
intermediate_size_per_rank = intermediate_size // tp_size
if 'up_proj.weight' in name or 'gate_proj.weight' in name:
if tensor.shape[0] != intermediate_size_per_rank:
start_idx = tp_rank * intermediate_size_per_rank
end_idx = start_idx + intermediate_size_per_rank
tensor = tensor[start_idx:end_idx, :]
elif 'down_proj.weight' in name:
if tensor.shape[1] != intermediate_size_per_rank:
start_idx = tp_rank * intermediate_size_per_rank
end_idx = start_idx + intermediate_size_per_rank
tensor = tensor[:, start_idx:end_idx]
return tensor