def run()

in torch_xla/distributed/xla_dist.py [0:0]


  def run(self, cmd):
    self.trials = 0
    while self.trials <= self.MAX_TPU_RETRY:
      try:
        self.logger.info(
            'Command to distribute: {}'.format(concat_cmd_list(cmd)),
            extra={
                'clientip': '',
                'ordinal': ''
            })
        self.logger.info(
            f'Cluster configuration: {self._cluster}',
            extra={
                'clientip': '',
                'ordinal': ''
            })

        script_map = self._prepare_scripts(cmd)
        proc = multiprocessing.Process(target=self._run_cmd, args=(script_map,))
        proc.start()
        while True:
          if not proc.is_alive():
            sys.exit(proc.exitcode)
          if len(self._cluster.list_tpus_with_health(
              'UNHEALTHY_MAINTENANCE')) != 0:
            # TPU Maintenance: kill all training, wait for healthy, and restart
            break
          if not self._error_queue.empty():
            # Potential HostError on GCE VM: kill all, wait, and restart
            self.logger.warning(
                self._error_queue.get(), extra={
                    'clientip': '',
                    'ordinal': ''
                })
            break

          proc.join(10)

        # First wait for VMs to come back then cleanup all others
        self._cluster.wait_for_healthy_client(self)
        self._cleanup(script_map)
        proc.terminate()
        self._cluster.wait_for_healthy_service()
        self.trials += 1
      except KeyboardInterrupt:
        self.logger.info(
            'Cleaning up processes (takes a couple of seconds)',
            extra={
                'clientip': '',
                'ordinal': ''
            })
        self._cleanup(script_map)
        sys.exit(128 + signal.SIGINT)

    self.logger.info(
        'Max number of retries reached.', extra={
            'clientip': '',
            'ordinal': ''
        })