in guided_diffusion/train_util.py [0:0]
def _load_ema_parameters(self, rate):
ema_params = copy.deepcopy(self.mp_trainer.master_params)
main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
if ema_checkpoint:
if dist.get_rank() == 0:
logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
state_dict = dist_util.load_state_dict(
ema_checkpoint, map_location=dist_util.dev()
)
ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
dist_util.sync_params(ema_params)
return ema_params