def init_state()

in bitsandbytes/optim/optimizer.py [0:0]


    def init_state(self, group, p, gindex, pindex):
        config = self.get_config(gindex, pindex, group)

        if config["optim_bits"] == 32:
            dtype = torch.float32
        elif config["optim_bits"] == 8:
            dtype = torch.uint8
        else:
            raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')

        if p.numel() < config["min_8bit_size"]:
            dtype = torch.float32

        state = self.state[p]
        state["step"] = 0

        if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
            state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
        elif dtype == torch.uint8:
            if state["step"] == 0:
                if "dynamic" not in self.name2qmap:
                    self.fill_qmap()
                self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)

            state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
            state["qmap1"] = self.name2qmap["dynamic"]

            if config["block_wise"]:
                n = p.numel()
                blocks = n // 2048
                blocks += 1 if n % 2048 > 0 else 0

                state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
            else:
                state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
                state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)

        if config["percentile_clipping"] < 100:
            state["gnorm_vec"] = torch.zeros((100,), device=p.device)

        if config["max_unorm"] > 0.0:
            state["unorm_vec"] = torch.zeros((1,), device=p.device)