def __init__()

in LaNAS/one-shot_LaNAS/supernet/supernet_train.py [0:0]


    def __init__( self, 
        data='../data', batch_size=64, 
        learning_rate=0.025, learning_rate_min=0.001, 
        momentum=0.9, weight_decay=3e-4, 
        report_freq=50, gpu=0, 
        epochs=50, 
        init_channels=16, layers=8, 
        cutout=False, cutout_length=16, 
        drop_path_prob=0.3, seed=2, 
        grad_clip=5, save_prefix='EXP',
        init_masks = []):
        
        #assert len(init_masks) > 0
        
        if not torch.cuda.is_available():
          print('no gpu device available')
          sys.exit(1)
        
        #device level hyerparameters
        np.random.seed( seed )
        torch.cuda.manual_seed( seed )
        torch.cuda.set_device(gpu)
        cudnn.enabled=True
        cudnn.benchmark = True
        print('gpu device = %d' % gpu)
        print('data=', data, 'batch_size=', batch_size, 'learning_rate=', learning_rate, 'learning_rate_min=',
              learning_rate_min, 'momentum=', momentum, 'weight_decay=', weight_decay, 'report_freq=', report_freq, 
              'gpu=', gpu, 'epochs=', epochs, 'init_channels=', init_channels, 'layers=', layers, 'cutout=', cutout,
              'cutout_length=', cutout_length, 'drop_path_prob=', drop_path_prob, 'seed=', seed, 'grad_clip=', grad_clip )
        
        savedirs = "supernet-logs"

        continue_train   = False
        if os.path.exists(savedirs + '/model.pt'):
          continue_train = True
         
        #prepare logging
        if not continue_train:
          create_exp_dir(savedirs, scripts_to_save=glob.glob('*.py'))
        
        #training hyperparameters
        self.data              = data
        self.batch_size        = batch_size
        self.learning_rate     = learning_rate
        self.learning_rate_min = learning_rate_min
        self.momentum          = momentum
        self.weight_decay      = weight_decay
        self.report_freq       = report_freq
        self.epochs            = epochs
        self.init_channels     = init_channels
        self.layers            = layers
        self.cutout            = cutout
        self.cutout_length     = 16
        self.drop_path_prob    = drop_path_prob
        self.grad_clip         = grad_clip
        self.save_prefix       = savedirs
        self.start_epoch       = 0
        self.mask_to_train     = init_masks #masks drive the iterations
        CIFAR_CLASSES          = 10
          
        #setup network
        self.criterion  = nn.CrossEntropyLoss()
        self.criterion  = self.criterion.cuda()
        self.supernet   = Network(supernet_normal, supernet_reduce, layer_type, init_channels, CIFAR_CLASSES, layers, self.criterion, steps=len(supernet_normal))
        self.supernet   = self.supernet.cuda()
        self.optimizer  = torch.optim.SGD( self.supernet.parameters(), self.learning_rate, momentum = self.momentum, weight_decay = weight_decay)
        self.scheduler  = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, float(epochs), eta_min = learning_rate_min )
        
        #setup training&test data
        train_transform, valid_transform = _data_transforms_cifar10(cutout, cutout_length)
        train_data = dset.CIFAR10(root=data, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR10(root=data, train=False, download=True, transform=valid_transform)
        self.train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2)
        self.valid_queue = torch.utils.data.DataLoader(
            valid_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)
        
        print("length of training & valid queue:", len(self.train_queue), len(self.valid_queue) )
        print("param size = %fMB"% count_parameters_in_MB(self.supernet) )
        
        if continue_train:
            print('continue train from checkpoint')
            checkpoint       = torch.load(self.save + '/model.pt')
            self.supernet.load_state_dict(checkpoint['model_state_dict'])
            self.start_epoch = checkpoint['epoch']
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler   = checkpoint['scheduler']
        
        self.curt_epoch = self.start_epoch
        self.curt_step  = 0
        
        self.base_net   = zero_supernet_generator(node, layer_type )