in LaNAS/Distributed_LaNAS/server/MCTS.py [0:0]
def search(self):
address = ('XXX.XX.XX.XXX', 8000)
server = Listener(address, authkey=b'nasnet')
while len(self.search_space) > 0:
self.dump_all_states()
print("-"*20,"iteration:", self.ITERATION )
#dispatch & retrieve jobs:
self.dispatch_and_retrieve_jobs(server)
#assemble the training data:
self.populate_training_data()
print("populate training data###", "total samples:", len(self.samples)," trained:", len(self.DISPATCHED_JOB)," task queue:", len(self.TASK_QUEUE) )
self.print_tree()
#training the tree
self.train_nodes()
print("finishing training###", "total samples:", len(self.samples)," trained:", len(self.DISPATCHED_JOB)," task queue:", len(self.TASK_QUEUE) )
self.print_tree()
#clear the data in nodes
print("reset training data###", "total samples:", len(self.samples)," trained:", len(self.DISPATCHED_JOB)," task queue:", len(self.TASK_QUEUE) )
self.reset_node_data()
self.print_tree()
print("populate prediction data###", "total samples:", len(self.samples)," trained:", len(self.DISPATCHED_JOB)," task queue:", len(self.TASK_QUEUE) )
self.populate_prediction_data()
#self.print_tree()
print("predict###", "total samples:", len(self.samples)," trained:", len(self.DISPATCHED_JOB)," task queue:", len(self.TASK_QUEUE) )
self.predict_nodes()
self.check_leaf_bags()
self.print_tree()
for i in range(0, 50):
#select
target_bin = self.select()
sampled_arch = target_bin.sample_arch()
sampled_arch = None
#NOTED: the sampled arch can be None
if sampled_arch is not None:
#TODO: back-propogate an architecture
#push the arch into task queue
if json.dumps(sampled_arch) not in self.DISPATCHED_JOB:
self.TASK_QUEUE.append( sampled_arch )
self.DISPATCHED_JOB[json.dumps(sampled_arch)] = 0
self.search_space.remove(sampled_arch)
else:
#trail 1: pick a network from the best leaf
for n in self.nodes:
if n.is_leaf == True:
sampled_arch = n.sample_arch()
if sampled_arch is not None:
if json.dumps(sampled_arch) not in self.DISPATCHED_JOB:
self.TASK_QUEUE.append( sampled_arch )
self.DISPATCHED_JOB[json.dumps( sampled_arch )] = 0
self.search_space.remove(sampled_arch)
break
else:
continue
self.print_task_queue()
self.print_tree()
self.ITERATION += 1