def main_worker()

in train_attentive_nas.py [0:0]


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu  # local rank, local machine cuda id
    args.local_rank = args.gpu
    args.batch_size = args.batch_size_per_gpu
    args.batch_size_total = args.batch_size * args.world_size
    #rescale base lr
    args.lr_scheduler.base_lr = args.lr_scheduler.base_lr * (max(1, args.batch_size_total // 256))

    # set random seed, make sure all random subgraph generated would be the same
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.gpu:
        torch.cuda.manual_seed(args.seed)

    global_rank = args.gpu + args.machine_rank * ngpus_per_node
    dist.init_process_group(
        backend=args.dist_backend, 
        init_method=args.dist_url,
        world_size=args.world_size, 
        rank=global_rank
    )

    # Setup logging format.
    logging.setup_logging(args.logging_save_path, 'w')

    logger.info(f"Use GPU: {args.gpu}, machine rank {args.machine_rank}, num_nodes {args.num_nodes}, \
                    gpu per node {ngpus_per_node}, world size {args.world_size}")

    # synchronize is needed here to prevent a possible timeout after calling
    # init_process_group
    # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
    comm.synchronize()

    args.rank = comm.get_rank() # global rank
    args.local_rank = args.gpu
    torch.cuda.set_device(args.gpu)

    # build model
    logger.info("=> creating model '{}'".format(args.arch))
    model = models.model_factory.create_model(args)
    model.cuda(args.gpu)

    #build arch sampler
    arch_sampler = None
    if getattr(args, 'sampler', None):
        arch_sampler = ArchSampler(
            args.sampler.arch_to_flops_map_file_path, args.sampler.discretize_step, model, None 
        )

    # use sync batchnorm
    if getattr(args, 'sync_bn', False):
        model.apply(
                lambda m: setattr(m, 'need_sync', True))

    model = comm.get_parallel_model(model, args.gpu) #local rank

    logger.info(model)

    criterion = loss_ops.CrossEntropyLossSmooth(args.label_smoothing).cuda(args.gpu)
    soft_criterion = loss_ops.KLLossSoft().cuda(args.gpu)

    if not getattr(args, 'inplace_distill', True):
        soft_criterion = None

    ## load dataset, train_sampler: distributed
    train_loader, val_loader, train_sampler =  build_data_loader(args)
    args.n_iters_per_epoch = len(train_loader)

    logger.info( f'building optimizer and lr scheduler, \
            local rank {args.gpu}, global rank {args.rank}, world_size {args.world_size}')
    optimizer = build_optimizer(args, model)
    lr_scheduler = build_lr_scheduler(args, optimizer)
 
    # optionally resume from a checkpoint
    if args.resume:
        saver.load_checkpoints(args, model, optimizer, lr_scheduler, logger)

    logger.info(args)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        args.curr_epoch = epoch
        logger.info('Training lr {}'.format(lr_scheduler.get_lr()[0]))

        # train for one epoch
        acc1, acc5 = train_epoch(epoch, model, train_loader, optimizer, criterion, args, \
                arch_sampler=arch_sampler, soft_criterion=soft_criterion, lr_scheduler=lr_scheduler)

        if comm.is_master_process() or args.distributed:
            # validate supernet model
            validate(
                train_loader, val_loader, model, criterion, args
            )

        if comm.is_master_process():
            # save checkpoints
            saver.save_checkpoint(
                args.checkpoint_save_path, 
                model,
                optimizer,
                lr_scheduler, 
                args,
                epoch,
            )