maga_transformer/distribute/gang_info.py (137 lines of code) (raw):
import json
import logging
import os
import socket
from typing import NamedTuple, List, Any, Dict, Optional
from maga_transformer.distribute.worker_info import g_worker_info, g_parallel_info, WorkerInfo
CONFIG_FILE_ENV = 'DISTRIBUTE_CONFIG_FILE'
def members_from_json(gang_info_json: Dict[str, Any]) -> List[WorkerInfo]:
members: List[WorkerInfo] = []
# here is only the fake ip
for name, info in gang_info_json.items():
server_port = info['port'] if 'port' in info else -1
members.append(WorkerInfo(
server_port=server_port,
gang_hb_port=-1,
http_port=-1,
rpc_server_port=-1,
backend_server_port=-1,
remote_rpc_server_port=-1,
cache_store_listen_port=-1,
cache_store_connect_port=-1,
cache_store_rdma_connect_port=-1,
cache_store_rdma_listen_port=-1,
local_rank=0,
world_rank=0,
name=info['name'], ip=info['ip'], info=info))
zone_name = os.environ.get("ZONE_NAME", "")
if zone_name:
members = [member for member in members if member.name.split('_')[-2] == zone_name]
masters = [member for member in members if member.name.endswith('part0')]
if len(masters) != 1:
raise Exception(f"gang master should contains 1 but got {len(masters)}")
return sorted(members, key=lambda x:x.name)
'''
test env example:
name:smoke_part0,ip:127.0.0.1,port:13045;name:smoke_part1,ip:127.0.0.1,port:12053
'''
def members_from_test_env(env_str: str) -> List[WorkerInfo]:
members: List[WorkerInfo] = []
for member_str in env_str.split(';'):
member_info = {}
for item in member_str.split(','):
key, value = item.split(':')
member_info[key] = value
members.append(WorkerInfo(
server_port=int(member_info['port']),
gang_hb_port=-1,
http_port=-1,
rpc_server_port=-1,
backend_server_port=-1,
remote_rpc_server_port=-1,
cache_store_listen_port=-1,
cache_store_connect_port=-1,
cache_store_rdma_connect_port=-1,
cache_store_rdma_listen_port=-1,
local_rank=0,
world_rank=0,
name=member_info['name'], ip=member_info['ip'], info=member_info))
masters = [member for member in members if member.name.endswith('part0')]
if len(masters) != 1:
raise Exception(f"gang master should contains 1 but got {len(masters)}")
sorted_members = sorted(members, key=lambda x:x.name)
if masters[0].name != sorted_members[0].name:
raise Exception(f"gang master should be the first one but got {sorted_members[0].name}")
return sorted_members
'''
raw gang info example:
app.c2.io/biz-detail-ganginfo="{\"llama13B_2A10_PCIE_1_inference_part0\":{\"name\":\"llama13B_2A10_PCIE_1_inference_part0\",\"ip\":\"33.76.194.173\"},\"llama13B_2A10_PCIE_1_inference_part1\":{\"name\":\"llama13B_2A10_PCIE_1_inference_part1\",\"ip\":\"33.76.194.182\"}}"
'''
def get_c2_members():
file_name = os.environ.get("GANG_ANNOCATION_PATH", "/etc/podinfo/annotations")
if not os.path.exists(file_name):
raise Exception(f"not found file: {file_name}")
with open(file_name, 'r') as reader:
content = reader.read()
infos = [x for x in content.split("\n") if "app.c2.io/biz-detail-ganginfo" in x]
if len(infos) != 1:
raise Exception(f"ganginfo length is not equal to 1, actual: {infos}")
gang_info = infos[0].replace("\\", "")
logging.info(f"gang info: {gang_info[gang_info.index('=') + 2: -1]}")
gang_info_json = json.loads(gang_info[gang_info.index('=') + 2: -1])
logging.info(f"gang info json: {gang_info_json}")
return members_from_json(gang_info_json)
def get_members_from_file():
file = os.environ[CONFIG_FILE_ENV]
with open(file, 'r') as reader:
config_json = json.loads(reader.read())
return members_from_json(config_json)
class GangInfo(NamedTuple):
members: List[WorkerInfo]
master: WorkerInfo
self: WorkerInfo
num_nodes: int
def workers(self) -> List[WorkerInfo]:
return [member for member in self.members if not member.equals(self.master)]
def get_gang_info() -> GangInfo:
if g_parallel_info.local_world_size < g_parallel_info.world_size:
# from config file
if os.environ.get(CONFIG_FILE_ENV):
members = get_members_from_file()
# for distributed test
elif os.environ.get("GANG_CONFIG_STRING"):
logging.info(f"use GANG_CONFIG_STRING: {os.environ['GANG_CONFIG_STRING']}")
members = members_from_test_env(os.environ['GANG_CONFIG_STRING'])
# from c2 annotation
else:
members = get_c2_members()
else:
members = [WorkerInfo(socket.gethostbyname(socket.gethostname()),
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 'local', None)]
# 假设 GPU 均匀分布,可以整除
# member 是按 part 排序的
self: Optional[WorkerInfo] = None
master: Optional[WorkerInfo] = None
all_members: List[WorkerInfo] = []
for part_rank, member in enumerate(members):
for local_rank in range(g_parallel_info.local_world_size):
new_member = WorkerInfo(
ip=member.ip,
server_port=WorkerInfo.server_port_offset(local_rank, member.server_port),
gang_hb_port=WorkerInfo.gang_hb_port_offset(local_rank, member.server_port),
http_port=WorkerInfo.http_port_offset(local_rank, member.server_port),
rpc_server_port=WorkerInfo.rpc_server_port_offset(local_rank, member.server_port),
backend_server_port=WorkerInfo.backend_server_port_offset(local_rank, member.server_port),
cache_store_listen_port=WorkerInfo.cache_store_listen_port_offset(local_rank, member.server_port),
cache_store_rdma_listen_port=WorkerInfo.cache_store_rdma_listen_port_offset(local_rank, member.server_port),
remote_rpc_server_port=WorkerInfo.rpc_server_port_offset(local_rank, int(os.environ.get("REMOTE_SERVER_PORT", 0))),
cache_store_connect_port=WorkerInfo.cache_store_listen_port_offset(local_rank, int(os.environ.get("REMOTE_SERVER_PORT", 0))),
cache_store_rdma_connect_port=WorkerInfo.cache_store_rdma_listen_port_offset(local_rank, int(os.environ.get("REMOTE_SERVER_PORT", 0))),
local_rank=local_rank,
world_rank=part_rank * g_parallel_info.local_world_size + local_rank,
name=member.name + '_' + str(local_rank),
info=member.info)
all_members.append(new_member)
logging.info(f"local rank {local_rank} vs {g_parallel_info.local_rank}, \
new_member: {new_member.ip} vs {g_worker_info.ip}, \
server port {new_member.server_port} vs {g_worker_info.server_port}")
if (local_rank == g_parallel_info.local_rank and
new_member.ip == g_worker_info.ip and new_member.server_port == g_worker_info.server_port):
self = new_member
if part_rank == 0 and local_rank == 0:
master = new_member
# not check master and self empty here for ut
return GangInfo(all_members, master, self, len(members))