def __init__()

in optimum/quanto/tensor/weights/tinygemm/qbits.py [0:0]


    def __init__(self, qtype, axis, group_size, size, stride, data, scale_shift, requires_grad=False):
        assert axis == 0
        if not isinstance(data, TinyGemmPackedTensor):
            assert type(data) is torch.Tensor
            assert isinstance(scale_shift, (tuple, list))
            # Format data, scale and shift for tinygemm
            ungrouped = ungroup(data, axis=0, orig_shape=size)
            self._data = TinyGemmPackedTensor.pack(ungrouped)
            out_features, in_features = size
            scale, shift = scale_shift
            scale = scale.reshape(out_features, in_features // group_size, 1)
            shift = shift.reshape(out_features, in_features // group_size, 1)
            if not shift.dtype.is_floating_point:
                # Integer shift must be scaled
                shift = scale * shift
            # The tinygemm kernel actually uses the mid-point of the quantization range as shift
            min_range = -shift
            half_qrange = 2 ** (qtype.bits - 1) * scale
            # This operation is lossy for bfloat16, and the actual value of shift will be lost
            shift = min_range + half_qrange
            # Scale and shift are actually stored in the same tensor
            self._scale_shift = torch.cat([scale, shift], 2).transpose(0, 1).contiguous()
        else:
            self._data = data
            self._scale_shift = scale_shift
        self._qtype = qtype
        self._axis = axis
        self._group_size = group_size