in bitsandbytes/functional.py [0:0]
def elementwise_func(func_name, A, B, value, prefetch=True):
func = None
if A.dtype == torch.float32:
func = getattr(lib, f"c{func_name}_fp32", None)
cvalue = ct.c_float(value)
elif A.dtype == torch.uint8:
func = getattr(lib, f"c{func_name}_uint8", None)
cvalue = ct.c_uint8(value)
if func is None:
raise NotImplementedError(f"Function not implemented: {func_name}")
is_managed = getattr(A, "is_managed", False)
if is_managed and prefetch:
prefetch_tensor(A)
if B is not None:
prefetch_tensor(B)
func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel()))
if A.is_paged or B.is_paged:
# paged function are fully asynchronous
# if we return from this function, we want to the tensor
# to be in the correct state, that is the final state after the
# operation occurred. So we synchronize.
torch.cuda.synchronize()