in c3dm/experiment.py [0:0]
def run_training(cfg):
# run the training loops
# make the exp dir
os.makedirs(cfg.exp_dir,exist_ok=True)
# set the seed
np.random.seed(cfg.seed)
# dump the exp config to the exp dir
dump_config(cfg)
# setup datasets
dset_train, dset_val, dset_test = dataset_zoo(**cfg.DATASET)
# init loaders
if cfg.batch_sampler=='default':
trainloader = torch.utils.data.DataLoader( dset_train,
num_workers=cfg.num_workers, pin_memory=True,
batch_size=cfg.batch_size, shuffle=False )
elif cfg.batch_sampler=='sequence':
trainloader = torch.utils.data.DataLoader( dset_train,
num_workers=cfg.num_workers, pin_memory=True,
batch_sampler=SceneBatchSampler(
torch.utils.data.SequentialSampler(dset_train),
cfg.batch_size,
True,
) )
else:
raise BaseException()
if dset_val is not None:
if cfg.batch_sampler=='default':
valloader = torch.utils.data.DataLoader( dset_val,
num_workers=cfg.num_workers, pin_memory=True,
batch_size=cfg.batch_size, shuffle=False )
elif cfg.batch_sampler=='sequence':
valloader = torch.utils.data.DataLoader( dset_val,
num_workers=cfg.num_workers, pin_memory=True,
batch_sampler=SceneBatchSampler( \
torch.utils.data.SequentialSampler(dset_val),
cfg.batch_size,
True,
) )
else:
raise BaseException()
else:
valloader = None
# test loaders
if dset_test is not None:
testloader = torch.utils.data.DataLoader(dset_test,
num_workers=cfg.num_workers, pin_memory=True,
batch_size=cfg.batch_size, shuffle=False,
)
_,_,eval_vars = eval_zoo(cfg.DATASET.dataset_name)
else:
testloader = None
eval_vars = None
# init the model
model, stats, optimizer_state = init_model(cfg,add_log_vars=eval_vars)
start_epoch = stats.epoch + 1
# annotate dataset with c3dpo outputs
if cfg.annotate_with_c3dpo_outputs:
for dset in dset_train, dset_val, dset_test:
if dset is not None:
run_c3dpo_model_on_dset(dset, cfg.MODEL.nrsfm_exp_path)
# move model to gpu
model.cuda(0)
# init the optimizer
optimizer, scheduler = init_optimizer(\
model, optimizer_state=optimizer_state, **cfg.SOLVER)
# loop through epochs
scheduler.last_epoch = start_epoch
for epoch in range(start_epoch, cfg.SOLVER.max_epochs):
with stats: # automatic new_epoch and plotting of stats at every epoch start
print("scheduler lr = %1.2e" % float(scheduler.get_lr()[-1]))
# train loop
trainvalidate(model, stats, epoch, trainloader, optimizer, False, \
visdom_env_root=get_visdom_env(cfg), **cfg )
# val loop
if valloader is not None:
trainvalidate(model, stats, epoch, valloader, optimizer, True, \
visdom_env_root=get_visdom_env(cfg), **cfg )
# eval loop (optional)
if testloader is not None:
if cfg.eval_interval >= 0:
if cfg.eval_interval == 0 or \
((epoch % cfg.eval_interval)==0 and epoch > 0):
torch.cuda.empty_cache() # we have memory heavy eval ...
with torch.no_grad():
run_eval(cfg,model,stats,testloader)
assert stats.epoch==epoch, "inconsistent stats!"
# delete previous models if required
if cfg.store_checkpoints_purge > 0 and cfg.store_checkpoints:
for prev_epoch in range(epoch-cfg.store_checkpoints_purge):
period = cfg.store_checkpoints_purge_except_every
if (period > 0 and prev_epoch % period == period - 1):
continue
purge_epoch(cfg.exp_dir,prev_epoch)
# save model
if cfg.store_checkpoints:
outfile = get_checkpoint(cfg.exp_dir,epoch)
save_model(model,stats,outfile,optimizer=optimizer)
scheduler.step()