maga_transformer/utils/database.py (126 lines of code) (raw):
from typing import Any, Dict, List, Set, Union, Optional, NamedTuple
from pathlib import PosixPath, Path
import json
import os
import logging
import re
import torch
from maga_transformer.utils.ckpt_file_info import CkptFileInfo, FinetuneType
from maga_transformer.lora.lora_file import LoraCkpt, LoraConfig
class BaseDatabase:
def get_pretrain_tensor_names(self) -> List[str]:
raise NotImplementedError
def get_lora_tensor_names(self, name: str) -> List[str]:
raise NotImplementedError
def load_tensor(self, name: str, datatype: Optional[torch.dtype] = torch.float16) -> List[torch.Tensor]:
raise NotImplementedError
def get_tensor_order(self, name: str) -> List[int]:
raise NotImplementedError
def get_tensor_type(self, name: str) -> torch.dtype:
raise NotImplementedError
class CkptDatabase(BaseDatabase):
PretrainFileList : List[CkptFileInfo]
FinetuneFileList : List[CkptFileInfo]
LoraCkpt: LoraCkpt
finetune_type : FinetuneType
def __init__(self, path: Optional[str], ptuning_path: Optional[str] = None) -> None:
if path is None:
return
self.PretrainFileList = []
self.FinetuneFileList = []
self.LoraCkpt = LoraCkpt()
if os.path.isfile(path):
raise Exception(f"CkptDatabase needs directory contains checkpoint files")
self.load_hf_meta(path)
self.load_ptuning_meta(ptuning_path)
logging.debug(f"CkptDatabase all tensor names = {self.get_pretrain_tensor_names()}")
def load_hf_meta(self, path: str):
# avoid consolidated.safetensors in Mistral-Nemo-Instruct-2407
index = os.path.join(path, 'model.safetensors.index.json')
if os.path.exists(index):
files = set(json.load(open(index))['weight_map'].values())
for f in files:
ckpt = CkptFileInfo(file_name=os.path.join(path, f))
self.PretrainFileList.append(ckpt)
return
# standard HF
patterns = ["*.safetensors", "*.bin", "*.pth", "*.pt"]
glob_files = {}
for pattern in patterns:
glob_files[pattern] = [file for file in Path(path).glob(pattern)]
for _, value in glob_files.items():
if len(value) != 0:
exclude_pattern: re.Pattern[str] = re.compile(r'.*adapter_model\.bin.*|.*training_args\.bin.*')
for f in value:
if not exclude_pattern.match(f.name):
ckpt = CkptFileInfo(file_name=str(f))
self.PretrainFileList.append(ckpt)
break
def load_ptuning_meta(self, ptuning_path: Optional[str]):
if ptuning_path is None or not os.path.exists(ptuning_path):
return
for f in Path(ptuning_path).glob("pytorch_model.bin"):
if not self._contains(f):
ckpt = CkptFileInfo(file_name=str(f), finetune_type=FinetuneType.ptuning)
self.FinetuneFileList.append(ckpt)
def _contains(self, path: Path):
for info in self.PretrainFileList + self.FinetuneFileList:
if Path(info.file_name).resolve() == path.resolve():
return True
return False
def get_pretrain_tensor_names(self) -> List[str]:
tensor_names = []
for ckptfile in self.PretrainFileList:
tensor_names.extend(ckptfile.get_tensor_names())
for ckptfile in self.FinetuneFileList:
tensor_names.extend(ckptfile.get_tensor_names())
return tensor_names
def load_tensor(self, name: str, datatype: Optional[torch.dtype] = torch.float16) -> List[torch.Tensor]:
tensors = []
for ckpt_file in self.PretrainFileList:
if name in ckpt_file.get_tensor_names():
tensors.append(ckpt_file.load_tensor(name, datatype))
logging.debug(f"self.FinetuneFileList: {self.FinetuneFileList}, PretrainFileList: {self.PretrainFileList}")
for ckpt_file in self.FinetuneFileList:
logging.debug(f"load tensor {name} from {ckpt_file.file_name}")
if name in ckpt_file.get_tensor_names():
tensors.append(ckpt_file.load_tensor(name, datatype))
return tensors
def get_tensor_type(self, name: str) -> torch.dtype:
return self.PretrainFileList[0].get_tensor_type(name)
def get_tensor_order(self, name: str) -> List[int]:
orders = []
for ckpt_file in self.PretrainFileList:
if name in ckpt_file.get_tensor_names():
orders.append((ckpt_file.file_name, ckpt_file.get_tensor_read_order(name)))
for ckpt_file in self.FinetuneFileList:
if name in ckpt_file.get_tensor_names():
orders.append((ckpt_file.file_name, ckpt_file.get_tensor_read_order(name)))
return orders
def load_tensors_by_prefix(self, prefix_list: List[str], device: str, direct_io: bool) -> dict[str, List[torch.Tensor]]:
res = {}
for ckptfile in self.PretrainFileList:
if any(tensor.startswith(prefix_list) for tensor in ckptfile.get_tensor_names()):
tensors = ckptfile.load_tensors(device, direct_io)
for k, v in tensors.items():
if not k.startswith(prefix_list):
continue
if k not in res:
res[k] = [v]
else:
res[k].append(v)
return res
def get_lora_tensor_names(self, config_name: str) -> List[str]:
return self.LoraCkpt.get_lora_tensor_names(config_name)
def load_lora_tensor(self, lora_name: str, tensor_name: str) -> List[torch.Tensor]:
return self.LoraCkpt.load_lora_tensor(lora_name, tensor_name)
def load_lora(self, config_name: str, lora_path: str):
self.LoraCkpt.load_lora(config_name, lora_path)
def remove_lora(self, name: str):
return self.LoraCkpt.remove_lora(name)
def get_lora_config(self, config_name: str):
return self.LoraCkpt.get_lora_config(config_name)
def has_lora(self):
return self.LoraCkpt.has_lora()
def get_first_lora_name(self):
return self.LoraCkpt.get_first_lora_name()
def dump_lora_info(self) -> None:
self.LoraCkpt.dump_lora_info()