in src/markov/multi_agent_coach/multi_agent_graph_manager.py [0:0]
def restore_checkpoint(self):
self.verify_graph_was_created()
# TODO: find better way to load checkpoints that were saved with a global network into the online network
if self.task_parameters.checkpoint_restore_path:
restored_checkpoint_paths = []
for agent_params in self.agents_params:
# for single agent name is 'agent'. For multi agent name is 'agent_0' ...
agent_checkpoint_restore_path = os.path.join(self.task_parameters.checkpoint_restore_path, agent_params.name)
if os.path.isdir(agent_checkpoint_restore_path):
# a checkpoint dir
if self.task_parameters.framework_type == Frameworks.tensorflow and\
'checkpoint' in os.listdir(agent_checkpoint_restore_path):
# TODO-fixme checkpointing
# MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
# it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt"
# filename pattern. The names used are maintained in a CheckpointState protobuf file named
# 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to
# restore the model, as the checkpoint names defined do not match the actual checkpoint names.
raise NotImplementedError('Checkpointing not implemented for TF monitored training session')
else:
checkpoint = get_checkpoint_state(agent_checkpoint_restore_path, all_checkpoints=True)
if checkpoint is None:
raise ValueError("No checkpoint to restore in: {}".format(agent_checkpoint_restore_path))
model_checkpoint_path = checkpoint.model_checkpoint_path
checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
restored_checkpoint_paths.append(model_checkpoint_path)
# Set the last checkpoint ID - only in the case of the path being a dir
chkpt_state_reader = CheckpointStateReader(agent_checkpoint_restore_path,
checkpoint_state_optional=False)
self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
else:
# a checkpoint file
if self.task_parameters.framework_type == Frameworks.tensorflow:
model_checkpoint_path = agent_checkpoint_restore_path
checkpoint_restore_dir = os.path.dirname(model_checkpoint_path)
restored_checkpoint_paths.append(model_checkpoint_path)
else:
raise ValueError("Currently restoring a checkpoint using the --checkpoint_restore_file argument is"
" only supported when with tensorflow.")
try:
self.checkpoint_saver[agent_params.name].restore(self.sess[agent_params.name],
model_checkpoint_path)
except Exception as ex:
raise ValueError("Failed to restore {}'s checkpoint: {}".format(agent_params.name, ex))
all_checkpoints = sorted(list(set([c.name for c in checkpoint.all_checkpoints]))) # remove duplicates :-(
if self.num_checkpoints_to_keep < len(all_checkpoints):
checkpoint_to_delete = all_checkpoints[-self.num_checkpoints_to_keep - 1]
agent_checkpoint_to_delete = os.path.join(agent_checkpoint_restore_path, checkpoint_to_delete)
for file in glob.glob("{}*".format(agent_checkpoint_to_delete)):
os.remove(file)
[manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers]
[manager.post_training_commands() for manager in self.level_managers]
screen.log_dict(
OrderedDict([
("Restoring from path", restored_checkpoint_path) for restored_checkpoint_path in restored_checkpoint_paths
]),
prefix="Checkpoint"
)