in bitsandbytes/optim/optimizer.py [0:0]
def __init__(self, params, defaults, optim_bits=32, is_paged=False):
"""
Base 8-bit optimizer class.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__(params, defaults)
self.initialized = False
self.name2qmap = {}
self.is_paged = is_paged
self.page_mng = F.GlobalPageManager.get_instance()
self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = {
"qmap1",
"qmap2",
"max1",
"max2",
"new_max1",
"new_max2",
"state1",
"state2",
"gnorm_vec",
"absmax1",
"absmax2",
"unorm_vec",
}
if optim_bits == 8:
self.fill_qmap()