def _get_gpus_owned()

in spark/spark-tensorflow-distributor/spark_tensorflow_distributor/mirrored_strategy_runner.py [0:0]


    def _get_gpus_owned(resources, gpu_resource_name):
        """
        Gets the number of GPUs that Spark scheduled to the calling task.

        Returns:
            The number of GPUs that Spark scheduled to the calling task.
        """
        if gpu_resource_name in resources:
            addresses = resources[gpu_resource_name].addresses
            pattern = re.compile('^[1-9][0-9]*|0$')
            if any(not pattern.match(address) for address in addresses):
                raise ValueError(f'Found GPU addresses {addresses} which '
                                 'are not all in the correct format '
                                 'for CUDA_VISIBLE_DEVICES, which requires '
                                 'integers with no zero padding.')
            if 'CUDA_VISIBLE_DEVICES' in os.environ:
                gpu_indices = list(map(int, addresses))
                gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
                gpu_owned = [gpu_list[i] for i in gpu_indices]
                return gpu_owned
            return addresses
        raise ValueError(
            f'The provided GPU resource name `{gpu_resource_name}` '
            'was not found in the '
            f'context resources. Contact your cluster administrator '
            'to make sure that the '
            f'spark.task.resource.{gpu_resource_name}, '
            f'spark.worker.resource.{gpu_resource_name}, '
            f'spark.executor.resource.{gpu_resource_name}, and '
            f'spark.driver.resource.{gpu_resource_name} confs are '
            'set and that the '
            f'GPU resource name `{gpu_resource_name}` matches '
            'those confs correctly.')