in train_attentive_nas.py [0:0]
def main_worker(gpu, ngpus_per_node, args):
args.gpu = gpu # local rank, local machine cuda id
args.local_rank = args.gpu
args.batch_size = args.batch_size_per_gpu
args.batch_size_total = args.batch_size * args.world_size
#rescale base lr
args.lr_scheduler.base_lr = args.lr_scheduler.base_lr * (max(1, args.batch_size_total // 256))
# set random seed, make sure all random subgraph generated would be the same
random.seed(args.seed)
torch.manual_seed(args.seed)
if args.gpu:
torch.cuda.manual_seed(args.seed)
global_rank = args.gpu + args.machine_rank * ngpus_per_node
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=global_rank
)
# Setup logging format.
logging.setup_logging(args.logging_save_path, 'w')
logger.info(f"Use GPU: {args.gpu}, machine rank {args.machine_rank}, num_nodes {args.num_nodes}, \
gpu per node {ngpus_per_node}, world size {args.world_size}")
# synchronize is needed here to prevent a possible timeout after calling
# init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
args.rank = comm.get_rank() # global rank
args.local_rank = args.gpu
torch.cuda.set_device(args.gpu)
# build model
logger.info("=> creating model '{}'".format(args.arch))
model = models.model_factory.create_model(args)
model.cuda(args.gpu)
#build arch sampler
arch_sampler = None
if getattr(args, 'sampler', None):
arch_sampler = ArchSampler(
args.sampler.arch_to_flops_map_file_path, args.sampler.discretize_step, model, None
)
# use sync batchnorm
if getattr(args, 'sync_bn', False):
model.apply(
lambda m: setattr(m, 'need_sync', True))
model = comm.get_parallel_model(model, args.gpu) #local rank
logger.info(model)
criterion = loss_ops.CrossEntropyLossSmooth(args.label_smoothing).cuda(args.gpu)
soft_criterion = loss_ops.KLLossSoft().cuda(args.gpu)
if not getattr(args, 'inplace_distill', True):
soft_criterion = None
## load dataset, train_sampler: distributed
train_loader, val_loader, train_sampler = build_data_loader(args)
args.n_iters_per_epoch = len(train_loader)
logger.info( f'building optimizer and lr scheduler, \
local rank {args.gpu}, global rank {args.rank}, world_size {args.world_size}')
optimizer = build_optimizer(args, model)
lr_scheduler = build_lr_scheduler(args, optimizer)
# optionally resume from a checkpoint
if args.resume:
saver.load_checkpoints(args, model, optimizer, lr_scheduler, logger)
logger.info(args)
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
args.curr_epoch = epoch
logger.info('Training lr {}'.format(lr_scheduler.get_lr()[0]))
# train for one epoch
acc1, acc5 = train_epoch(epoch, model, train_loader, optimizer, criterion, args, \
arch_sampler=arch_sampler, soft_criterion=soft_criterion, lr_scheduler=lr_scheduler)
if comm.is_master_process() or args.distributed:
# validate supernet model
validate(
train_loader, val_loader, model, criterion, args
)
if comm.is_master_process():
# save checkpoints
saver.save_checkpoint(
args.checkpoint_save_path,
model,
optimizer,
lr_scheduler,
args,
epoch,
)