def __init__()

in optimum/quanto/tensor/weights/marlin/int4/qbits.py [0:0]


    def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):
        assert axis == 0
        out_features, in_features = size
        if not isinstance(data, MarlinInt4PackedTensor):
            assert type(data) is torch.Tensor
            # Format data, scale and shift for optimized CUDA gemm
            ungrouped = ungroup(data, axis=0, orig_shape=size)
            data = MarlinInt4PackedTensor.pack(ungrouped)
            scale = scale.reshape(out_features, in_features // group_size).t().contiguous()
            shift = shift.reshape(out_features, in_features // group_size).t()
            if not shift.dtype.is_floating_point:
                # Integer shift must be scaled
                shift = scale * shift
            # Shift must be negated
            shift = -shift.contiguous()
            # Finally, apply scale and shift permutations
            scale = marlin_permute(scale)
            shift = marlin_permute(shift)
        super().__init__(qtype, axis, group_size, size, stride, data, scale, shift)
        self._workspace = torch.zeros(out_features // 128 * 16, dtype=torch.int, device=data.device)