in experiment_utils/cluster_manager.py [0:0]
def save_checkpoint(self, epoch_id=None, requeue_on_signal=True):
# To find out if a signal is received in any process
if requeue_on_signal and self.world_size > 1:
dist.all_reduce(self.signal_tensor, group=self.process_group)
self.logger.info('Saving checkpoint')
if self.all_workers or self.rank == ClusterManager.MASTER_RANK:
if epoch_id is None:
checkpoint_fpath = self.checkpoint_fpath
else:
checkpoint_fpath = ClusterManager.CHECKPOINT_DIR \
+ 'ep' + str(epoch_id) + '_' \
+ self.model_tag + self.checkpoint_fname
torch.save(self.state, checkpoint_fpath)
if self.state['is_best']:
shutil.copyfile(checkpoint_fpath,
self.model_best_fpath)
self.state['is_best'] = False
if requeue_on_signal and self.signal_tensor[0] > 0:
self.logger.info('Atleast 1 process received SIGUSR1. Terminating')
# relaunch job on cluster starting from checkpoint only for
# main process of the rank 0 agent
if self.rank == 0 and os.getpid() == self.main_pid:
command = f'scontrol requeue {os.environ["SLURM_JOB_ID"]}'
self.logger.info('Relaunching: ' + command)
if os.system(command):
raise RuntimeError('sbatch failed')
self.logger.info('New job submitted to the queue')
self.logger.info('Terminating')
sys.exit(0)