in bitsandbytes/functional.py [0:0]
def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None):
assert A.dtype == torch.int32
if bias is not None:
assert bias.dtype == torch.float16
out_shape = quant_state[0]
if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
if out is None:
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
if new_row_stats is None:
new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
if new_col_stats is None:
new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
prev_device = pre_call(A.device)
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
ptrRowStats = get_ptr(row_stats)
ptrColStats = get_ptr(col_stats)
ptrNewRowStats = get_ptr(new_row_stats)
ptrNewColStats = get_ptr(new_col_stats)
ptrBias = get_ptr(bias)
numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1])
is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias])
lib.cdequant_mm_int32_fp16(
ptrA,
ptrRowStats,
ptrColStats,
ptrOut,
ptrNewRowStats,
ptrNewColStats,
ptrBias,
numRows,
numCols,
)
post_call(prev_device)
return out