contrib/python/ofi_nccl/tuner/cli/wrapper.py (193 lines of code) (raw):

#!/usr/bin/env python3 import ctypes import logging import pandas as pd import numpy as np import os import pathlib from enum import Enum import functools class NCCLDebugLogLevel(int, Enum): NONE = 0 ERROR = 1 WARN = 2 INFO = 3 DEBUG = 4 TRACE = 5 class NCCLFunc(int, Enum): Broadcast = 0 Reduce = 1 AllGather = 2 ReduceScatter = 3 AllReduce = 4 SendRecv = 5 Send = 6 # FIXME: tuner does not handle this right. # Recv = 7 class NCCLAlgo(int, Enum): TREE = 0 RING = 1 COLLNET_DIRECT = 2 COLLNET_CHAIN = 3 NVLS = 4 NVLS_TREE = 5 PAT = 6 class NCCLProto(int, Enum): LL = 0 LL128 = 1 SIMPLE = 2 class TunerPlatform(str, Enum): P5en = "p5en.48xlarge" P5 = "p5.48xlarge" class Tuner: def _debug_logger_callback(self, level, flags, file, line, fmt, args): level_map = { NCCLDebugLogLevel.NONE: logging.NOTSET, NCCLDebugLogLevel.ERROR: logging.ERROR, NCCLDebugLogLevel.WARN: logging.WARNING, NCCLDebugLogLevel.INFO: logging.INFO, NCCLDebugLogLevel.DEBUG: logging.DEBUG, NCCLDebugLogLevel.TRACE: logging.DEBUG, # Python doesn't have a TRACE level } py_level = level_map.get(level, logging.DEBUG) message = f"NCCL [{file.decode('utf-8')}:{line}]: {fmt.decode('utf-8')} {args}" self.logger.log(py_level, message) def __init__( self, tuner_dso: pathlib.Path, nranks: int, nnodes: int, platform: TunerPlatform, log_level=logging.DEBUG ): self.nranks = nranks self.nnodes = nnodes self.platform = platform os.environ["OFI_NCCL_FORCE_PRODUCT_NAME"] = str(platform.value) self.logger = logging.getLogger("NCCLTuner") self.logger.setLevel(log_level) self.ncclDebugLogger_t = ctypes.CFUNCTYPE( None, ctypes.c_int, ctypes.c_ulong, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p, ) num_entries = len(NCCLAlgo) * len(NCCLProto) self.ncclTuner_v3_t = type( "ncclTuner_v3_t", (ctypes.Structure,), { "_fields_": [ ("name", ctypes.c_char_p), ( "init", ctypes.CFUNCTYPE( ctypes.c_int, ctypes.c_size_t, ctypes.c_size_t, ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p), ), ), ( "getCollInfo", ctypes.CFUNCTYPE( ctypes.c_int, ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t, ctypes.c_int, ctypes.POINTER(ctypes.c_float * num_entries), ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int), ), ), ("destroy", ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)), ] }, ) callback_func = self.ncclDebugLogger_t(self._debug_logger_callback) self._callback_func = callback_func if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) handler.setFormatter(formatter) self.logger.addHandler(handler) self.context = ctypes.c_void_p() self.lib = ctypes.CDLL(tuner_dso) self.libc = ctypes.CDLL("libc.so.6") self.tuner = self.ncclTuner_v3_t.in_dll(self.lib, "ncclTunerPlugin_v3") result = self.tuner.init( nranks, nnodes, callback_func, ctypes.byref(self.context) ) if result != 0: raise RuntimeError(f"Failed to initialize NCCL tuner. Error code: {result}") @functools.lru_cache(maxsize=None) def get_coll_info(self, coll_type: NCCLFunc, msg_size: int, num_pipe_ops: int = 1): if self.context is None: raise RuntimeError("NCCL tuner not initialized. Call initialize() first.") nBytes = ctypes.c_size_t(msg_size) numPipeOps = ctypes.c_int(num_pipe_ops) num_entries = len(NCCLAlgo) * len(NCCLProto) cost_table_array = (ctypes.c_float * num_entries)() for i in range(len(cost_table_array)): cost_table_array[i] = float(1337) nChannels = ctypes.c_int(0) result = self.tuner.getCollInfo( self.context, coll_type.value, nBytes, numPipeOps, ctypes.byref(cost_table_array), len(NCCLAlgo), len(NCCLProto), ctypes.byref(nChannels), ) if result != 0: raise RuntimeError(f"Failed to get collective info. Error code: {result}") decision = (np.nan, np.nan) for algo in NCCLAlgo: for proto in NCCLProto: cost = cost_table_array[algo.value * len(NCCLProto) + proto.value] if cost != float(1337): decision = (algo, proto) return { "algo": decision[0], "proto": decision[1], } def analyze_message_range(self, coll_type: NCCLFunc, min_size: int, max_size: int): def bisect(min_size, max_size): min_decision = self.get_coll_info(coll_type, min_size) max_decision = self.get_coll_info(coll_type, max_size) if min_decision == max_decision or max_size - min_size <= 1: return [(min_size, min_decision), (max_size, max_decision)] mid_size = (min_size + max_size) // 2 left_results = bisect(min_size, mid_size) right_results = bisect(mid_size, max_size) combined = left_results + right_results[1:] return [combined[0]] + [ combined[i] for i in range(1, len(combined)) if combined[i][1] != combined[i - 1][1] ] edges = bisect(min_size, max_size) return pd.DataFrame( [ { "collective": coll_type.name, "message_size": size, "algo": decision["algo"], "proto": decision["proto"], "ranks": self.nranks, "nodes": self.nnodes, "platform": self.platform, } for size, decision in edges ] ) def analyze_all(self, min_size: int = 32, max_size: int = 32 * 1024 * 1024 * 1024): return pd.concat( [self.analyze_message_range(c, min_size, max_size) for c in NCCLFunc] ) def __del__(self): self.tuner.destroy(self.context) del self.context del self.tuner