in cm/train_util.py [0:0]
def _load_and_sync_target_parameters(self):
resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
if resume_checkpoint:
path, name = os.path.split(resume_checkpoint)
target_name = name.replace("model", "target_model")
resume_target_checkpoint = os.path.join(path, target_name)
if bf.exists(resume_target_checkpoint) and dist.get_rank() == 0:
logger.log(
"loading model from checkpoint: {resume_target_checkpoint}..."
)
self.target_model.load_state_dict(
dist_util.load_state_dict(
resume_target_checkpoint, map_location=dist_util.dev()
),
)
dist_util.sync_params(self.target_model.parameters())
dist_util.sync_params(self.target_model.buffers())