def forward()

in optimum/quanto/tensor/weights/marlin/fp8/qbits.py [0:0]


    def forward(ctx, input, other, bias=None):
        ctx.save_for_backward(input, other)
        input_shape = input.shape

        if input.ndim > 2:
            input = input.reshape(-1, input_shape[-1])

        output = torch.ops.quanto.gemm_f16f8_marlin(
            input,
            b_q_weight=other._data._data,
            b_scales=other._scale,  # .to(input.dtype)
            workspace=other._workspace,
            num_bits=8,
            size_m=input.shape[0],
            size_n=other._scale.shape[1],
            size_k=input.shape[1],
        )

        if len(input_shape) > 2:
            output = output.reshape(input_shape[:-1] + (other._scale.shape[1],))

        return output