def save_vision_model()

in egg/zoo/simclr/game_callbacks.py [0:0]


    def save_vision_model(self, epoch=""):
        is_distributed = self.trainer.distributed_context.is_distributed
        is_leader = self.trainer.distributed_context.is_leader
        if hasattr(self.trainer, "checkpoint_path"):
            if self.trainer.checkpoint_path and (
                (not is_distributed) or (is_distributed and is_leader)
            ):
                self.trainer.checkpoint_path.mkdir(exist_ok=True, parents=True)
                if is_distributed:
                    # if distributed training the model is an instance of
                    # DistributedDataParallel and we need to unpack it from it.
                    vision_module = self.trainer.game.module.vision_module
                else:
                    vision_module = self.trainer.game.vision_module
                torch.save(
                    vision_module.encoder.state_dict(),
                    self.trainer.checkpoint_path
                    / f"vision_module{epoch if epoch else '_final'}.pt",
                )