in LaNAS/Distributed_LaNAS/clientX/client.py [0:0]
def train(self):
while True:
while not self.received:
try:
send_address = ('100.97.66.131', 8000)
conn = Client(send_address, authkey=b'nasnet')
if conn.poll(2):
[ self.network ] = conn.recv()
self.total_recv += 1
conn.close()
self.received = True
self.dump_client()
print("RECEIEVE:=>", self.network)
print("RECEIEVE:=>", " total_send:", self.total_send, " total_recv:", self.total_recv)
self.print_client_status()
except Exception as e:
print(e)
print(traceback.format_exc())
print("client recv error")
if self.received:
print("prepare training the network:", self.network)
network = np.array( self.network, dtype = 'int' )
network = network.tolist()
net = gen_code_from_list( network, node_num=7 ) #TODO: change it to 7
net_str = json.dumps( network )
if net_str in self.accuracy_trace:
self.acc = self.accuracy_trace[net_str]
else:
genotype_net = translator([net, net], max_node=7) #TODO: change it to 7
print("--"*15)
print(genotype_net)
print("training the above network")
print("--"*15)
self.acc = train_client.run(genotype_net, epochs=600, batch_size=200)
self.accuracy_trace[net_str] = self.acc
self.dump_acc_trace()
#TODO: train the actual network
#time.sleep(random.randint(2, 5) )
while self.received:
try:
recv_address = ('100.97.66.131', 8000)
conn = Client(recv_address, authkey=b'nasnet')
network_str = json.dumps( np.array(network).tolist() )
conn.send([self.client_name, network_str, self.acc])
self.total_send += 1
print("SEND:=>", self.network, self.acc)
self.network = []
self.acc = 0
self.received = False
self.dump_client()
print("SEND:=>", " total_send:", self.total_send, " total_recv:", self.total_recv)
conn.close()
except Exception as e:
print(e)
print(traceback.format_exc())
print("client send error, reconnecting")