in picotron/tensor_parallel/tensor_parallel.py [0:0]
def reset_parameters(self):
# Initialize weight tensor with the default initialization method used for nn.Linear in PyTorch
master_weight = torch.empty(
self.out_features,
self.in_features,
dtype=self.weight.dtype,
device=self.weight.device,
requires_grad=False
)
# Calculate bound based on master weight's input dimension
k = 1 / master_weight.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(master_weight, -bound, bound)
# Split the model into size of self.output_size_per_partition
weight_list = torch.split(master_weight, self.output_size_per_partition, dim=0)
self.weight.data = weight_list[self.tp_rank].contiguous()