def search()

in LaNAS/LaNAS_NASBench101/MCTS.py [0:0]


    def search(self):
        while len(self.search_space) > 0:
            #self.dump_all_states()
            #assemble the training data:
            self.populate_training_data()
            print("-"*20,"iteration:", self.SEARCH_COUNTER)
            print("populate training data")
            #self.print_tree()

            #training the tree
            self.train_nodes()
            print("finishing training")
            #self.print_tree()
            
            #clear the data in nodes
            print("reset training data")
            self.reset_node_data()
            #self.print_tree()
            
            print("populate prediction data")
            self.populate_prediction_data()
            #self.print_tree()
            
            print("predict:", len(self.samples) )
            self.predict_nodes()
            self.check_leaf_bags()
            
            #print("training meta-dnn toward #samples:", len( self.samples ) )
            #self.metaDNN.train( self.samples )
            self.print_tree()
            
            for i in range(0, 20):
                #select
                target_bin        = self.select()
                sampled_arch      = target_bin.sample_arch()
                if sampled_arch is not None:
                #TODO: back-propogate an architecture
                    sampled_acc  = self.net_trainer.train_net(sampled_arch)
                    self.samples[ json.dumps(sampled_arch) ] = sampled_acc
                    print("sampled architecture:", sampled_arch, sampled_acc)
                    self.backpropogate( target_bin, sampled_acc)
                    self.search_space.remove(sampled_arch)
                else:
                    for n in self.nodes:
                        if n.is_leaf == True:
                            sampled_arch = n.sample_arch()
                            if sampled_arch is not None:
                                print(sampled_arch)
                                sampled_acc  = self.net_trainer.train_net(sampled_arch)
                                self.samples[ json.dumps(sampled_arch) ] = sampled_acc
                                self.backpropogate( n, sampled_acc)
                                self.search_space.remove(sampled_arch)
                                break
                            else:
                                continue
        
            self.print_tree()
            self.SEARCH_COUNTER += 1