in bitsandbytes/optim/optimizer.py [0:0]
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
if config["optim_bits"] == 32:
dtype = torch.float32
elif config["optim_bits"] == 8:
dtype = torch.uint8
else:
raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
state = self.state[p]
state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap1"] = self.name2qmap["dynamic"]
state["state2"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap2"] = self.name2qmap["udynamic"]
if config["block_wise"]:
n = p.numel()
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
else:
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
if config["max_unorm"] > 0.0:
state["unorm_vec"] = torch.zeros((1,), device=p.device)