def __init__()

in optimum/quanto/tensor/weights/marlin/fp8/qbits.py [0:0]


    def __init__(self, qtype, axis, size, stride, data, scale, requires_grad=False):
        assert axis == 0
        assert data.ndim == 2

        out_features = size[0]
        self._workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=data.device)

        # TODO: Here we should use `not isinstance(data, MarlinF8PackedTensor)`, but `torch.compile` is bugged when using that.
        # Somewhere in the internals of torch.compile, `data` gets converted to a `torch._subclasses.fake_tensor.FakeTensor` not inheriting from `MarlinF8PackedTensor` and torch then goes into the wrong controlflow.
        # Reference: https://pytorch.slack.com/archives/C033H6DJSJU/p1721837684035049
        if data.dtype != torch.int32:
            assert scale.shape == (out_features, 1)
            scale_perm_single = get_scale_perms()
            scale = scale.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
            scale = scale.reshape(-1, out_features).contiguous()

            data_packed = MarlinF8PackedTensor.pack(data)  # pack fp8 data to in32, and apply marlier re-ordering.
        else:
            # When freezing (`model.freeze()`), the data is already a MarlinF8PackedTensor and scale is already repacked.
            data_packed = data

        super().__init__(
            qtype, axis, size, stride, data_packed, scale, activation_qtype=qfloat8_e4m3fn, requires_grad=requires_grad
        )