def run()

in contrib/action_recognition/i3d/train.py [0:0]


def run(*options, cfg=None):
    """Run training and validation of model

    Notes:
        Options can be passed in via the options argument and loaded from the cfg file
        Options loaded from default.py will be overridden by options loaded from cfg file
        Options passed in through options argument will override option loaded from cfg file
    
    Args:
        *options (str,int ,optional): Options used to overide what is loaded from the config. 
                                      To see what options are available consult default.py
        cfg (str, optional): Location of config file to load. Defaults to None.
    """
    update_config(config, options=options, config_file=cfg)

    print("Training ", config.TRAIN.MODALITY, " model.")
    print("Batch size:", config.TRAIN.BATCH_SIZE, " Gradient accumulation steps:", config.TRAIN.GRAD_ACCUM_STEPS)

    torch.backends.cudnn.benchmark = config.CUDNN.BENCHMARK

    torch.manual_seed(config.SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(config.SEED)
    np.random.seed(seed=config.SEED)

    # Log to tensorboard
    writer = SummaryWriter(log_dir=config.LOG_DIR)

    # Setup dataloaders
    train_loader = torch.utils.data.DataLoader(
        I3DDataSet(
            data_root=config.DATASET.DIR,
            split=config.DATASET.SPLIT,
            sample_frames=config.TRAIN.SAMPLE_FRAMES,
            modality=config.TRAIN.MODALITY,
            transform=torchvision.transforms.Compose([
                GroupScale(config.TRAIN.RESIZE_MIN),
                GroupRandomCrop(config.TRAIN.INPUT_SIZE),
                GroupRandomHorizontalFlip(),
                GroupNormalize(modality=config.TRAIN.MODALITY),
                Stack(),
            ])
        ),
        batch_size=config.TRAIN.BATCH_SIZE,
        shuffle=True,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY
    )

    val_loader = torch.utils.data.DataLoader(
        I3DDataSet(
            data_root=config.DATASET.DIR,
            split=config.DATASET.SPLIT,
            modality=config.TRAIN.MODALITY,
            train_mode=False,
            transform=torchvision.transforms.Compose([
                GroupScale(config.TRAIN.RESIZE_MIN),
                GroupCenterCrop(config.TRAIN.INPUT_SIZE),
                GroupNormalize(modality=config.TRAIN.MODALITY),
                Stack(),
            ]),
        ),
        batch_size=config.TEST.BATCH_SIZE,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY
    )

    # Setup model
    if config.TRAIN.MODALITY == "RGB":
        channels = 3
        checkpoint = config.MODEL.PRETRAINED_RGB
    elif config.TRAIN.MODALITY == "flow":
        channels = 2
        checkpoint = config.MODEL.PRETRAINED_FLOW
    else:
        raise ValueError("Modality must be RGB or flow")

    i3d_model = InceptionI3d(400, in_channels=channels)
    i3d_model.load_state_dict(torch.load(checkpoint))

    # Replace final FC layer to match dataset
    i3d_model.replace_logits(config.DATASET.NUM_CLASSES)

    criterion = torch.nn.CrossEntropyLoss().cuda()

    optimizer = optim.SGD(
       i3d_model.parameters(), 
       lr=0.1,
       momentum=0.9, 
       weight_decay=0.0000001
    )

    i3d_model = i3d_model.cuda()

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=0.1,
        patience=2,
        verbose=True,
        threshold=1e-4,
        min_lr=1e-4
    )

    # Data-parallel
    devices_lst = list(range(torch.cuda.device_count()))
    print("Devices {}".format(devices_lst))
    if len(devices_lst) > 1:
        i3d_model = torch.nn.DataParallel(i3d_model)

    if not os.path.exists(config.MODEL.CHECKPOINT_DIR):
        os.makedirs(config.MODEL.CHECKPOINT_DIR)
    
    for epoch in range(config.TRAIN.MAX_EPOCHS):

        train(train_loader,
            i3d_model,
            criterion,
            optimizer,
            epoch,
            writer
        )

        if (epoch + 1) % config.TEST.EVAL_FREQ == 0 or epoch == config.TRAIN.MAX_EPOCHS - 1:
            val_loss = validate(val_loader, i3d_model, criterion, epoch, writer)
            scheduler.step(val_loss)
            torch.save(
                i3d_model.module.state_dict(),
                config.MODEL.CHECKPOINT_DIR+'/'+config.MODEL.NAME+'_split'+str(config.DATASET.SPLIT)+'_epoch'+str(epoch).zfill(3)+'.pt'
            )

    writer.close()