in optimum/tpu/xla_model_parallel.py [0:0]
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_, self.groups, self.world_size, self.rank)
# Matrix multiply.
if self.quant and USE_CUDA:
# GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
scaled_weight = self.weight * self.weight_scaler
output_parallel = F.linear(input_parallel, scaled_weight, self.bias)
elif self.quant:
output_parallel = F.linear(input_parallel, self.weight, self.bias)
output_parallel = output_parallel * self.weight_scaler
else:
output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank)
else:
output = output_parallel
return output