in cm/train_util.py [0:0]
def _load_and_sync_teacher_parameters(self):
resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
if resume_checkpoint:
path, name = os.path.split(resume_checkpoint)
teacher_name = name.replace("model", "teacher_model")
resume_teacher_checkpoint = os.path.join(path, teacher_name)
if bf.exists(resume_teacher_checkpoint) and dist.get_rank() == 0:
logger.log(
"loading model from checkpoint: {resume_teacher_checkpoint}..."
)
self.teacher_model.load_state_dict(
dist_util.load_state_dict(
resume_teacher_checkpoint, map_location=dist_util.dev()
),
)
dist_util.sync_params(self.teacher_model.parameters())
dist_util.sync_params(self.teacher_model.buffers())