in experiment.py [0:0]
def init_model(cfg, force_load=False, clear_stats=False, add_log_vars=None):
# get the model
model = C3DPO(**cfg.MODEL)
# obtain the network outputs that should be logged
if hasattr(model, 'log_vars'):
log_vars = copy.deepcopy(model.log_vars)
else:
log_vars = ['objective']
if add_log_vars is not None:
log_vars.extend(copy.deepcopy(add_log_vars))
visdom_env_charts = get_visdom_env(cfg) + "_charts"
# init stats struct
stats = Stats(log_vars, visdom_env=visdom_env_charts,
verbose=False, visdom_server=cfg.visdom_server,
visdom_port=cfg.visdom_port)
# find the last checkpoint
if cfg.resume_epoch > 0:
model_path = get_checkpoint(cfg.exp_dir, cfg.resume_epoch)
else:
model_path = find_last_checkpoint(cfg.exp_dir)
optimizer_state = None
if model_path is not None:
print("found previous model %s" % model_path)
if force_load or cfg.resume:
print(" -> resuming")
model_state_dict, stats_load, optimizer_state = load_model(
model_path)
if not clear_stats:
stats = stats_load
else:
print(" -> clearing stats")
model.load_state_dict(model_state_dict, strict=True)
model.log_vars = log_vars
else:
print(" -> but not resuming -> starting from scratch")
# update in case it got lost during load:
stats.visdom_env = visdom_env_charts
stats.visdom_server = cfg.visdom_server
stats.visdom_port = cfg.visdom_port
stats.plot_file = os.path.join(cfg.exp_dir, 'train_stats.pdf')
stats.synchronize_logged_vars(log_vars)
return model, stats, optimizer_state