in step7_pipeline_parallel_afab/tensor_parallel.py [0:0]
def __init__(self, in_features: int, out_features: int, bias: bool, gather_output: bool = False):
super(ColumnParallelLinear, self).__init__()
self.tp_world_size = pgm.process_group_manager.tp_world_size
self.tp_rank = pgm.process_group_manager.tp_rank
self.in_features = in_features
self.out_features = out_features
assert out_features % self.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
self.output_size_per_partition = out_features // self.tp_world_size
self.gather_output = gather_output
# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
self.weight = nn.Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
if bias:
self.bias = nn.Parameter(torch.Tensor(self.output_size_per_partition))
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
self.reset_parameters()