def __init__()

in jukebox/utils/ema.py [0:0]


    def __init__(self, params, mu=0.999):
        self.mu = mu
        params = list(params)
        self.params = {}
        self.params['fp16'] = [p for p in params if p.requires_grad and p.data.dtype == torch.float16]
        self.params['fp32'] = [p for p in params if p.requires_grad and p.data.dtype != torch.float16]
        self.groups = [group for group in self.params.keys() if len(self.params[group]) > 0]
        self.state = {}
        for group in self.groups:
            self.state[group] = self.get_model_state(group)