def initialize_state()

in dpr_scale/optim/madgrad.py [0:0]


    def initialize_state(self):
        for group in self.param_groups:
            for p in group["params"]:
                if p not in self.state:
                    state = self.state[p]
                    state["grad_sum_sq"] = torch.zeros_like(p.data).detach().cuda()
                    state["s"] = torch.zeros_like(p.data).detach().cuda()
                    if self.momentum != 0:
                        state["x0"] = torch.clone(p.data).detach().cuda()