optimum_benchmark/launchers/base.py (71 lines of code) (raw):

import os import shutil import sys import tempfile from abc import ABC from contextlib import contextmanager from logging import getLogger from multiprocessing import Process, set_executable from typing import Any, Callable, ClassVar, Generic, List, Optional from ..benchmark.report import BenchmarkReport from ..system_utils import is_nvidia_system, is_rocm_system from .config import LauncherConfigT from .device_isolation_utils import assert_device_isolation NUMA_EXECUTABLE_CONTENT = """#!/bin/bash echo "Running with numactl wrapper" echo "numactl path: {numactl_path}" echo "numactl args: {numactl_args}" echo "python path: {python_path}" echo "python args: $@" {numactl_path} {numactl_args} {python_path} "$@" """ class Launcher(Generic[LauncherConfigT], ABC): NAME: ClassVar[str] config: LauncherConfigT def __init__(self, config: LauncherConfigT): self.config = config self.logger = getLogger(self.NAME) self.logger.info(f"Allocated {self.NAME} launcher") def launch(self, worker: Callable[..., BenchmarkReport], worker_args: List[Any]) -> BenchmarkReport: raise NotImplementedError("Launcher must implement launch method") @contextmanager def device_isolation(self, pid: int, device_ids: Optional[str] = None): if device_ids is None: if is_rocm_system(): device_ids = os.environ.get("ROCR_VISIBLE_DEVICES", None) elif is_nvidia_system(): device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", None) self.device_isolation_process = Process( target=assert_device_isolation, kwargs={"action": self.config.device_isolation_action, "device_ids": device_ids, "pid": pid}, daemon=True, ) self.device_isolation_process.start() self.logger.info(f"\t+ Isolating device(s) [{device_ids}] for process [{pid}] and its children") self.logger.info(f"\t+ Executing action [{self.config.device_isolation_action}] in case of violation") yield self.logger.info("\t+ Stopping device isolation process") self.device_isolation_process.terminate() self.device_isolation_process.join() self.device_isolation_process.close() @contextmanager def numactl_executable(self): self.logger.info("\t+ Warming up multiprocessing context") dummy_process = Process(target=dummy_target, daemon=False) dummy_process.start() dummy_process.join() dummy_process.close() self.logger.info("\t+ Creating numactl wrapper executable for multiprocessing") python_path = sys.executable numactl_path = shutil.which("numactl") if numactl_path is None: raise RuntimeError("ِCould not find numactl executable. Please install numactl and try again.") numactl_args = " ".join([f"--{key}={value}" for key, value in self.config.numactl_kwargs.items()]) numa_executable = tempfile.NamedTemporaryFile(delete=False, prefix="numa_executable_", suffix=".sh") numa_executable_content = NUMA_EXECUTABLE_CONTENT.format( numactl_path=numactl_path, numactl_args=numactl_args, python_path=python_path ) numa_executable.write(numa_executable_content.encode()) os.chmod(numa_executable.name, 0o777) numa_executable.close() self.logger.info("\t+ Setting multiprocessing executable to numactl wrapper") set_executable(numa_executable.name) yield self.logger.info("\t+ Resetting default multiprocessing executable") os.unlink(numa_executable.name) set_executable(sys.executable) def dummy_target() -> None: exit(0)