def forward()

in optimum/quanto/tensor/weights/tinygemm/qbits.py [0:0]


    def forward(ctx, input, other, bias):
        ctx.save_for_backward(input, other)
        if type(input) is not torch.Tensor:
            input = input.dequantize()
        in_features = input.shape[-1]
        out_features = other.shape[0]
        output_shape = input.shape[:-1] + (out_features,)
        if input.device.type == "cpu":
            output = torch._weight_int4pack_mm_for_cpu(
                input.reshape(-1, in_features), other._data._data, other._group_size, other._scale_shift
            )
        else:
            output = torch._weight_int4pack_mm(
                input.reshape(-1, in_features), other._data._data, other._group_size, other._scale_shift
            )
        output = output.reshape(output_shape)
        if bias is not None:
            output = output + bias
        return output