in egg/zoo/simclr/game_callbacks.py [0:0]
def save_vision_model(self, epoch=""):
is_distributed = self.trainer.distributed_context.is_distributed
is_leader = self.trainer.distributed_context.is_leader
if hasattr(self.trainer, "checkpoint_path"):
if self.trainer.checkpoint_path and (
(not is_distributed) or (is_distributed and is_leader)
):
self.trainer.checkpoint_path.mkdir(exist_ok=True, parents=True)
if is_distributed:
# if distributed training the model is an instance of
# DistributedDataParallel and we need to unpack it from it.
vision_module = self.trainer.game.module.vision_module
else:
vision_module = self.trainer.game.vision_module
torch.save(
vision_module.encoder.state_dict(),
self.trainer.checkpoint_path
/ f"vision_module{epoch if epoch else '_final'}.pt",
)