in c3dm/experiment.py [0:0]
def init_model(cfg,force_load=False,clear_stats=False,add_log_vars=None):
# get the model
model = Model(**cfg.MODEL)
# obtain the network outputs that should be logged
if hasattr(model,'log_vars'):
log_vars = copy.deepcopy(model.log_vars)
else:
log_vars = ['objective']
if add_log_vars:
log_vars.extend(copy.deepcopy(add_log_vars))
visdom_env_charts = get_visdom_env(cfg) + "_charts"
# init stats struct
stats = Stats( log_vars, visdom_env=visdom_env_charts, \
verbose=False, visdom_server=cfg.visdom_server, \
visdom_port=cfg.visdom_port )
model_path = None
if cfg.resume_epoch > 0:
model_path = get_checkpoint(cfg.exp_dir,cfg.resume_epoch)
elif cfg.resume_epoch == -1: # find the last checkpoint
model_path = find_last_checkpoint(cfg.exp_dir)
optimizer_state = None
if model_path is None and force_load:
from dataset.dataset_configs import C3DM_URLS
url = C3DM_URLS[cfg.DATASET.dataset_name]
print('Downloading C3DM model %s from %s' % (cfg.DATASET.dataset_name, url))
utils.untar_to_dir(url, cfg.exp_dir)
model_path = find_last_checkpoint(cfg.exp_dir)
if model_path is not None:
print( "found previous model %s" % model_path )
if force_load or cfg.resume:
print( " -> resuming" )
model_state_dict, stats_load, optimizer_state = load_model(model_path)
if not clear_stats:
if stats_load is None:
print(" -> bad stats! -> clearing")
else:
stats = stats_load
else:
print(" -> clearing stats")
try:
model.load_state_dict(model_state_dict, strict=True)
except RuntimeError as e:
print('!!!!! cant load state dict in strict mode:')
print(e)
print('loading in non-strict mode ...')
model.load_state_dict(model_state_dict, strict=False)
model.log_vars = log_vars
else:
print( " -> but not resuming -> starting from scratch" )
elif force_load:
print('!! CANNOT RESUME FROM A CHECKPOINT !!')
# update in case it got lost during load:
stats.visdom_env = visdom_env_charts
stats.visdom_server = cfg.visdom_server
stats.visdom_port = cfg.visdom_port
#stats.plot_file = os.path.join(cfg.exp_dir,'train_stats.pdf')
stats.synchronize_logged_vars(log_vars)
return model, stats, optimizer_state