horovod/runner/elastic/registration.py (111 lines of code) (raw):
# Copyright 2020 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
import threading
from collections import defaultdict
from horovod.runner.elastic import constants
READY = 'READY'
SUCCESS = 'SUCCESS'
FAILURE = 'FAILURE'
class WorkerStateRegistry(object):
def __init__(self, driver, host_manager, reset_limit=None, verbose=False):
self._driver = driver
self._host_manager = host_manager
self._reset_limit = reset_limit
self._reset_count = 0
self._lock = threading.Lock()
self._states = {}
self._workers = defaultdict(set)
self._barrier = None
self._rendezvous_id = 0
self._verbose = verbose
self._size = 0
def get_recorded_slots(self):
return self._states.keys()
def get(self, state):
return self._workers[state]
def count(self, state):
return len(self._workers[state])
def reset(self, size):
with self._lock:
logging.info('reset workers: {}'.format(size))
self._states.clear()
self._workers.clear()
self._barrier = threading.Barrier(parties=size, action=self._action)
self._rendezvous_id += 1
self._size = size
def size(self):
return self._size
def last_rendezvous(self):
return self._rendezvous_id
def record_ready(self, host, slot):
return self._record_state(host, slot, READY)
def record_success(self, host, slot):
return self._record_state(host, slot, SUCCESS)
def record_failure(self, host, slot):
return self._record_state(host, slot, FAILURE)
def _record_state(self, host, slot, state):
if self._driver.finished():
logging.info('driver finished, ignoring registration: {}[{}] = {}'.format(host, slot, state))
return self._rendezvous_id
if self._host_manager.is_blacklisted(host):
logging.warning('host registers state %s but is already blacklisted, ignoring: %s', state, host)
return self._rendezvous_id
key = (host, slot)
with self._lock:
if key in self._states:
if state == FAILURE:
# Worker originally recorded itself as READY, but the worker failed while waiting at the barrier. As
# such, we need to update the state to FAILURE, and we don't want two threads coming from the same
# worker at the barrier.
#
# In order to ensure that the new failing thread can record results in cases of total job failure,
# we also need to block this thread by waiting on the barrier. This requires us to reset the barrier,
# as otherwise this worker will be double-counted (once for the READY thread and once for FAILURE),
# which would cause the barrier to complete too early.
logging.info('key exists, reset barrier: {}[{}] = {} -> {}'
.format(host, slot, self._states[key], state))
self._barrier.reset()
else:
logging.error('key exists and new state %s not FAILURE, '
'ignoring (current state is %s)', state, self._states[key])
if key not in self._states or state == FAILURE:
logging.info('record state: {}[{}] = {}'.format(host, slot, state))
self._states[key] = state
self._workers[state].add(key)
rendezvous_id = self._rendezvous_id
rendezvous_id = self._wait(key, state, rendezvous_id)
return rendezvous_id
def _wait(self, key, state, rendezvous_id):
while True:
try:
self._barrier.wait()
return rendezvous_id
except threading.BrokenBarrierError:
if self._barrier.broken:
# Timeout or other non-recoverable error, so exit
raise
# Barrier has been reset
with self._lock:
# Check to make sure the reset was not caused by a change of state for this key
rendezvous_id = self._rendezvous_id
saved_state = self._states.get(key, state)
if saved_state != state:
# This worker changed its state, so do not attempt to wait again to avoid double-counting
raise RuntimeError('State {} overridden by {}'.format(state, saved_state))
def _action(self):
self._on_workers_recorded()
def _on_workers_recorded(self):
logging.info('all {} workers recorded'.format(self.size()))
# Check for success state, if any process succeeded, shutdown all other processes
if self.count(SUCCESS) > 0:
logging.info('success count == {} -> stop running'.format(self.count(SUCCESS)))
self._driver.stop()
return
# Check that all processes failed, indicating that processing should stop
if self.count(FAILURE) == self._size:
logging.error('failure count == {} -> stop running'.format(self._size))
self._driver.stop()
return
# Check for failures, and add them to the blacklisted hosts list
failures = self.get(FAILURE)
for host, slot in failures:
self._host_manager.blacklist(host)
# If every active host is blacklisted, then treat this as job failure
if all([self._host_manager.is_blacklisted(host) for host, slot in self.get_recorded_slots()]):
logging.error('blacklisted slots count == {} -> stop running'.format(self._size))
self._driver.stop()
return
# Check that we have already reset the maximum number of allowed times
if self._reset_limit is not None and self._reset_count >= self._reset_limit:
logging.error('reset count {} has exceeded limit {} -> stop running'
.format(self._reset_count, self._reset_limit))
self._driver.stop(error_message=constants.RESET_LIMIT_EXCEEDED_MESSAGE.format(self._reset_limit))
return
try:
self._reset_count += 1
self._driver.resume()
except Exception:
logging.exception('failed to activate new hosts -> stop running')
self._driver.stop()