def init_state()

in bitsandbytes/optim/optimizer.py [0:0]


    def init_state(self, group, p, gindex, pindex):
        config = self.get_config(gindex, pindex, group)

        if config['optim_bits'] == 32:
            dtype = torch.float32
        elif config['optim_bits'] == 8:
            dtype = torch.uint8
        else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')

        if p.numel() < config['min_8bit_size']: dtype = torch.float32

        state = self.state[p]
        state['step'] = 0

        if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
            state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
        elif dtype == torch.uint8:
            if state['step'] == 0:
                if 'dynamic' not in self.name2qmap: self.fill_qmap()
                self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)

            state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
            state['qmap1'] = self.name2qmap['dynamic']

            if config['block_wise']:
                n = p.numel()
                blocks = n//2048
                blocks += 1 if n % 2048 > 0 else 0

                state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
            else:
                state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
                state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)

        if config['percentile_clipping'] < 100:
            state['gnorm_vec'] = torch.zeros((100,), device=p.device)

        if config['max_unorm'] > 0.0:
            state['unorm_vec'] = torch.zeros((1,), device=p.device)