def forward()

in optimum/tpu/xla_model_parallel.py [0:0]


    def forward(self, input_: torch.Tensor) -> torch.Tensor:  # type: ignore
        # Set up backprop all-reduce.
        input_parallel = copy_to_model_parallel_region(input_, self.groups, self.world_size, self.rank)
        # Matrix multiply.
        if self.quant and USE_CUDA:
            # GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
            scaled_weight = self.weight * self.weight_scaler
            output_parallel = F.linear(input_parallel, scaled_weight, self.bias)
        elif self.quant:
            output_parallel = F.linear(input_parallel, self.weight, self.bias)
            output_parallel = output_parallel * self.weight_scaler
        else:
            output_parallel = F.linear(input_parallel, self.weight, self.bias)
        if self.gather_output:
            # All-gather across the partitions.
            output = gather_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank)
        else:
            output = output_parallel
        return output