in lib/utils/checkpoints.py [0:0]
def save_model_params(model, params_file, model_iter):
logger.info("Saving model params to weights file {}".format(params_file))
root_gpu_id = cfg.ROOT_GPU_ID
save_params = [str(param) for param in model.GetParams('gpu_{}'.format(
root_gpu_id))]
save_computed_params = [
str(param) for param in model.GetComputedParams('gpu_{}'.format(
root_gpu_id))]
save_blobs = {}
save_blobs['model_iter'] = model_iter + 1
save_blobs['lr'] = workspace.FetchBlob('gpu_{}/lr'.format(root_gpu_id))
for param in save_params:
if param in model.TrainableParams():
scoped_blob_name = str(param) + '_momentum'
unscoped_blob_name = misc.unscope_name(scoped_blob_name)
if unscoped_blob_name not in save_blobs:
data = workspace.FetchBlob(scoped_blob_name)
save_blobs[unscoped_blob_name] = data
for param in save_params + save_computed_params:
scoped_blob_name = str(param)
unscoped_blob_name = misc.unscope_name(scoped_blob_name)
if unscoped_blob_name not in save_blobs:
data = workspace.FetchBlob(
scoped_blob_name
)
save_blobs[unscoped_blob_name] = data
try:
with open(params_file, 'w') as fwrite:
pickle.dump(
dict(blobs=save_blobs),
fwrite,
pickle.HIGHEST_PROTOCOL
)
except Exception:
logger.warning("save_model_params: dump parameters failed.")