def main()

in train_vclr.py [0:0]


def main(args, writer):
    train_loader = get_loader(args)
    n_data = len(train_loader.dataset)
    logger.info("length of training dataset: {}".format(n_data))

    model, model_ema = build_model(args)
    logger.info('{}'.format(model))
    contrast = MemorySeCo(128, args.nce_k, args.nce_t, args.nce_t_intra).cuda()
    contrast_tsn = MemoryVCLR(128, args.nce_k, args.nce_t).cuda()
    criterion = NCESoftmaxLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.batch_size * dist.get_world_size() / 256 * args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = get_scheduler(optimizer, len(train_loader), args)
    model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=args.broadcast_buffer)
    logger.info('Distributed Enabled')

    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume)
        load_checkpoint(args, model, model_ema, contrast, contrast_tsn, optimizer, scheduler, logger.info)

    # routine
    logger.info('Training')
    timer = mmcv.Timer()
    for epoch in range(args.start_epoch, args.epochs + 1):
        train_loader.sampler.set_epoch(epoch)
        loss = train_vclr(epoch, train_loader, model, model_ema, contrast, contrast_tsn, criterion, optimizer,
                          scheduler, writer, args)
        logger.info('epoch {}, total time {:.2f}, loss={}'.format(epoch, timer.since_last_check(), loss))
        if dist.get_rank() == 0:
            save_checkpoint(args, epoch, model, model_ema, contrast, contrast_tsn, optimizer, scheduler, logger.info)
        dist.barrier()