in optimum/tpu/xla_model_parallel.py [0:0]
def quantize(self):
assert not self.quant
fp_w = deepcopy(self.weight.data)
orig_dtype = fp_w.dtype
fp_w = fp_w.to(torch.float32)
self.weight = Parameter(
torch.empty((self.out_features, self.input_size_per_partition), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.out_features))
qconfig = TensorQConfig(axis=0)
self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig)
self.weight_scaler.data = scale.to(orig_dtype)
self.quant = True