def _find_qparams()

in optimum/tpu/xla_model_parallel.py [0:0]


def _find_qparams(x: torch.Tensor, qconfig: TensorQConfig):
    # Only support per-channel symmetric quant to int8 now
    axis = qconfig.axis
    dtype = qconfig.dtype
    symmetric_quant = qconfig.symmetric_quant
    quant_min = qconfig.quant_min
    quant_max = qconfig.quant_max
    assert axis >= 0 and axis < len(x.shape)
    assert dtype == torch.int8
    min_val, max_val = _find_per_channel_min_max(x, axis)
    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
    scale = torch.ones(min_val_neg.size(), dtype=torch.float32)
    if symmetric_quant:
        max_val_pos = torch.max(-min_val_neg, max_val_pos)
        scale = max_val_pos / (float(quant_max - quant_min) / 2)
        eps = torch.zeros_like(scale).fill_(EPS)
        scale = torch.max(scale, eps)
        return scale, None
    else:
        assert symmetric_quant