in step7_pipeline_parallel_afab/tensor_parallel.py [0:0]
def reset_parameters(self):
# Initialize weight tensor with the default initialization method used for nn.Linear in PyTorch
if self.tp_world_size == 1:
# U(-sqrt(k), sqrt(k))
k = 1 / self.weight.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(self.weight, -bound, bound)
return
# When TP > 1, Initialize master weight
master_weight = torch.empty(self.out_features, self.in_features, dtype=self.weight.dtype, requires_grad=False)
# Calculate bound based on master weight's input dimension. U(-sqrt(k), sqrt(k))
k = 1 / master_weight.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(master_weight, -bound, bound)
# Split the model into size of self.output_size_per_partitio and take the corresponding partition
weight_list = torch.split(master_weight, self.output_size_per_partition, dim=0)
self.weight.data = weight_list[self.tp_rank].contiguous()