tinynn/llm_quant/util.py (41 lines of code) (raw):
import ctypes
import os
import platform
import sys
import importlib
import torch.nn as nn
def _init_patch_easyquant():
pkg_root = os.path.dirname(
os.path.realpath(importlib.machinery.PathFinder().find_module("easyquant").get_filename())
)
libs_dir = os.path.abspath(pkg_root)
is_conda_cpython = platform.python_implementation() == 'CPython' and (
hasattr(ctypes.pythonapi, 'Anaconda_GetVersion') or 'packaged by conda-forge' in sys.version
)
if sys.version_info[:2] >= (3, 8) and not is_conda_cpython or sys.version_info[:2] >= (3, 10):
if os.path.isdir(libs_dir):
os.add_dll_directory(libs_dir)
else:
load_order_filepath = os.path.join(libs_dir, '.load-order-easyquant-0.0.1')
if os.path.isfile(load_order_filepath):
with open(load_order_filepath, 'r', encoding='utf-8') as file:
load_order = file.read().split()
for lib in load_order:
lib_path = os.path.join(os.path.join(libs_dir, lib))
if os.path.isfile(lib_path) and not ctypes.windll.kernel32.LoadLibraryExW(
ctypes.c_wchar_p(lib_path), None, 0x00000008
):
raise OSError('Error loading {}; {}'.format(lib, ctypes.FormatError()))
def get_submodule_with_parent_from_name(model, module_name):
"""Gets the submodule with its parent and sub_name using the name given"""
module_name_parts = module_name.split('.')
cur_obj = model
last_obj = None
for ns in module_name_parts:
last_obj = cur_obj
if type(cur_obj) is nn.ModuleList:
cur_obj = cur_obj[int(ns)]
elif type(cur_obj) is nn.ModuleDict:
cur_obj = cur_obj[ns]
else:
cur_obj = getattr(cur_obj, ns)
return cur_obj, last_obj, module_name_parts[-1]