maga_transformer/distribute/worker_info.py (253 lines of code) (raw):
from __future__ import annotations
import os
import json
import socket
import torch
import logging
from typing import Any, Dict
from dataclasses import dataclass
DEFAULT_START_PORT = 8088
MASTER_INFO_PORT_NUM = 11
MIN_WORKER_INFO_PORT_NUM = 7
WORKER_INFO_PORT_NUM = MIN_WORKER_INFO_PORT_NUM
def get_worker_port_num():
global WORKER_INFO_PORT_NUM
global MIN_WORKER_INFO_PORT_NUM
WORKER_INFO_PORT_NUM = int(os.environ.get('WORKER_INFO_PORT_NUM', MIN_WORKER_INFO_PORT_NUM))
logging.info(f'env WORKER_INFO_PORT_NUM: {WORKER_INFO_PORT_NUM}')
if WORKER_INFO_PORT_NUM < MIN_WORKER_INFO_PORT_NUM:
raise Exception(f"env worker info port num {WORKER_INFO_PORT_NUM} " \
f"is small than min worker info port num {MIN_WORKER_INFO_PORT_NUM}")
get_worker_port_num()
class FrontendServerInfo(object):
def __init__(self, frontend_server_id: int):
self.frontend_server_id = frontend_server_id
@staticmethod
def from_env() -> FrontendServerInfo:
return FrontendServerInfo.from_params(dict(os.environ))
@staticmethod
def from_params(params: Dict[str, str]) -> FrontendServerInfo:
info = FrontendServerInfo(
frontend_server_id=int(params.get('FRONTEND_SERVER_ID', '0')))
return info
def reload(self):
new_info = self.from_env()
self.frontend_server_id = new_info.frontend_server_id
def __str__(self):
return f"FrontendServerInfo:[ frontend_server_id={self.frontend_server_id} ]"
g_frontend_server_info = FrontendServerInfo.from_env()
class ParallelInfo(object):
# EP从TP里分
def __init__(
self, tp_size: int, ep_size: int,
pp_size: int, dp_size: int, ffn_sp_size: int,
world_size: int, world_rank: int,
local_world_size: int
):
self.tp_size = tp_size
self.ep_size = ep_size
self.pp_size = pp_size
self.dp_size = dp_size
self.ffn_sp_size = ffn_sp_size
self.ffn_tp_size = self.tp_size // self.ffn_sp_size
self.world_size = world_size
self.world_rank = world_rank
self.local_world_size = local_world_size
logging.info(f"ParallelInfo:[ tp_size={self.tp_size} ep_size={self.ep_size} pp_size={self.pp_size} world_size={self.world_size} world_rank={self.world_rank} local_world_size={self.local_world_size} ffn_sp_size={self.ffn_sp_size} ffn_tp_size={self.ffn_tp_size}]")
assert ep_size <= world_size and world_size % ep_size == 0
assert self.world_size == self.tp_size * self.dp_size * self.pp_size
if torch.cuda.is_available():
self.device = 'cuda:' + str(self.world_rank % self.local_world_size)
else:
self.device = 'cpu'
@property
def tp_rank(self) -> int:
return self.world_rank % self.tp_size
@property
def dp_rank(self) -> int:
return self.world_rank // self.tp_size
# ep_rank只在MOE plugin生效
@property
def ep_rank(self) -> int:
return self.world_rank % self.ep_size
@property
def ffn_tp_rank(self) -> int:
return self.tp_rank % self.ffn_tp_size
@property
def local_rank(self) -> int:
return self.world_rank % self.local_world_size
@property
def is_master(self):
return self.world_rank == 0
@staticmethod
def from_env() -> ParallelInfo:
return ParallelInfo.from_params(dict(os.environ))
@staticmethod
def from_params(params: Dict[str, str]) -> ParallelInfo:
world_size = int(params.get('WORLD_SIZE', '1'))
if 'LOCAL_WORLD_SIZE' in params:
local_world_size = int(params['LOCAL_WORLD_SIZE'])
else:
local_world_size = min(torch.cuda.device_count(), world_size)
local_world_size = max(local_world_size, 1) # make sure local_world_size >= 1
info = ParallelInfo(
tp_size=int(params.get('TP_SIZE', '1')),
ep_size=int(params.get('EP_SIZE', params.get('WORLD_SIZE', '1'))),
pp_size=int(params.get('PP_SIZE', '1')),
dp_size=int(params.get('DP_SIZE', 1)),
ffn_sp_size = int(params.get('FFN_SP_SIZE', '1')),
world_size=world_size,
world_rank=int(params.get('WORLD_RANK', '0')),
local_world_size=local_world_size)
if (torch.cuda.is_available() and (info.local_world_size > torch.cuda.device_count())):
raise Exception(f'local_world_size:{info.local_world_size} > cuda device count:{torch.cuda.device_count()}')
if (info.tp_size * info.pp_size * info.dp_size != info.world_size or
info.world_rank >= info.world_size or (info.tp_size % info.ffn_sp_size != 0)):
raise Exception(f'tp_size:{info.tp_size}, ep_size:{info.ep_size}, pp_size:{info.pp_size}, world_size:{info.world_size}, world_rank:{info.world_rank} ffn_sp_size: {info.ffn_sp_size} invalid world config')
# 假设 GPU 均匀分布,可以整除
if info.world_size % info.local_world_size != 0:
raise Exception(f"not support info.world_size:[{info.world_size}] mod info.local_world_size:[{info.local_world_size}] != 0")
if torch.cuda.is_available():
torch.cuda.set_device(info.local_rank)
if os.environ.get("ACCL_SELECT_PATH") == "1":
select_port = str(info.local_rank % 2)
os.environ["ACCL_SELECT_PORT"] = select_port
logging.info(f"local rank {info.local_rank} set accl select port to {select_port} ")
if os.environ.get("ACCL_USE_NICS") == None and os.environ.get("ACCL_NIC_GPU_AFFINITY") != None:
content = os.environ.get("ACCL_NIC_GPU_AFFINITY")
try:
gpu_nic_affinity = json.loads(content) # 验证内容是否为合法 JSON
if str(info.local_rank) in gpu_nic_affinity:
affinity_nic = gpu_nic_affinity[str(info.local_rank)]
os.environ["ACCL_USE_NICS"] = affinity_nic
logging.info(f"local rank {info.local_rank} use cuda device {info.local_rank} set ACCL_USE_NICS to {affinity_nic}")
else:
logging.info(f"local rank {info.local_rank} use cuda device {info.local_rank} get affinity nic failed, content is {content}")
except json.JSONDecodeError:
logging.info(f"try decode ACCL_NIC_GPU_AFFINITY failed, content is {content}")
return info
# used for ut
def reload(self):
new_info = self.from_env()
self.tp_size=new_info.tp_size
self.pp_size=new_info.pp_size
self.world_size=new_info.world_size
self.world_rank=new_info.world_rank
self.local_world_size=new_info.local_world_size
def __str__(self):
return f"ParallelInfo:[ tp_size={self.tp_size} pp_size={self.pp_size} world_size={self.world_size} world_rank={self.world_rank} local_world_size={self.local_world_size} tp_rank={self.tp_rank} dp_rank={self.dp_rank} ep_size={self.ep_size} dp_size={self.dp_size} ep_rank={self.ep_rank} local_rank={self.local_rank} ffn_sp_size={self.ffn_sp_size} ]"
g_parallel_info = ParallelInfo.from_env()
class WorkerInfo(object):
def __init__(self, ip: str, server_port: int, gang_hb_port: int,
http_port: int, rpc_server_port: int, remote_rpc_server_port: int,
cache_store_listen_port: int, cache_store_connect_port: int, cache_store_rdma_listen_port: int,
cache_store_rdma_connect_port: int, backend_server_port: int,
local_rank: int, world_rank: int, name: str, info: Any):
self.ip = ip
self.server_port = server_port
self.gang_hb_port = gang_hb_port
self.http_port = http_port
self.rpc_server_port= rpc_server_port
self.remote_rpc_server_port = remote_rpc_server_port
self.cache_store_listen_port = cache_store_listen_port
self.cache_store_connect_port = cache_store_connect_port
self.cache_store_rdma_listen_port = cache_store_rdma_listen_port
self.cache_store_rdma_connect_port = cache_store_rdma_connect_port
self.backend_server_port = backend_server_port
self.local_rank: int = local_rank
self.world_rank: int = world_rank
self.name = name
self.info = info
def equals(self, other: 'WorkerInfo') -> bool:
return self.ip == other.ip and self.server_port == other.server_port
@staticmethod
def from_env():
info = WorkerInfo(
ip=socket.gethostbyname(socket.gethostname()),
server_port=WorkerInfo.server_port_offset(g_parallel_info.local_rank),
gang_hb_port=WorkerInfo.gang_hb_port_offset(g_parallel_info.local_rank),
http_port=WorkerInfo.http_port_offset(g_parallel_info.local_rank),
rpc_server_port=WorkerInfo.rpc_server_port_offset(g_parallel_info.local_rank),
remote_rpc_server_port=WorkerInfo.rpc_server_port_offset(g_parallel_info.local_rank, int(os.environ.get("REMOTE_SERVER_PORT", 0))),
cache_store_listen_port=WorkerInfo.cache_store_listen_port_offset(g_parallel_info.local_rank),
cache_store_connect_port=WorkerInfo.cache_store_listen_port_offset(g_parallel_info.local_rank, int(os.environ.get("REMOTE_SERVER_PORT", 0))),
cache_store_rdma_listen_port=WorkerInfo.cache_store_rdma_listen_port_offset(g_parallel_info.local_rank),
cache_store_rdma_connect_port=WorkerInfo.cache_store_rdma_listen_port_offset(g_parallel_info.local_rank, int(os.environ.get("REMOTE_SERVER_PORT", 0))),
backend_server_port=WorkerInfo.backend_server_port_offset(g_parallel_info.local_rank),
local_rank=g_parallel_info.local_rank,
world_rank=g_parallel_info.world_rank,
name='', info=None)
return info
@staticmethod
def self_server_port():
return int(os.environ.get('START_PORT', DEFAULT_START_PORT))
@staticmethod
def server_port_offset(local_rank: int, server_port: int = -1) -> int:
if server_port != -1:
base_port = server_port
else:
base_port = WorkerInfo.self_server_port()
return base_port + local_rank * WORKER_INFO_PORT_NUM
@staticmethod
def rpc_server_port_offset(local_rank: int, server_port: int = -1) -> int:
return WorkerInfo.server_port_offset(local_rank, server_port) + 1
@staticmethod
def cache_store_listen_port_offset(local_rank: int, server_port: int = -1) -> int:
return WorkerInfo.server_port_offset(local_rank, server_port) + 2
@staticmethod
def gang_hb_port_offset(local_rank: int, server_port: int = -1) -> int:
return WorkerInfo.server_port_offset(local_rank, server_port) + 3
@staticmethod
def cache_store_rdma_listen_port_offset(local_rank: int, server_port: int = -1) -> int:
return WorkerInfo.server_port_offset(local_rank, server_port) + 4
@staticmethod
def http_port_offset(local_rank: int, server_port: int = -1) -> int:
return WorkerInfo.server_port_offset(local_rank, server_port) + 5
@staticmethod
def backend_server_port_offset(local_rank: int, server_port: int = -1) -> int:
return WorkerInfo.server_port_offset(local_rank, server_port) + 6
# used for ut
def reload(self):
new_info = self.from_env()
self.ip = new_info.ip
self.server_port = new_info.server_port
self.gang_hb_port = new_info.gang_hb_port
self.http_port = new_info.http_port
self.remote_rpc_server_port = new_info.remote_rpc_server_port
self.cache_store_listen_port = new_info.cache_store_listen_port
self.cache_store_connect_port = new_info.cache_store_connect_port
self.rpc_server_port = new_info.rpc_server_port
self.backend_server_port = new_info.backend_server_port
self.local_rank = new_info.local_rank
self.world_rank = new_info.world_rank
self.name = new_info.name
self.info = new_info.info
def __str__(self):
return f"""
WorkerInfo: [ip={self.ip} server_port={self.server_port} gang_hb_port={self.gang_hb_port}
http_port={self.http_port} rpc_port={self.rpc_server_port} backend_server_port={self.backend_server_port}
cache_store_listen_port={self.cache_store_listen_port} cache_store_connect_port={self.cache_store_connect_port} remote_rpc_server_port={self.remote_rpc_server_port}
local_rank={self.local_rank} world_rank={self.world_rank} name={self.name} info={self.info} ]
"""
g_worker_info = WorkerInfo.from_env()
@dataclass
class MasterInfo:
ip: str
th_nccl_port: int
tp_nccl_port: int
nccl_op_port: int
sp_gpt_nccl_port: int
dp_tp_nccl_port: int
ffn_tp_nccl_port: int
g_master_info = MasterInfo(
ip='',
th_nccl_port=0,
tp_nccl_port = 0,
nccl_op_port=0,
sp_gpt_nccl_port=0,
dp_tp_nccl_port=0,
ffn_tp_nccl_port=0,
)
def update_master_info(ip: str, base_port: int):
g_master_info.ip = ip
g_master_info.dp_tp_nccl_port = base_port - 10
base_port -= g_parallel_info.dp_rank * MASTER_INFO_PORT_NUM
g_master_info.th_nccl_port = base_port - 1
g_master_info.tp_nccl_port = base_port - 2
g_master_info.nccl_op_port = base_port - 3
g_master_info.sp_gpt_nccl_port = base_port - 4
# note: reserve 4 ports for ffn_tp_nccl_port
g_master_info.ffn_tp_nccl_port = base_port - 5
if g_parallel_info.ffn_sp_size != g_parallel_info.tp_size:
base_port -= g_parallel_info.ffn_sp_size
def total_need_port_num() -> int:
return MASTER_INFO_PORT_NUM * g_parallel_info.dp_size + WORKER_INFO_PORT_NUM * g_parallel_info.tp_size