in picotron/tensor_parallel/tensor_parallel.py [0:0]
def reset_parameters(self):
master_weight = torch.empty(
self.num_embeddings,
self.embedding_dim,
dtype=self.weight.dtype,
device=self.weight.device,
requires_grad=False
)
torch.nn.init.normal_(master_weight, mean=0.0, std=1.0)
# Split the model into size of self.num_embeddings_per_partition
weight_list = torch.split(master_weight, self.num_embeddings_per_partition, dim=0)
self.weight.data = weight_list[self.tp_rank].contiguous()