in LaNAS/LaNAS_NASBench101/MCTS.py [0:0]
def __init__(self, search_space, tree_height, arch_code_len ):
assert type(search_space) == type([])
assert len(search_space) >= 1
assert type(search_space) == type([])
assert type(search_space[0]) == type([])
self.ARCH_CODE_LEN = arch_code_len
self.SEARCH_COUNTER = 0
self.samples = {}
self.nodes = []
self.search_space = search_space
self.Cp = 0.1
# 49 is the length of architectuer encoding, 1 is for predicted accuracy
#self.metaDNN = LinearModel(49, 1)
#querying the accuracy from nasbench
self.net_trainer = Net_Trainer( )
# set random seed
np.random.seed(seed=int(time.time() ) )
random.seed(datetime.now() )
#initialize the a full tree
total_nodes = 2**tree_height - 1
for i in range(1, total_nodes + 1):
is_good_kid = False
if (i-1) > 0 and (i-1) % 2 == 0:
is_good_kid = False
elif (i -1) > 0:
is_good_kid = True
parent_id = i // 2 - 1
if parent_id == -1:
self.nodes.append( Node( None, is_good_kid, self.ARCH_CODE_LEN, True ) )
else:
self.nodes.append( Node(self.nodes[parent_id], is_good_kid, self.ARCH_CODE_LEN, False) )
# self.loads_all_states()
self.ROOT = self.nodes[0]
self.CURT = self.ROOT
print('='*10 + 'search space start' + '='*10)
print("total architectures:", len(search_space) )
print('='*10 + 'search space end ' + '='*10)
self.init_train()