def quantize()

in optimum/tpu/xla_model_parallel.py [0:0]


    def quantize(self):
        assert not self.quant
        fp_w = deepcopy(self.weight.data)
        orig_dtype = fp_w.dtype
        fp_w = fp_w.to(torch.float32)
        self.weight = Parameter(
            torch.empty((self.output_size_per_partition, self.in_features), dtype=torch.int8),
            requires_grad=False,
        )
        self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition))
        qconfig = TensorQConfig(axis=0)
        self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig)
        self.weight_scaler.data = scale.to(orig_dtype)
        self.quant = True