in LaNAS/one-shot_LaNAS/LaNAS/MCTS.py [0:0]
def __init__(self, search_space, trainer, tree_height ):
self.ARCH_CODE_LEN = int( len( search_space["b"] ) / 2 )
self.SEARCH_COUNTER = 0
self.samples = {}
self.nodes = []
# search space is a tuple,
# 0: left side of the constraint, i.e. A
# 1: right side of the constraint, i.e. b
self.search_space = search_space
self.Cp = 10
self.trainer = trainer
# pre-defined for generating masks for supernet
print("architecture code length:", self.ARCH_CODE_LEN )
# set random seed
np.random.seed(seed=int(time.time() ) )
random.seed(datetime.now() )
#initialize the a full tree
total_nodes = 2**tree_height - 1
for i in range(1, total_nodes + 1):
is_good_kid = False
if (i-1) > 0 and (i-1) % 2 == 0:
is_good_kid = False
elif (i -1) > 0:
is_good_kid = True
parent_id = i // 2 - 1
if parent_id == -1:
self.nodes.append( Node( None, is_good_kid, self.ARCH_CODE_LEN, True ) )
else:
self.nodes.append( Node(self.nodes[parent_id], is_good_kid, self.ARCH_CODE_LEN, False) )
self.ROOT = self.nodes[0]
self.CURT = self.ROOT
print('='*10 + 'search space start' + '='*10)
print("total architectures: 2^", len(search_space) )
print('='*10 + 'search space end ' + '='*10)