in train.py [0:0]
def main(args):
cfg = setup_cfg(args)
if cfg.SEED >= 0:
print("Setting fixed seed: {}".format(cfg.SEED))
set_random_seed(cfg.SEED)
setup_logger(cfg.OUTPUT_DIR)
if torch.cuda.is_available() and cfg.USE_CUDA:
torch.backends.cudnn.benchmark = True
# print_args(args, cfg)
# print("Collecting env info ...")
# print("** System info **\n{}\n".format(collect_env_info()))
trainer = build_trainer(cfg)
if args.ood_test:
if args.model_dir != '':
assert cfg.TRAINER.CATEX.CTX_INIT in ['', None, 'ensemble', 'ensemble_learned']
trainer.load_model(args.model_dir, epoch=args.load_epoch)
trainer.test_ood(model_directory=args.model_dir)
return
if args.eval_only:
if cfg.TRAINER.CATEX.CTX_INIT == '':
trainer.load_model(args.model_dir, epoch=args.load_epoch)
trainer.test()
return
if not args.no_train:
if args.model_dir != '':
trainer.load_model(args.model_dir, epoch=args.load_epoch)
if cfg.TRAINER.OOD_TRAIN:
trainer.forward_backward = trainer.forward_backward_ood
trainer.train()