def main()

in main_gdt.py [0:0]


def main(args):
    # Set up mixed precision training
    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError("Apex currently only supports Python 3. Aborting.")
        if amp is None:
            raise RuntimeError(
                "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                "to enable mixed-precision training."
            )

    # Make output dir
    if args.output_dir:
        makedir(args.output_dir)

    # Init distributed mode
    if torch.cuda.is_available():
        init_distributed_mode(args)

    # init signal handler
    init_signal_handler()

    # Set up logger
    if args.distributed:
        filename = str(args.job_id) + '_' + str(args.rank) + '_log.out'

    # Set up tensorboard
    tbx_path = os.path.join(args.output_dir, 'tensorboard')
    global_rank = args.rank if args.distributed else 0
    is_master = True if global_rank == 0 else False
    if is_master:
        writer = SummaryWriter(tbx_path)
        writer.add_text(
            'args',
            " \n".join(['%s : %s' % (arg, getattr(args, arg)) for arg in vars(args)]), 
            0
        )
    else:
        writer = None

    # Log version information
    logger.info(args)
    logger.info(f"torch version: {torch.__version__}")

    # Set distributed mode
    device = torch.device(args.device)

    # Set CudNN benchmark
    torch.backends.cudnn.benchmark = True

    # Create model
    logger.info("Creating model")
    if args.model == 'av_gdt':
        model = GDT(
            vid_base_arch=args.vid_base_arch, 
            aud_base_arch=args.aud_base_arch,
            pretrained=False, 
            norm_feat=args.norm_feat, 
            use_mlp=args.use_mlp,
            num_classes=256, 
        )
    else:
        # Video-Text GDT encoder for pretraining
        model = TextVid_GDT(
            vid_base_arch=args.vid_base_arch,
            text_base_arch='word2vec',
            pretrained=False,
            norm_feat=args.norm_feat,
            use_mlp=args.use_mlp,
            num_classes=256,
        )
    model.to(device)
    if args.distributed and args.sync_bn:
        logger.info("Sync BN on model")
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    model_without_ddp = model
    if args.distributed:
        ngpus_per_node = torch.cuda.device_count()
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False
        )
        model_without_ddp = model.module

    if args.aug_audio:
        if args.audio_augtype == 'mild':
            args.aug_audio = [1, 1, 2, 5]
        elif args.audio_augtype == 'medium':
            args.aug_audio = [1, 1, 3, 6]
        elif args.audio_augtype == 'heavy':
            args.aug_audio = [2, 2, 3, 6]

    # Set up training optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # For Mixed Precision training
    if args.apex:
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.apex_opt_level
        )

    # Set up LR scheduler
    milestones = [int(lr) - args.lr_warmup_epochs for lr in args.lr_milestones.split(',')]
    lr_scheduler = None
    if args.use_scheduler:
        if args.lr_warmup_epochs > 0:
            if args.scheduler_type == 'multi_step':
                logger.info(f'Using Multi-Step LR scheduler')
                scheduler_step = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer,
                    milestones=milestones,
                    gamma=args.lr_gamma
                )
            else:
                logger.info(f'Using Cosine Annealing LR scheduler')
                scheduler_step = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
            lr_scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=args.world_size,
                total_epoch=args.lr_warmup_epochs,
                after_scheduler=scheduler_step
            )
        else:
            if args.scheduler_type == 'multi_step':
                logger.info(f'Using Multi-Step LR scheduler')
                lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer,
                    milestones=milestones,
                    gamma=args.lr_gamma
                )
            else:
                logger.info(f'Using Cosine Annealing LR scheduler')
                lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

    # Checkpointing restart
    ckp_path = os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth')
    if os.path.isfile(ckp_path):
        logger.info(f'Loading checkpoint')
        checkpoint = torch.load(ckp_path, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch']
        logger.info(f'Restrating at epoch {args.start_epoch}')

    # Create dataloader
    if args.dataset == "ht100m":
        ds = HT100M_Dataset(
            csv_file='data/howto.csv',
            video_root=args.root_dir,
            caption_root=args.ht100m_caption_root,
            token_to_word_path='data/dict.npy',
            fps=32/int(args.sample_rate),
            num_frames=args.clip_len,
            size=args.train_crop_size,
            center_crop=args.center_crop, # True
        )
    else:
        # Audio-Visual datasets: Kinetics-400/600, Audioset, VGG-Sound
        ds = GDTPretrainDataset(
            ds_name=args.dataset,
            root_dir=args.root_dir,
            mode='train',
            args=args
        )

    print("Creating data loaders", flush=True)
    train_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(ds)

    data_loader = torch.utils.data.DataLoader(
        ds,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=None,
        drop_last=True
    )

    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        if writer:
            writer.add_scalar('train/epoch', epoch, epoch)
        logger.info(f'Start training epoch: {epoch}')
        loss = train_one_epoch(
            args,
            data_loader,
            model,
            optimizer,
            device,
            epoch,
            args.print_freq,
            lr_scheduler,
            args.apex,
            writer=writer,
        )
        if lr_scheduler:
            lr_scheduler.step()
        if args.output_dir:
            save_checkpoint(args, epoch, model, optimizer, lr_scheduler)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info(f'Training time {total_time_str}')