def _get_spark_task_program()

in spark/spark-tensorflow-distributor/spark_tensorflow_distributor/mirrored_strategy_runner.py [0:0]


    def _get_spark_task_program(self, train_fn, **kwargs):
        num_slots = self._num_slots
        use_custom_strategy = self._use_custom_strategy
        gpu_resource_name = self._gpu_resource_name
        num_tasks = self._num_tasks
        use_gpu = self._use_gpu
        run_tensorflow_program = MirroredStrategyRunner._run_tensorflow_program

        # Spark task program
        def wrapped_train_fn(_):
            import json
            import logging
            import os
            import socket
            from contextlib import closing
            from pyspark import BarrierTaskContext

            # Sets the TF_CONFIG env var so TF servers
            # can communicate with each other
            def set_tf_config(context):
                addrs = [
                    e.address.split(':')[0] for e in context.getTaskInfos()
                ]
                my_addr = addrs[context.partitionId()]
                with closing(socket.socket(socket.AF_INET,
                                           socket.SOCK_STREAM)) as my_sock:
                    my_sock.bind(('', 0))
                    _, my_port = my_sock.getsockname()
                    my_endpoint = "{}:{}".format(my_addr, my_port)
                    worker_endpoints = context.allGather(my_endpoint)
                cluster = {'worker': worker_endpoints}
                tf_config = {
                    'cluster': cluster,
                    'task': {
                        'type': 'worker',
                        'index': context.partitionId()
                    }
                }
                os.environ['TF_CONFIG'] = json.dumps(tf_config)

            # Sets the CUDA_VISIBLE_DEVICES env var so only
            # the appropriate GPUS are used
            def set_gpus(context):
                gpus_owned = MirroredStrategyRunner._get_gpus_owned(
                    context.resources(), gpu_resource_name)

                my_num_gpus = (num_slots //
                               num_tasks) + (context.partitionId() <
                                             (num_slots % num_tasks))
                gpu_addresses = [
                    str(e) for e in random.sample(gpus_owned, my_num_gpus)
                ]
                logging.info(f'Using GPU addresses: {gpu_addresses}')
                os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(gpu_addresses)

            context = BarrierTaskContext.get()
            if use_gpu:
                set_gpus(context)
            else:
                os.environ['CUDA_VISIBLE_DEVICES'] = ''
            set_tf_config(context)
            result = run_tensorflow_program(train_fn, use_custom_strategy,
                                            **kwargs)
            if context.partitionId() == 0:
                return [result]
            return [None]

        return wrapped_train_fn