in torch_xla/distributed/xla_dist.py [0:0]
def run(self, cmd):
self.trials = 0
while self.trials <= self.MAX_TPU_RETRY:
try:
self.logger.info(
'Command to distribute: {}'.format(concat_cmd_list(cmd)),
extra={
'clientip': '',
'ordinal': ''
})
self.logger.info(
f'Cluster configuration: {self._cluster}',
extra={
'clientip': '',
'ordinal': ''
})
script_map = self._prepare_scripts(cmd)
proc = multiprocessing.Process(target=self._run_cmd, args=(script_map,))
proc.start()
while True:
if not proc.is_alive():
sys.exit(proc.exitcode)
if len(self._cluster.list_tpus_with_health(
'UNHEALTHY_MAINTENANCE')) != 0:
# TPU Maintenance: kill all training, wait for healthy, and restart
break
if not self._error_queue.empty():
# Potential HostError on GCE VM: kill all, wait, and restart
self.logger.warning(
self._error_queue.get(), extra={
'clientip': '',
'ordinal': ''
})
break
proc.join(10)
# First wait for VMs to come back then cleanup all others
self._cluster.wait_for_healthy_client(self)
self._cleanup(script_map)
proc.terminate()
self._cluster.wait_for_healthy_service()
self.trials += 1
except KeyboardInterrupt:
self.logger.info(
'Cleaning up processes (takes a couple of seconds)',
extra={
'clientip': '',
'ordinal': ''
})
self._cleanup(script_map)
sys.exit(128 + signal.SIGINT)
self.logger.info(
'Max number of retries reached.', extra={
'clientip': '',
'ordinal': ''
})