def main()

in experiments/sgd/train_net.py [0:0]


def main():

    args = init_config(mode='train_net')

    is_imagenet = args.dataset == 'imagenet'
    train_queue, valid_queue, num_classes = image_loader(dataset=args.dataset,
                                                         data_dir=args.data_dir,
                                                         test=True,
                                                         load_train_anyway=True,
                                                         batch_size=args.batch_size,
                                                         test_batch_size=args.test_batch_size,
                                                         num_workers=args.num_workers,
                                                         cutout=args.cutout,
                                                         cutout_length=args.cutout_length,
                                                         seed=args.seed,
                                                         noise=args.noise,
                                                         n_shots=args.n_shots)


    assert args.arch is not None, 'architecture genotype/index must be specified'

    try:
        genotype = eval('genotypes.%s' % args.arch)
        net_args = {'C': args.init_channels,
                    'genotype': genotype,
                    'n_cells': args.layers,
                    'C_mult': int(genotype != ViT) + 1,  # assume either ViT or DARTS-style architecture
                    'preproc': genotype != ViT,
                    'stem_type': 1}  # assume that the ImageNet-style stem is used by default
    except:
        deepnets = DeepNets1M(split=args.split,
                              nets_dir=args.data_dir,
                              large_images=is_imagenet,
                              arch=args.arch)
        assert len(deepnets) == 1, 'one architecture must be chosen to train'
        graph = deepnets[0]
        net_args, idx = graph.net_args, graph.net_idx
        if 'norm' in net_args and net_args['norm'] == 'bn':
            net_args['norm'] = 'bn-track'
    if isinstance(net_args['genotype'], str):
        model = adjust_net(eval('torchvision.models.%s(pretrained=%d)' % (net_args['genotype'], args.pretrained)), is_imagenet)
    else:
        model = Network(num_classes=num_classes,
                        is_imagenet_input=is_imagenet,
                        auxiliary=args.auxiliary,
                        **net_args)

    if args.ckpt is not None or isinstance(model, torchvision.models.ResNet):
        model = pretrained_model(model, args.ckpt, num_classes, args.debug, GHN)

    model = model.train().to(args.device)

    print('\nTraining arch={} with {} parameters'.format(args.arch, capacity(model)[1]))

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.wd
    )

    if is_imagenet:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.97)
    else:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

    trainer = Trainer(optimizer,
                      num_classes,
                      is_imagenet,
                      n_batches=len(train_queue),
                      grad_clip=args.grad_clip,
                      auxiliary=args.auxiliary,
                      auxiliary_weight=args.auxiliary_weight,
                      device=args.device,
                      log_interval=args.log_interval,
                      amp=args.amp)

    for epoch in range(max(1, args.epochs)):  # if args.epochs=0, then just evaluate the model

        if args.epochs > 0:
            print('\nepoch={:03d}/{:03d}, lr={:e}'.format(epoch + 1, args.epochs, scheduler.get_last_lr()[0]))
            model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

            trainer.reset()
            model.train()
            for images, targets in train_queue:
                trainer.update(model, images, targets)
                trainer.log()

            if args.save:
                checkpoint_path = os.path.join(args.save, 'checkpoint.pt')
                torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, checkpoint_path)
                print('\nsaved the checkpoint to {}'.format(checkpoint_path))


        infer(model.eval(), valid_queue, verbose=True)

        scheduler.step()