in optimum/quanto/tensor/weights/marlin/fp8/qbits.py [0:0]
def __init__(self, qtype, axis, size, stride, data, scale, requires_grad=False):
assert axis == 0
assert data.ndim == 2
out_features = size[0]
self._workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=data.device)
# TODO: Here we should use `not isinstance(data, MarlinF8PackedTensor)`, but `torch.compile` is bugged when using that.
# Somewhere in the internals of torch.compile, `data` gets converted to a `torch._subclasses.fake_tensor.FakeTensor` not inheriting from `MarlinF8PackedTensor` and torch then goes into the wrong controlflow.
# Reference: https://pytorch.slack.com/archives/C033H6DJSJU/p1721837684035049
if data.dtype != torch.int32:
assert scale.shape == (out_features, 1)
scale_perm_single = get_scale_perms()
scale = scale.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
scale = scale.reshape(-1, out_features).contiguous()
data_packed = MarlinF8PackedTensor.pack(data) # pack fp8 data to in32, and apply marlier re-ordering.
else:
# When freezing (`model.freeze()`), the data is already a MarlinF8PackedTensor and scale is already repacked.
data_packed = data
super().__init__(
qtype, axis, size, stride, data_packed, scale, activation_qtype=qfloat8_e4m3fn, requires_grad=requires_grad
)