in LaNAS/LaNAS_NASBench101/MCTS.py [0:0]
def search(self):
while len(self.search_space) > 0:
#self.dump_all_states()
#assemble the training data:
self.populate_training_data()
print("-"*20,"iteration:", self.SEARCH_COUNTER)
print("populate training data")
#self.print_tree()
#training the tree
self.train_nodes()
print("finishing training")
#self.print_tree()
#clear the data in nodes
print("reset training data")
self.reset_node_data()
#self.print_tree()
print("populate prediction data")
self.populate_prediction_data()
#self.print_tree()
print("predict:", len(self.samples) )
self.predict_nodes()
self.check_leaf_bags()
#print("training meta-dnn toward #samples:", len( self.samples ) )
#self.metaDNN.train( self.samples )
self.print_tree()
for i in range(0, 20):
#select
target_bin = self.select()
sampled_arch = target_bin.sample_arch()
if sampled_arch is not None:
#TODO: back-propogate an architecture
sampled_acc = self.net_trainer.train_net(sampled_arch)
self.samples[ json.dumps(sampled_arch) ] = sampled_acc
print("sampled architecture:", sampled_arch, sampled_acc)
self.backpropogate( target_bin, sampled_acc)
self.search_space.remove(sampled_arch)
else:
for n in self.nodes:
if n.is_leaf == True:
sampled_arch = n.sample_arch()
if sampled_arch is not None:
print(sampled_arch)
sampled_acc = self.net_trainer.train_net(sampled_arch)
self.samples[ json.dumps(sampled_arch) ] = sampled_acc
self.backpropogate( n, sampled_acc)
self.search_space.remove(sampled_arch)
break
else:
continue
self.print_tree()
self.SEARCH_COUNTER += 1