in optimum/tpu/xla_model_parallel.py [0:0]
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_model_parallel_region(input_, self.groups, self.world_size, self.rank)
# Matrix multiply.
if self.quant and USE_CUDA:
# GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
scaled_weight = self.weight * self.weight_scaler
output_parallel = F.linear(input_parallel, scaled_weight, self.bias)
elif self.quant:
output_parallel = F.linear(input_parallel, self.weight, self.bias)
output_parallel = output_parallel * self.weight_scaler
else:
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank)
if self.bias is not None:
output = output_ + self.bias
else:
output = output_
return output