in dpr_scale/optim/madgrad.py [0:0]
def initialize_state(self):
for group in self.param_groups:
for p in group["params"]:
if p not in self.state:
state = self.state[p]
state["grad_sum_sq"] = torch.zeros_like(p.data).detach().cuda()
state["s"] = torch.zeros_like(p.data).detach().cuda()
if self.momentum != 0:
state["x0"] = torch.clone(p.data).detach().cuda()