def __init__()

in optimum/quanto/tensor/weights/awq/qbits.py [0:0]


    def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):
        assert axis == 0
        if not isinstance(data, AWQPackedTensor):
            assert type(data) is torch.Tensor
            # Format data, scale and shift for optimized CUDA gemm
            ungrouped = ungroup(data, axis=0, orig_shape=size)
            data = AWQPackedTensor.pack(ungrouped, packing=AWQPacking.V2)
            out_features, in_features = size
            scale = scale.reshape(out_features, in_features // group_size).t().contiguous()
            shift = shift.reshape(out_features, in_features // group_size).t()
            if not shift.dtype.is_floating_point:
                # Integer shift must be scaled
                shift = scale * shift
            # Shift must be negated
            shift = -shift.contiguous()
        super().__init__(qtype, axis, group_size, size, stride, data, scale, shift)