def run_training()

in c3dm/experiment.py [0:0]


def run_training(cfg):
	# run the training loops
	
	# make the exp dir
	os.makedirs(cfg.exp_dir,exist_ok=True)

	# set the seed
	np.random.seed(cfg.seed)

	# dump the exp config to the exp dir
	dump_config(cfg)

	# setup datasets
	dset_train, dset_val, dset_test = dataset_zoo(**cfg.DATASET)

	 # init loaders
	if cfg.batch_sampler=='default':
		trainloader = torch.utils.data.DataLoader( dset_train, 
							num_workers=cfg.num_workers, pin_memory=True,
							batch_size=cfg.batch_size, shuffle=False )
	elif cfg.batch_sampler=='sequence':
		trainloader = torch.utils.data.DataLoader( dset_train, 
							num_workers=cfg.num_workers, pin_memory=True,
							batch_sampler=SceneBatchSampler(
									torch.utils.data.SequentialSampler(dset_train),
									cfg.batch_size,
									True,
		) )
	else:
		raise BaseException()

	if dset_val is not None:
		if cfg.batch_sampler=='default':
			valloader = torch.utils.data.DataLoader( dset_val, 
								num_workers=cfg.num_workers, pin_memory=True,
								batch_size=cfg.batch_size, shuffle=False )
		elif cfg.batch_sampler=='sequence':
			valloader = torch.utils.data.DataLoader( dset_val, 
								num_workers=cfg.num_workers, pin_memory=True,
								batch_sampler=SceneBatchSampler( \
									 torch.utils.data.SequentialSampler(dset_val),
									 cfg.batch_size,
									 True,
			) )
		else:
			raise BaseException()
	else:
		valloader = None

	# test loaders
	if dset_test is not None:
		testloader = torch.utils.data.DataLoader(dset_test, 
				num_workers=cfg.num_workers, pin_memory=True,
				batch_size=cfg.batch_size, shuffle=False,
				)
		_,_,eval_vars = eval_zoo(cfg.DATASET.dataset_name)
	else:
		testloader = None
		eval_vars = None


	# init the model    
	model, stats, optimizer_state = init_model(cfg,add_log_vars=eval_vars)
	start_epoch = stats.epoch + 1

	# annotate dataset with c3dpo outputs
	if cfg.annotate_with_c3dpo_outputs:
		for dset in dset_train, dset_val, dset_test:
			if dset is not None:
				run_c3dpo_model_on_dset(dset, cfg.MODEL.nrsfm_exp_path)

	# move model to gpu
	model.cuda(0)

	# init the optimizer
	optimizer, scheduler = init_optimizer(\
		model, optimizer_state=optimizer_state, **cfg.SOLVER)
	
	# loop through epochs
	scheduler.last_epoch = start_epoch
	for epoch in range(start_epoch, cfg.SOLVER.max_epochs):
		with stats: # automatic new_epoch and plotting of stats at every epoch start
			
			print("scheduler lr = %1.2e" % float(scheduler.get_lr()[-1]))
			
			# train loop
			trainvalidate(model, stats, epoch, trainloader, optimizer, False, \
										visdom_env_root=get_visdom_env(cfg), **cfg )
			
			# val loop
			if valloader is not None:
				trainvalidate(model, stats, epoch, valloader,   optimizer, True,  \
											visdom_env_root=get_visdom_env(cfg), **cfg  )

			# eval loop (optional)
			if testloader is not None:
				if cfg.eval_interval >= 0:                
					if cfg.eval_interval == 0 or \
						((epoch % cfg.eval_interval)==0 and epoch > 0):
						torch.cuda.empty_cache() # we have memory heavy eval ...
						with torch.no_grad():
							run_eval(cfg,model,stats,testloader)

			assert stats.epoch==epoch, "inconsistent stats!"

			# delete previous models if required
			if cfg.store_checkpoints_purge > 0 and cfg.store_checkpoints:
				for prev_epoch in range(epoch-cfg.store_checkpoints_purge):
					period = cfg.store_checkpoints_purge_except_every
					if (period > 0 and prev_epoch % period == period - 1):
						continue
					purge_epoch(cfg.exp_dir,prev_epoch)

			# save model
			if cfg.store_checkpoints:
				outfile = get_checkpoint(cfg.exp_dir,epoch)        
				save_model(model,stats,outfile,optimizer=optimizer)

			scheduler.step()