in optimum/tpu/xla_model_parallel.py [0:0]
def quantize(self):
assert not self.quant
fp_w = deepcopy(self.weight.data)
orig_dtype = fp_w.dtype
fp_w = fp_w.to(torch.float32)
self.weight = Parameter(
torch.empty((self.output_size_per_partition, self.in_features), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition))
qconfig = TensorQConfig(axis=0)
self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig)
self.weight_scaler.data = scale.to(orig_dtype)
self.quant = True