in jukebox/utils/ema.py [0:0]
def __init__(self, params, mu=0.999):
self.mu = mu
params = list(params)
self.params = {}
self.params['fp16'] = [p for p in params if p.requires_grad and p.data.dtype == torch.float16]
self.params['fp32'] = [p for p in params if p.requires_grad and p.data.dtype != torch.float16]
self.groups = [group for group in self.params.keys() if len(self.params[group]) > 0]
self.state = {}
for group in self.groups:
self.state[group] = self.get_model_state(group)