def main()

in func/train.py [0:0]


def main(cfg):
    logger = logging.getLogger(__name__)
    dist_info, device, writer = initial_setup(cfg, logger)
    # Data loading code
    logger.info("Loading data")

    logger.info("\t Loading datasets")
    st = time.time()

    # separate these into get transforms
    # TODO: This is gotten too complex: clean up, make interface better
    transform_train = [
        T.ToTensorVideo(),
        T.Resize(_get_resize_shape(cfg.data_train)),
        T.RandomHorizontalFlipVideo(cfg.data_train.flip_p),
        T.ColorJitterVideo(brightness=cfg.data_train.color_jitter_brightness,
                           contrast=cfg.data_train.color_jitter_contrast,
                           saturation=cfg.data_train.color_jitter_saturation,
                           hue=cfg.data_train.color_jitter_hue),
        torchvision.transforms.Lambda(
            lambda x: x * cfg.data_train.scale_pix_val),
        torchvision.transforms.Lambda(lambda x: x[[2, 1, 0], ...])
        if cfg.data_train.reverse_channels else torchvision.transforms.Compose(
            []),
        T.NormalizeVideo(**_get_pixel_mean_std(cfg.data_train)),
    ]
    if cfg.data_train.crop_size is not None:
        transform_train.append(
            T.RandomCropVideo(
                (cfg.data_train.crop_size, cfg.data_train.crop_size)), )
    transform_train = torchvision.transforms.Compose(transform_train)
    transform_eval = [
        T.ToTensorVideo(),
        T.Resize(_get_resize_shape(cfg.data_eval)),
        torchvision.transforms.Lambda(
            lambda x: x * cfg.data_eval.scale_pix_val),
        torchvision.transforms.Lambda(lambda x: x[[2, 1, 0], ...]) if
        cfg.data_eval.reverse_channels else torchvision.transforms.Compose([]),
        T.NormalizeVideo(**_get_pixel_mean_std(cfg.data_eval)),
    ]
    if cfg.data_eval.crop_size is not None:
        transform_eval.append(
            T.MultiCropVideo(
                (cfg.data_eval.crop_size, cfg.data_eval.crop_size),
                cfg.data_eval.eval_num_crops, cfg.data_eval.eval_flip_crops))
    transform_eval = torchvision.transforms.Compose(transform_eval)

    datasets_train = [
        get_dataset(getattr(cfg, el), cfg.data_train, transform_train, logger)
        for el in cfg.keys() if el.startswith(DATASET_TRAIN_CFG_KEY)
    ]
    if len(datasets_train) > 1:
        dataset = torch.utils.data.ConcatDataset(datasets_train)
    else:
        dataset = datasets_train[0]
    # could be multiple test datasets
    datasets_test = {
        el[len(DATASET_EVAL_CFG_KEY):]:
        get_dataset(getattr(cfg, el), cfg.data_eval, transform_eval, logger)
        for el in cfg.keys() if el.startswith(DATASET_EVAL_CFG_KEY)
    }

    logger.info("Took %d", time.time() - st)

    logger.info("Creating data loaders")
    train_sampler = None
    test_samplers = {key: None for key in datasets_test}
    if hasattr(dataset, 'video_clips'):
        assert cfg.train.shuffle_data, 'TODO'
        train_sampler = RandomClipSampler(getattr(dataset, 'video_clips'),
                                          cfg.data_train.train_bs_multiplier)
        test_samplers = {
            key: UniformClipSampler(val.video_clips,
                                    cfg.data_eval.val_clips_per_video)
            for key, val in datasets_test.items()
        }
        if dist_info['distributed']:
            train_sampler = DistributedSampler(train_sampler)
            test_samplers = [DistributedSampler(el) for el in test_samplers]
    elif dist_info['distributed']:
        # Distributed, but doesn't have video_clips
        if cfg.data_train.use_dist_sampler:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                dataset,
                num_replicas=dist_info['world_size'],
                rank=dist_info['rank'],
                shuffle=cfg.train.shuffle_data)
        if cfg.data_eval.use_dist_sampler:
            test_samplers = {
                key: torch.utils.data.distributed.DistributedSampler(
                    val,
                    num_replicas=dist_info['world_size'],
                    rank=dist_info['rank'],
                    shuffle=False)
                for key, val in datasets_test.items()
            }

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=cfg.train.batch_size,
        sampler=train_sampler,
        num_workers=cfg.data_train.workers,
        pin_memory=False,  # usually hurts..
        shuffle=(train_sampler is None and cfg.train.shuffle_data),
        collate_fn=collate_fn_remove_audio,
    )

    data_loaders_test = {
        key: torch.utils.data.DataLoader(
            val,
            # Since no backprop, so can have a larger batch size
            batch_size=cfg.eval.batch_size or cfg.train.batch_size * 4,
            sampler=test_samplers[key],
            num_workers=cfg.data_eval.workers,
            pin_memory=False,  # Usually hurts..
            shuffle=False,
            collate_fn=collate_fn_remove_audio,
        )
        for key, val in datasets_test.items()
    }

    num_classes = {key: len(val) for key, val in dataset.classes.items()}
    logger.info('Creating model with %s classes', num_classes)
    model = base_model.BaseModel(cfg.model,
                                 num_classes=num_classes,
                                 class_mappings=dataset.class_mappings)
    logger.debug('Model: %s', model)
    if dist_info['distributed'] and cfg.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if cfg.train.init_from_model:
        # This can have structure as follows:
        # <module name>:<path to init model>;<module name>:<path>: ...
        for module_ckpt in cfg.train.init_from_model:
            elts = module_ckpt
            if len(elts) == 1:
                model_to_init = model
                ckpt_modules_to_keep = None
                ckpt_path = elts[0]
            elif len(elts) == 2:
                model_to_init = operator.attrgetter(elts[0])(model)
                ckpt_modules_to_keep = None
                ckpt_path = elts[1]
            elif len(elts) == 3:
                model_to_init = operator.attrgetter(elts[0])(model)
                ckpt_modules_to_keep = elts[1]
                ckpt_path = elts[2]
            else:
                raise ValueError(f'Incorrect formatting {module_ckpt}')
            init_model(model_to_init, ckpt_path, ckpt_modules_to_keep, logger)

    model.to(device)

    if cfg.opt.classifier_only:
        assert len(cfg.opt.lr_wd) == 1
        assert cfg.opt.lr_wd[0][0] == 'classifier'
        model = _set_all_bn_to_not_track_running_mean(model)
    params = []
    for this_module_names, this_lr, this_wd in cfg.opt.lr_wd:
        if OmegaConf.get_type(this_module_names) != list:
            this_module_names = [this_module_names]
        this_modules = [
            operator.attrgetter(el)(model) if el != '__all__' else model
            for el in this_module_names
        ]
        this_params_bias_bn = {}
        this_params_rest = {}
        for this_module_name, this_module in zip(this_module_names,
                                                 this_modules):
            for name, param in this_module.named_parameters():
                # ignore the param without grads
                if not param.requires_grad:
                    continue
                # May not always have a ".bias" if it's the last element, and no
                # module name
                if name.endswith('bias') or ('.bn' in name):
                    this_params_bias_bn[this_module_name + '.' + name] = param
                else:
                    this_params_rest[this_module_name + '.' + name] = param
        this_scaled_lr = this_lr * dist_info['world_size']
        if cfg.opt.scale_lr_by_bs:
            this_scaled_lr *= cfg.train.batch_size
        params.append({
            'params': this_params_rest.values(),
            'lr': this_scaled_lr,
            'weight_decay': this_wd,
        })
        logger.info('Using LR %f WD %f for parameters %s', params[-1]['lr'],
                    params[-1]['weight_decay'], this_params_rest.keys())
        params.append({
            'params': this_params_bias_bn.values(),
            'lr': this_scaled_lr,
            'weight_decay': this_wd * cfg.opt.bias_bn_wd_scale,
        })
        logger.info('Using LR %f WD %f for parameters %s', params[-1]['lr'],
                    params[-1]['weight_decay'], this_params_bias_bn.keys())
    # Remove any parameters for which LR is 0; will save GPU usage
    params_final = []
    for param_lr in params:
        if param_lr['lr'] != 0.0:
            params_final.append(param_lr)
        else:
            for param in param_lr['params']:
                param.requires_grad = False

    optimizer = hydra.utils.instantiate(cfg.opt.optimizer, params_final)

    # convert scheduler to be per iteration,
    # not per epoch, for warmup that lasts
    # between different epochs
    main_scheduler = hydra.utils.instantiate(
        cfg.opt.scheduler,
        optimizer,
        iters_per_epoch=len(data_loader),
        world_size=dist_info['world_size'])
    lr_scheduler = hydra.utils.instantiate(cfg.opt.warmup,
                                           optimizer,
                                           main_scheduler,
                                           iters_per_epoch=len(data_loader),
                                           world_size=dist_info['world_size'])

    last_saved_ckpt = CKPT_FNAME
    start_epoch = 0
    if os.path.isfile(last_saved_ckpt):
        checkpoint = torch.load(last_saved_ckpt, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        start_epoch = checkpoint['epoch']
        logger.warning('Loaded model from %s (ep %f)', last_saved_ckpt,
                       start_epoch)

    if dist_info['distributed'] and not cfg.eval.eval_fn.only_run_featext:
        # If only feat ext, then each gpu is going to test separately anyway,
        # no need for communication between the models
        logger.info('Wrapping model into DDP')
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist_info['gpu']],
            output_device=dist_info['gpu'])
    elif cfg.data_parallel:
        logger.info('Wrapping model into DP')
        device_ids = range(dist_info['world_size'])
        model = torch.nn.parallel.DataParallel(model, device_ids=device_ids)

    # TODO add an option here to support val mode training
    # Passing in the training dataset, since that will be used for computing
    # weights for classes etc.
    train_eval_op = hydra.utils.instantiate(cfg.train_eval_op,
                                            model,
                                            device,
                                            dataset,
                                            _recursive_=False)

    if cfg.test_only:
        logger.info("Starting test_only")
        hydra.utils.call(cfg.eval.eval_fn, train_eval_op, data_loaders_test,
                         writer, logger, start_epoch)
        return

    logger.info("Start training")
    start_time = time.time()

    # Get training metric logger
    stat_loggers = get_default_loggers(writer, start_epoch, logger)
    best_acc1 = 0.0
    partial_epoch = start_epoch - int(start_epoch)
    start_epoch = int(start_epoch)
    last_saved_time = datetime.datetime(1, 1, 1, 0, 0)
    epoch = 0  # Since using this var to write the checkpoint output, so init to sth
    for epoch in range(start_epoch, cfg.train.num_epochs):
        if dist_info['distributed'] and train_sampler is not None:
            train_sampler.set_epoch(epoch)
        last_saved_time = hydra.utils.call(cfg.train.train_one_epoch_fn,
                                           train_eval_op, optimizer,
                                           lr_scheduler, data_loader, epoch,
                                           partial_epoch,
                                           stat_loggers["train"], logger,
                                           last_saved_time)
        partial_epoch = 0  # Reset, for future epochs
        store_checkpoint([CKPT_FNAME], model, optimizer, lr_scheduler,
                         epoch + 1)
        if cfg.train.eval_freq and epoch % cfg.train.eval_freq == 0:
            acc1 = hydra.utils.call(cfg.eval.eval_fn, train_eval_op,
                                    data_loaders_test, writer, logger,
                                    epoch + 1)
        else:
            acc1 = 0
        if cfg.train.store_best and acc1 >= best_acc1:
            store_checkpoint('checkpoint_best.pth', model, optimizer,
                             lr_scheduler, epoch + 1)
            best_acc1 = acc1

        if isinstance(lr_scheduler.base_scheduler,
                      scheduler.ReduceLROnPlateau):
            lr_scheduler.step(acc1)

        # reset all meters in the metric logger
        for log in stat_loggers:
            stat_loggers[log].reset_meters()
    # Store the final model to checkpoint
    store_checkpoint([CKPT_FNAME], model, optimizer, lr_scheduler, epoch + 1)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time %s', total_time_str)