in picotron/tensor_parallel/tensor_parallel.py [0:0]
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.async_all_reduce:
output = linear_with_async_all_reduce(x, self.weight, self.bias)
else:
output = linear_with_all_reduce(x, self.weight, self.bias)
if self.gather_output:
output = GatherFromModelParallelRegion.apply(output)
return output