def main_worker()

in scripts/action_recognition/train.py [0:0]


def main_worker(cfg):
    # create tensorboard and logs
    if cfg.DDP_CONFIG.GPU_WORLD_RANK == 0:
        tb_logdir = build_log_dir(cfg)
        writer = SummaryWriter(log_dir=tb_logdir)
    else:
        writer = None
    cfg.defrost()
    if cfg.CONFIG.TEST.MULTI_VIEW_TEST:
        # disable multi-view testing during training
        cfg.CONFIG.TEST.MULTI_VIEW_TEST = False
    cfg.freeze()

    # create model
    print('Creating model: %s' % (cfg.CONFIG.MODEL.NAME))
    model = get_model(cfg)
    model_without_ddp = model

    # create dataset and dataloader
    train_loader, val_loader, train_sampler, val_sampler, _ = build_dataloader(cfg)

    # create criterion
    criterion = torch.nn.CrossEntropyLoss().cuda()

    # create optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=cfg.CONFIG.TRAIN.LR,
                                momentum=cfg.CONFIG.TRAIN.OPTIMIZER.MOMENTUM,
                                weight_decay=cfg.CONFIG.TRAIN.WEIGHT_DECAY)

    model, optimizer, model_ema = deploy_model(model, optimizer, cfg)
    model_without_ddp = model.module
    num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('Number of parameters in the model: %6.2fM' % (num_parameters / 1000000))

    # create lr scheduler
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=cfg.CONFIG.TRAIN.LR_SCHEDULER.LR_MILESTONE,
                                                        gamma=cfg.CONFIG.TRAIN.LR_SCHEDULER.DECAY_RATE)

    print('Start training...')
    start_time = time.time()
    max_accuracy = 0.0
    mixup_fn = None
    for epoch in range(cfg.CONFIG.TRAIN.START_EPOCH, cfg.CONFIG.TRAIN.EPOCH_NUM):
        if cfg.DDP_CONFIG.DISTRIBUTED:
            train_sampler.set_epoch(epoch)

        train_classification(cfg, model, model_ema, criterion, train_loader, optimizer, epoch, mixup_fn, lr_scheduler, writer)
        lr_scheduler.step()

        if cfg.DDP_CONFIG.GPU_WORLD_RANK == 0 and (
                epoch % cfg.CONFIG.LOG.SAVE_FREQ == 0 or epoch == cfg.CONFIG.TRAIN.EPOCH_NUM - 1):
            save_checkpoint(cfg, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler)

        if epoch % cfg.CONFIG.VAL.FREQ == 0 or epoch == cfg.CONFIG.TRAIN.EPOCH_NUM - 1:
            acc1, acc5, loss = validate_classification(cfg, val_loader, model, criterion, epoch, writer)
            max_accuracy = max(max_accuracy, acc1)
            print(f"Accuracy of the network: {acc1:.1f}%")
            print(f'Max accuracy: {max_accuracy:.2f}%')

    if writer is not None:
        writer.close()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))