in optimum/quanto/tensor/weights/marlin/fp8/qbits.py [0:0]
def forward(ctx, input, other, bias=None):
ctx.save_for_backward(input, other)
input_shape = input.shape
if input.ndim > 2:
input = input.reshape(-1, input_shape[-1])
output = torch.ops.quanto.gemm_f16f8_marlin(
input,
b_q_weight=other._data._data,
b_scales=other._scale, # .to(input.dtype)
workspace=other._workspace,
num_bits=8,
size_m=input.shape[0],
size_n=other._scale.shape[1],
size_k=input.shape[1],
)
if len(input_shape) > 2:
output = output.reshape(input_shape[:-1] + (other._scale.shape[1],))
return output