def forward()

in optimum/tpu/xla_model_parallel.py [0:0]


    def forward(self, input_: torch.Tensor) -> torch.Tensor:  # type:ignore
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
            input_parallel = scatter_to_model_parallel_region(input_, self.groups, self.world_size, self.rank)
        # Matrix multiply.
        if self.quant and USE_CUDA:
            # GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
            scaled_weight = self.weight * self.weight_scaler
            output_parallel = F.linear(input_parallel, scaled_weight, self.bias)
        elif self.quant:
            output_parallel = F.linear(input_parallel, self.weight, self.bias)
            output_parallel = output_parallel * self.weight_scaler
        else:
            output_parallel = F.linear(input_parallel, self.weight)
        # All-reduce across all the partitions.
        output_ = reduce_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank)
        if self.bias is not None:
            output = output_ + self.bias
        else:
            output = output_
        return output