def reset_parameters()

in picotron/tensor_parallel/tensor_parallel.py [0:0]


    def reset_parameters(self):
        master_weight = torch.empty(
            self.num_embeddings, 
            self.embedding_dim, 
            dtype=self.weight.dtype,
            device=self.weight.device, 
            requires_grad=False
        )
        torch.nn.init.normal_(master_weight, mean=0.0, std=1.0)
        # Split the model into size of self.num_embeddings_per_partition
        weight_list = torch.split(master_weight, self.num_embeddings_per_partition, dim=0)
        self.weight.data = weight_list[self.tp_rank].contiguous()