maga_transformer/utils/util.py (163 lines of code) (raw):
import re
import torch
import os
import json
import logging
import shutil
import pynvml
import threading
import requests
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Dict, Any, List, Set
from maga_transformer import _ft_pickler
class AtomicCounter:
def __init__(self, initial: int=0):
self.initial = initial
self.value = initial
self._lock = threading.Lock()
def increment(self):
with self._lock:
self.value += 1
return self.value
def decrement(self):
with self._lock:
self.value -= 1
return self.value
def decrement_if_gt_0(self):
with self._lock:
if self.value > 0:
self.value -= 1
return True
return False
def get(self):
with self._lock:
return self.value
def reset(self):
with self._lock:
self.value = self.initial
PathLike = Union[str, Path]
def to_torch_dtype(maybe_str_dtype: Union[str, torch.dtype]) -> torch.dtype:
if isinstance(maybe_str_dtype, torch.dtype):
dtype = maybe_str_dtype
else:
try:
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
"int8": torch.int8
}[maybe_str_dtype.lower()]
except KeyError:
raise ValueError(f"Cannot convert to torch data type, got {maybe_str_dtype}")
return dtype
def check_get_config_from_path(ckpt_path: str) -> Dict[str, Any]:
config_json = get_config_from_path(ckpt_path)
if config_json is None:
raise Exception(f"Failed to get config.json from path: {ckpt_path}")
return config_json
def get_config_from_path(ckpt_path: str) -> Optional[Dict[str, Any]]:
if os.path.isdir(ckpt_path):
# load from huggingface
config_json_path = os.path.join(ckpt_path, 'config.json')
if os.path.isfile(config_json_path):
with open(config_json_path, "r", encoding="utf-8") as reader:
text = reader.read()
config_dict = json.loads(text)
return config_dict
return None
def generate_pad_mask(input_lengths: torch.Tensor, memory_length: int, init_step: int=0):
""" Generate a pad mask tensor.
# Args.
input_lengths: (batch_size * beam_width,), input lengths
memory_length: the length of key/value cache memory.
init_step: int, initial step.
# Return
masked_tokens: BoolTensor, (batch_size * beam_width, memory_length),
True if init_step + input_length[i] <= j < init_step + max_input_length,
where i is a batch-beam index and j is a time step modulo by memory_length.
"""
max_input_length = input_lengths.max()
input_lengths = input_lengths.unsqueeze(1)
shift = init_step % memory_length
step_indices = torch.arange(
init_step, init_step + memory_length, device=input_lengths.device)
step_indices = step_indices.roll(shift).unsqueeze(0).tile(input_lengths.shape[0], 1)
masked_tokens = torch.logical_and(
step_indices >= input_lengths, step_indices < init_step + max_input_length)
return masked_tokens
def get_ckpt_file_from_index(ckpt_path: str, model_index_file: str) -> List[str]:
with open(model_index_file) as reader:
index_json = json.loads(reader.read())
ckpt_set: Set[str] = set()
for _, ckpt_file in index_json['weight_map'].items():
ckpt_set.add(ckpt_file)
return [os.path.join(ckpt_path, ckpt_file) for ckpt_file in ckpt_set]
def load_ckpt(ckpt_path: str) -> Dict[str, Any]:
if os.path.isfile(ckpt_path):
return torch.load(ckpt_path, map_location='cpu', pickle_module=_ft_pickler)
elif os.path.isdir(ckpt_path):
# just support from huggingface
model_index_file = os.path.join(ckpt_path, 'pytorch_model.bin.index.json')
if os.path.exists(model_index_file):
checkpoints = get_ckpt_file_from_index(ckpt_path, model_index_file)
else:
checkpoints = sorted(Path(ckpt_path).glob("*.bin"))
params: Dict[str, torch.Tensor] = {}
for ckpt in checkpoints:
params.update(torch.load(ckpt, map_location='cpu', pickle_module=_ft_pickler))
return params
else:
raise NotImplementedError(f"just support pt file or huggingface: ckpt_path:{ckpt_path}")
def copy_gemm_config():
if 'HIPPO_APP_INST_ROOT' in os.environ:
inst_root = os.environ['HIPPO_APP_INST_ROOT']
gemm_config_path = os.path.join(inst_root, 'gemm_config.in')
if os.path.exists(gemm_config_path):
logging.info("Found gemm_config, copy to current path")
shutil.copy(gemm_config_path, '.')
return
logging.info("not found gemm_config in HIPPO_APP_INST_ROOT, not copy")
def get_dtype_size(dtype: torch.dtype) -> int:
return {torch.int8: 1, torch.half: 2, torch.bfloat16: 2, torch.float: 4}[dtype]
def check_with_info(condition: bool, error_msg: str):
if not condition:
raise Exception(error_msg)
def str_to_bool(s: str):
true_values = ('yes', 'true', '1')
false_values = ('no', 'false', '0')
if s.lower() in true_values:
return True
elif s.lower() in false_values:
return False
else:
raise ValueError("Cannot covert {} to a bool".format(s))
def closest_power_of_2(x):
if x < 1:
return 1
power = 1
while power * 2 <= x:
power *= 2
return power
# a's suffix is equal to b's prefix
def has_overlap(a: str, b: str) -> bool:
max_possible = min(len(a), len(b))
for k in range(1, max_possible + 1):
if a[-k:] == b[:k]:
return True
return False
# a's suffix is equal to b's prefix
def has_overlap_kmp(a: str, b: str) -> bool:
if len(a) > len(b):
a = a[-(len(b) + 1):]
s = b + '#' + a
prefix = [0] * len(s)
for i in range(1, len(s)):
j = prefix[i-1]
while j > 0 and s[i] != s[j]:
j = prefix[j-1]
if s[i] == s[j]:
j += 1
prefix[i] = j
return prefix[-1] > 0
def request_server(method: str, server_port: int, uri: str = '', req: Dict[str, Any] = {}):
try:
if method == "get":
response = requests.get(f'http://localhost:{server_port}/{uri}', json=req)
return response.json()
elif method == "post":
response = requests.post(f'http://localhost:{server_port}/{uri}', json=req)
return response.json()
else:
return {"error": "error method"}
except requests.RequestException as e:
return {"error": f"Failed to call {uri}", "details": str(e)}