src/kernels/layer.py (219 lines of code) (raw):
import inspect
import os
import warnings
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from types import MethodType
from typing import TYPE_CHECKING, Dict, Optional, Type, Union
from .utils import get_kernel
if TYPE_CHECKING:
import torch
from torch import nn
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
@dataclass(frozen=True)
class Device:
type: str
# In the future we might add compute capabilities, etc.
def __eq__(self, other):
return isinstance(other, Device) and self.type == other.type
def __hash__(self):
return hash(self.type)
@dataclass
class LayerRepository:
"""
Repository and name of a layer.
"""
layer_name: str = field(
metadata={"help": "The name of the layer in the kernel repository."}
)
repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."})
revision: str = field(
default="main", metadata={"help": "The revision of the layer."}
)
def __eq__(self, other):
return (
isinstance(other, LayerRepository)
and self.layer_name == other.layer_name
and self.repo_id == other.repo_id
and self.revision == other.revision
)
def __hash__(self):
return hash((self.layer_name, self.repo_id, self.revision))
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
"_KERNEL_MAPPING", default={}
)
def use_kernel_mapping(
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
*,
inherit_mapping: bool = True,
):
"""
Context manager that sets a mapping for a duration of the context.
When `inherit_mapping` is set to `True` the current mapping will be
extended by `mapping` inside the context. If it is `False`, only
`mapping` is used inside the context.
"""
class ContextManager:
def __enter__(self):
# Mappings always stack on previous mappings.
if inherit_mapping:
self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
else:
self.token = _KERNEL_MAPPING.set({})
register_kernel_mapping(mapping)
def __exit__(self, exc_type, exc_value, traceback):
_KERNEL_MAPPING.reset(self.token)
return ContextManager()
def register_kernel_mapping(
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
):
"""
Allows one to register a mapping between a layer name the corresponding
kernel to use, depending on the device. This should be use in conjunction
with `kernelize`.
Exemple usage:
```python
from kernels import LayerRepository, register_kernel_mapping
kernel_layer_mapping = {
"LlamaRMSNorm": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="RmsNorm",
revision="layers",
),
},
}
register_kernel_mapping(kernel_layer_mapping)
```
"""
# Merge with existing mappings.
for new_kernel, new_device_repos in mapping.items():
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
for new_device, new_repo in new_device_repos.items():
if isinstance(new_device, str):
device_repo[Device(type=new_device)] = new_repo
else:
device_repo[new_device] = new_repo
def replace_kernel_forward_from_hub(
cls,
layer_name: str,
):
"""
Decorator that prepares a layer class to use a kernel from the Hugging Face Hub.
This decorator stores the layer name and original forward method, which will be used
by the kernelize function to replace the forward implementation with the appropriate
kernel from the hub.
Args:
cls: The layer class to decorate
layer_name: The name of the layer to use for kernel lookup
"""
cls.kernel_layer_name = layer_name
def kernelize(
model: "nn.Module",
device: Optional[Union[str, "torch.device"]] = None,
needs_torch_compile: bool = False,
use_fallback: bool = True,
):
"""
Iterate over all modules in the model and replace the `forward` method of
extensible layers for which kernels are registered using `register_kernel_mapping`
or `use_kernel_mapping`.
Args:
model: The PyTorch model to kernelize
device: The device type to load kernels for. The device type will be inferred
from the parameters of the model when not provided.
needs_torch_compile: When set to `true`, only kernels that support
`torch.compile` will be loaded.
use_fallback: Whether to use the original forward method of modules when no
compatible kernel could be found. If set to `False`, an exception will
be raised in such cases.
Returns:
The kernelized model
"""
import torch
if device is None:
device_type = _find_device(model)
elif isinstance(device, str):
device_type = Device(type=torch.device(device).type)
else:
device_type = Device(device.type)
assert isinstance(device_type, Device)
for _, module in model.named_modules():
module_class = type(module)
if not hasattr(module_class, "kernel_layer_name"):
continue
layer_name = module_class.kernel_layer_name
if _DISABLE_KERNEL_MAPPING:
_replace_forward(module, module_class)
continue
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
if kernel is None:
warnings.warn(
"\n"
f"No kernel mapping found for layer `{layer_name}`. "
f"Check if the layer name matches one of the kernels in the mapping or add the kernel "
f"you want to use to the mapping. Defaulting to original forward implementation."
)
if not use_fallback:
raise ValueError(f"No layer mapping for `{layer_name}`")
_replace_forward(module, module_class)
continue
# Use device type string directly instead of Device object
repo = kernel.get(device_type)
if repo is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
)
_replace_forward(module, module_class)
continue
# Short-circuit if we already loaded the layer.
layer = _CACHED_LAYER.get(repo, None)
if layer is not None:
_conditionally_replace_forward(
module=module,
layer=layer,
needs_torch_compile=needs_torch_compile,
use_fallback=use_fallback,
)
continue
layer = _get_kernel_layer(
repo_id=repo.repo_id,
layer_name=repo.layer_name,
revision=repo.revision,
)
# Validate the replacement layer against the class layer.
_validate_layer(check_cls=module_class, cls=layer)
_CACHED_LAYER[repo] = layer
_conditionally_replace_forward(
module=module,
layer=layer,
needs_torch_compile=needs_torch_compile,
use_fallback=use_fallback,
)
return model
def use_kernel_forward_from_hub(layer_name: str):
"""
Make a layer extensible using the name `layer_name`.
"""
def decorator(cls):
replace_kernel_forward_from_hub(cls, layer_name)
return cls
return decorator
def _get_kernel_layer(
*, repo_id: str, layer_name: str, revision: str
) -> Type["nn.Module"]:
"""Get a layer from a kernel."""
kernel = get_kernel(repo_id, revision=revision)
if getattr(kernel, "layers", None) is None:
raise ValueError(
f"Kernel `{repo_id}` at revision `{revision}` does not define any layers."
)
layer = getattr(kernel.layers, layer_name, None)
if layer is None:
raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.")
return layer
def _validate_layer(*, check_cls, cls):
import torch.nn as nn
# The layer must have at least have the following properties: (1) it
# must be stateless; (2) the forward signature should correspond to
# the signature it is replacing; (3) forward should not call other
# methods.
if not issubclass(cls, nn.Module):
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
# We verify statelessness by checking that the does not have its own
# constructor (since the constructor could add member variables)...
if cls.__init__ is not nn.Module.__init__:
raise TypeError("Layer must not override nn.Module constructor.")
# ... or predefined member variables.
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)}
difference = cls_members - torch_module_members
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
if not difference <= {"can_torch_compile", "has_backward"}:
raise TypeError("Layer must not contain additional members.")
# Check whether the forward signatures are similar.
params = inspect.signature(cls.forward).parameters
ref_params = inspect.signature(check_cls.forward).parameters
if len(params) != len(ref_params):
raise TypeError(
"Forward signature does not match: different number of arguments."
)
for param, ref_param in zip(params.values(), ref_params.values()):
if param.kind != ref_param.kind:
raise TypeError(
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
)
def _find_device(model: "nn.Module") -> Device:
try:
param = next(model.parameters())
except StopIteration:
raise ValueError(
"Cannot determine model device, provide as `device` argument to `kernelize`."
)
return Device(type=param.device.type)
def _conditionally_replace_forward(
*,
module: "nn.Module",
layer: Type["nn.Module"],
needs_torch_compile: bool,
use_fallback: bool,
):
module_class = type(module)
# Switch to fallback when the layer does not support:
# compilation/compile when needed.
# backward when needed
needs_fallback = needs_torch_compile and not getattr(
layer, "can_torch_compile", False
)
if needs_fallback:
if use_fallback:
_replace_forward(module, module_class)
else:
raise ValueError(
f"Available kernel does not fulfill requirements: needs_torch_compile={needs_torch_compile}"
)
else:
_replace_forward(module, layer)
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
import torch.nn as nn
module_class = type(module)
layer_with_backward = (
layer if getattr(layer, "has_backward", True) else module_class
)
def train(self, mode: bool = True) -> nn.Module:
super(type(self), self).train(mode)
if mode:
self.forward = MethodType(layer_with_backward.forward, self)
else:
self.forward = MethodType(layer.forward, self)
return self
module.train = MethodType(train, module) # type: ignore[method-assign]
# Trigger setting correct forward for the current state.
module.train(module.training)