in optimum/tpu/xla_model_parallel.py [0:0]
def _find_qparams(x: torch.Tensor, qconfig: TensorQConfig):
# Only support per-channel symmetric quant to int8 now
axis = qconfig.axis
dtype = qconfig.dtype
symmetric_quant = qconfig.symmetric_quant
quant_min = qconfig.quant_min
quant_max = qconfig.quant_max
assert axis >= 0 and axis < len(x.shape)
assert dtype == torch.int8
min_val, max_val = _find_per_channel_min_max(x, axis)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
scale = torch.ones(min_val_neg.size(), dtype=torch.float32)
if symmetric_quant:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
eps = torch.zeros_like(scale).fill_(EPS)
scale = torch.max(scale, eps)
return scale, None
else:
assert symmetric_quant