in anticipation/anticipation/tools/train_recognizer.py [0:0]
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# update configs according to CLI args
if args.work_dir is not None:
cfg.work_dir = args.work_dir
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.gpus = args.gpus
if cfg.checkpoint_config is not None:
# save mmaction version in checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmact_version=__version__, config=cfg.text)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# init logger before other steps
logger = get_root_logger(cfg.log_level)
logger.info('Distributed training: {}'.format(distributed))
# set random seeds
if args.seed is not None:
logger.info('Set random seed to {}'.format(args.seed))
set_random_seed(args.seed)
model = build_recognizer(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
train_dataset = get_trimmed_dataset(cfg.data.train)
val_dataset = get_trimmed_dataset(cfg.data.val)
datasets = []
for flow in cfg.workflow:
assert flow[0] in ['train', 'val']
if flow[0] == 'train':
datasets.append(train_dataset)
else:
datasets.append(val_dataset)
train_network(
model,
datasets,
cfg,
distributed=distributed,
validate=args.validate,
logger=logger)