def restore_checkpoint()

in src/markov/multi_agent_coach/multi_agent_graph_manager.py [0:0]


    def restore_checkpoint(self):
        self.verify_graph_was_created()

        # TODO: find better way to load checkpoints that were saved with a global network into the online network
        if self.task_parameters.checkpoint_restore_path:
            restored_checkpoint_paths = []
            for agent_params in self.agents_params:
                # for single agent name is 'agent'. For multi agent name is 'agent_0' ... 
                agent_checkpoint_restore_path = os.path.join(self.task_parameters.checkpoint_restore_path, agent_params.name)
                if os.path.isdir(agent_checkpoint_restore_path):
                    # a checkpoint dir
                    if self.task_parameters.framework_type == Frameworks.tensorflow and\
                            'checkpoint' in os.listdir(agent_checkpoint_restore_path):
                        # TODO-fixme checkpointing
                        # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
                        # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt"
                        # filename pattern. The names used are maintained in a CheckpointState protobuf file named
                        # 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to
                        # restore the model, as the checkpoint names defined do not match the actual checkpoint names.
                        raise NotImplementedError('Checkpointing not implemented for TF monitored training session')
                    else:
                        checkpoint = get_checkpoint_state(agent_checkpoint_restore_path, all_checkpoints=True)

                    if checkpoint is None:
                        raise ValueError("No checkpoint to restore in: {}".format(agent_checkpoint_restore_path))
                    model_checkpoint_path = checkpoint.model_checkpoint_path
                    checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
                    restored_checkpoint_paths.append(model_checkpoint_path)

                    # Set the last checkpoint ID - only in the case of the path being a dir
                    chkpt_state_reader = CheckpointStateReader(agent_checkpoint_restore_path,
                                                               checkpoint_state_optional=False)
                    self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
                else:
                    # a checkpoint file
                    if self.task_parameters.framework_type == Frameworks.tensorflow:
                        model_checkpoint_path = agent_checkpoint_restore_path
                        checkpoint_restore_dir = os.path.dirname(model_checkpoint_path)
                        restored_checkpoint_paths.append(model_checkpoint_path)
                    else:
                        raise ValueError("Currently restoring a checkpoint using the --checkpoint_restore_file argument is"
                                         " only supported when with tensorflow.")

                try:
                    self.checkpoint_saver[agent_params.name].restore(self.sess[agent_params.name],
                                                                     model_checkpoint_path)
                except Exception as ex:
                    raise ValueError("Failed to restore {}'s checkpoint: {}".format(agent_params.name, ex))

                all_checkpoints = sorted(list(set([c.name for c in checkpoint.all_checkpoints]))) # remove duplicates :-(
                if self.num_checkpoints_to_keep < len(all_checkpoints):
                    checkpoint_to_delete = all_checkpoints[-self.num_checkpoints_to_keep - 1]
                    agent_checkpoint_to_delete = os.path.join(agent_checkpoint_restore_path, checkpoint_to_delete)
                    for file in glob.glob("{}*".format(agent_checkpoint_to_delete)):
                        os.remove(file)

            [manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
            [manager.post_training_commands() for manager in self.level_managers]

            screen.log_dict(
                OrderedDict([
                    ("Restoring from path", restored_checkpoint_path) for restored_checkpoint_path in restored_checkpoint_paths
                ]),
                prefix="Checkpoint"
            )