def run()

in LaNAS/one-shot_LaNAS/Evaluate/super_individual_train.py [0:0]


def run():
  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logging.info('gpu device = %d' % args.gpu)
  logging.info("args = %s", args)

  cur_epoch = 0

  layers_type = [
      'max_pool_3x3',
      'skip_connect',
      'sep_conv_3x3',
      'sep_conv_5x5'
  ]
  # supernet_normal = eval(args.supernet_normal)
  # supernet_reduce = eval(args.supernet_reduce)


  supernet_normal, supernet_reduce = translator.encoding_to_masks(eval(args.masked_code))
  supernet_normal = translator.expend_to_supernet_code(supernet_normal)
  supernet_reduce = translator.expend_to_supernet_code(supernet_reduce)

  if not continue_train:
    print('train from scratch!')

    model = Network(supernet_normal, supernet_reduce, layers_type, args.init_ch, CIFAR_CLASSES, args.layers,
                    args.auxiliary,
                    steps=len(supernet_normal), multiplier=len(supernet_normal))
    # model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
    model = model.cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.wd
        )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))



  else:
    print('train from checkpoints')
    # model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)

    model = Network(supernet_normal, supernet_reduce, layers_type, args.init_ch, CIFAR_CLASSES, args.layers,
                    args.auxiliary,
                    steps=len(supernet_normal))
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    optimizer = torch.optim.SGD(
      model.parameters(),
      args.lr,
      momentum=args.momentum,
      weight_decay=args.wd
    )

    checkpoint = torch.load(args.save + '/model.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    cur_epoch = checkpoint['epoch']
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler = checkpoint['scheduler']





  train_transform, valid_transform = utils._data_transforms_cifar10(args, args.cutout_length)
  train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
  valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

  train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)

  valid_queue = torch.utils.data.DataLoader(
      valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)



  best_acc = 0.0


  for epoch in range(cur_epoch, args.epochs):
    print("=====> current epoch:", epoch)
    logging.info('=====> current epoch: %d', epoch)

    logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
    model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

    # train_acc, train_obj = train(train_queue, model, criterion, optimizer)
    train_acc, train_obj = train(train_queue, model, criterion, optimizer, args.grad_clip, args.report_freq)

    scheduler.step()
    logging.info('train_acc %f', train_acc)

    # valid_acc, valid_obj = infer(valid_queue, model, criterion)
    valid_acc, valid_obj = infer(valid_queue, model, criterion, args.report_freq)
    logging.info('valid_acc %f', valid_acc)

    if valid_acc > best_acc:
      best_acc = valid_acc
      print('this model is the best')
      logging.info('this model is the best')
      torch.save({'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'scheduler': scheduler,
                  'optimizer_state_dict': optimizer.state_dict()}, os.path.join(args.save, 'top_1.pt'))

    torch.save({'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'scheduler': scheduler,
                'optimizer_state_dict': optimizer.state_dict()}, os.path.join(args.save, 'model.pt'))

    logging.info('best_acc: %f', best_acc)
    print('current best acc is', best_acc)