in rlkit/core/logging.py [0:0]
def save_itr_params(self, itr, params):
if self._snapshot_dir:
if self._snapshot_mode == 'all':
file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr)
torch.save(params, file_name)
elif self._snapshot_mode == 'last':
# override previous params
file_name = osp.join(self._snapshot_dir, 'params.pkl')
torch.save(params, file_name)
elif self._snapshot_mode == "gap":
if itr % self._snapshot_gap == 0:
file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr)
torch.save(params, file_name)
elif self._snapshot_mode == "gap_and_last":
if itr % self._snapshot_gap == 0:
file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr)
torch.save(params, file_name)
file_name = osp.join(self._snapshot_dir, 'params.pkl')
torch.save(params, file_name)
elif self._snapshot_mode == 'none':
pass
else:
raise NotImplementedError