def main()

in experiments/train_ghn.py [0:0]


def main():

    args = init_config(mode='train_ghn')

    train_queue, val_queue, num_classes = image_loader(args.dataset,
                                                       args.data_dir,
                                                       test=False,
                                                       batch_size=args.batch_size,
                                                       test_batch_size=args.test_batch_size,
                                                       num_workers=args.num_workers,
                                                       seed=args.seed)

    is_imagenet = args.dataset == 'imagenet'
    graphs_queue = DeepNets1M.loader(args.meta_batch_size,
                                     split=args.split,
                                     nets_dir=args.data_dir,
                                     virtual_edges=args.virtual_edges,
                                     num_nets=args.num_nets,
                                     large_images=is_imagenet)

    ghn = GHN(max_shape=args.max_shape,
              num_classes=num_classes,
              hypernet=args.hypernet,
              decoder=args.decoder,
              weight_norm=args.weight_norm,
              ve=args.virtual_edges > 1,
              layernorm=args.ln,
              hid=args.hid,
              debug_level=args.debug).to(args.device)
    if args.multigpu:
        ghn = ghn_parallel(ghn)


    optimizer = torch.optim.Adam(ghn.parameters(), args.lr, weight_decay=args.wd)
    scheduler = MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.gamma)

    trainer = Trainer(optimizer,
                      num_classes,
                      is_imagenet,
                      n_batches=len(train_queue),
                      grad_clip=args.grad_clip,
                      device=ghn.device_ids if args.multigpu else args.device,
                      log_interval=args.log_interval,
                      amp=args.amp)


    seen_nets = set()

    print('\nStarting training GHN with {} parameters!'.format(capacity(ghn)[1]))

    for epoch in range(args.epochs):

        print('\nepoch={:03d}/{:03d}, lr={:e}'.format(epoch + 1, args.epochs, scheduler.get_last_lr()[0]))

        trainer.reset()
        ghn.train()
        failed_batches = 0

        for step, (images, targets) in enumerate(train_queue):

            upd, loss = False, torch.zeros(1, device=args.device)
            while not upd:
                try:
                    graphs = next(graphs_queue)

                    nets_torch = []

                    for nets_args in graphs.net_args:
                        net = Network(is_imagenet_input=is_imagenet,
                                      num_classes=num_classes,
                                      compress_params=True,
                                      **nets_args)
                        nets_torch.append(net)

                    loss = trainer.update(nets_torch, images, targets, ghn=ghn, graphs=graphs)
                    trainer.log()

                    for ind in graphs.net_inds:
                        seen_nets.add(ind)
                    upd = True

                except RuntimeError as e:
                    print('error', type(e), e)
                    oom = str(e).find('out of memory') >= 0
                    is_nan = torch.isnan(loss) or str(e).find('the loss is') >= 0
                    if oom or is_nan:
                        if failed_batches > len(train_queue) // 50:
                            print('Out of patience (after %d attempts to continue), '
                                  'please restart the job with another seed !!!' % failed_batches)
                            raise

                        if oom:
                            print('CUDA out of memory, attempt to clean memory #%d' % failed_batches)
                            if args.multigpu:
                                ghn = ghn.module
                            ghn.to('cpu')
                            torch.cuda.empty_cache()
                            ghn.to(args.device)
                            if args.multigpu:
                                ghn = ghn_parallel(ghn)

                        failed_batches += 1

                    else:
                        raise

            del images, targets, graphs, nets_torch, loss
            if step % 10 == 0:
                torch.cuda.empty_cache()

        if args.save:
            # Save config necessary to restore GHN configuration when evaluating it
            config = {}
            config['max_shape'] = args.max_shape
            config['num_classes'] = num_classes
            config['hypernet'] = args.hypernet
            config['decoder'] = args.decoder
            config['weight_norm'] = args.weight_norm
            config['ve'] = (ghn.module if args.multigpu else ghn).ve
            config['layernorm'] = args.ln
            config['hid'] = args.hid

            checkpoint_path = os.path.join(args.save, 'ghn.pt')
            torch.save({'state_dict': (ghn.module if args.multigpu else ghn).state_dict(),
                        'optimzer': optimizer.state_dict(),
                        'epoch': epoch,
                        'config': config}, checkpoint_path)
            print('\nsaved the checkpoint to {}'.format(checkpoint_path))

        print('{} unique architectures seen'.format(len(seen_nets)))

        scheduler.step()