in spark/spark-tensorflow-distributor/spark_tensorflow_distributor/mirrored_strategy_runner.py [0:0]
def run(self, train_fn, **kwargs):
"""
Args:
train_fn: Function that contains TensorFlow training code.
If it constructs its own tensorflow.distribute.Strategy
object, then construct MirroredStrategyRunner with
use_custom_strategy set to `True`.
kwargs: keyword arguments passed to the training function
at invocation time. When train_fn is called, it will
be called with train_fn(**kwargs).
Returns:
Return value of the training function
from the chief training worker (partition ID 0) in
distributed mode, or the direct return value of train_fn in
local mode.
"""
spark_task_program = self._get_spark_task_program(train_fn, **kwargs)
# Run in local mode
if self._local_mode:
old_cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES',
'')
cuda_state_was_set = 'CUDA_VISIBLE_DEVICES' in os.environ
try:
if self._use_gpu:
# TODO: handle the case that driver gpu resources
# is not properly configured
gpus_owned = MirroredStrategyRunner._get_gpus_owned(
self.sc.resources, self._gpu_resource_name)
num_gpus_owned = len(gpus_owned)
if self._num_slots > num_gpus_owned:
raise ValueError(
f'{self._num_slots} slots were requested '
'for local training with '
f'GPU training but only {num_gpus_owned} GPUs '
'were available.')
# TODO: Check GPU utilization to avoid resource contention
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
str(e)
for e in random.sample(gpus_owned, self._num_slots))
else:
if self._num_slots > 1:
raise ValueError(f'Cannot run with more than 1 CPU '
'machine in local mode. '
'Try setting num_slots to -1.')
os.environ['CUDA_VISIBLE_DEVICES'] = ''
result = MirroredStrategyRunner._run_tensorflow_program(
train_fn, self._use_custom_strategy, **kwargs)
finally:
if cuda_state_was_set:
os.environ[
'CUDA_VISIBLE_DEVICES'] = old_cuda_visible_devices
else:
del os.environ['CUDA_VISIBLE_DEVICES']
return result
# Run in distributed mode
self._check_encryption()
self._logger.info('Distributed training in progress...')
self._logger.info(
'View Spark executor stderr logs to inspect training...')
result = self.sc.parallelize(range(self._num_tasks), self._num_tasks) \
.barrier() \
.mapPartitions(spark_task_program) \
.collect()[0]
self._logger.info(f'Training with {self._num_slots} slots is complete!')
return result