arctic_inference/py_custom_ops.py (44 lines of code) (raw):
import torch
import os
import logging
logger = logging.getLogger(__name__)
def try_load_torch_library() -> bool:
package_name = 'arctic_inference'
module_basename = 'custom_ops'
package_path = __import__(package_name).__path__[0]
# Dynamically locate the compiled extension (handles .cpython-310... suffix)
for file in os.listdir(package_path):
if file.startswith(module_basename) and file.endswith('.so'):
library_path = os.path.join(package_path, file)
break
else:
logger.info("Could not find compiled custom_ops library in package.")
return False
try:
logger.info(f"Attempting to load custom ops from {library_path}...")
torch.ops.load_library(library_path)
return True
except RuntimeError as e:
logger.info(
f"Unable to load custom library from {library_path}. RuntimeError: {e}. Falling back to original implementation."
)
return False
except Exception as e:
logger.info(
f"Unable to load custom library from {library_path}. Exception: {e}. Falling back to original implementation."
)
return False
def reshape_and_cache_flash_bulk(
keys: list[torch.Tensor],
values: list[torch.Tensor],
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scales: list[torch.Tensor],
v_scales: list[torch.Tensor],
num_heads: int,
head_size: int,
) -> None:
torch.ops.arctic_inference.reshape_and_cache_flash_bulk(
keys, values, key_caches, value_caches, slot_mapping, kv_cache_dtype,
k_scales, v_scales, num_heads, head_size)