Source code for tdhook.hooks

import weakref
from typing import Callable, Any, Optional, List, Literal, Protocol, Generic, TypeVar, Type, Tuple
import inspect
from tensordict import TensorDict
import re
from torch.utils.hooks import RemovableHandle
from torch import nn
import torch

from tdhook._types import UnraveledKey


[docs] HookDirection = Literal["fwd", "bwd", "fwd_pre", "bwd_pre", "fwd_kwargs", "fwd_pre_kwargs"]
[docs] T = TypeVar("T")
[docs] DIRECTION_TO_PARAMS = { "fwd": ("module", "args", "output"), "bwd": ("module", "grad_input", "grad_output"), "fwd_pre": ("module", "args"), "bwd_pre": ("module", "grad_output"), "fwd_kwargs": ("module", "args", "kwargs", "output"), "fwd_pre_kwargs": ("module", "args", "kwargs"), }
[docs] DIRECTION_TO_RETURN = { "fwd": "output", "bwd": "grad_input", "fwd_pre": "args", "bwd_pre": "grad_output", "fwd_kwargs": "output", "fwd_pre_kwargs": "args", }
[docs] DIRECTION_TO_RETURN_INDEX = {k: v.index(DIRECTION_TO_RETURN[k]) for k, v in DIRECTION_TO_PARAMS.items()}
[docs] DIRECTION_TO_TYPE = { "fwd": "output", "bwd": "grad_input", "fwd_pre": "input", "bwd_pre": "grad_output", "fwd_kwargs": "output", "fwd_pre_kwargs": "input", }
[docs] def _check_hook_signature(hook: Callable, direction: HookDirection): """Check the signature of the hook.""" if direction not in DIRECTION_TO_PARAMS: raise ValueError(f"Invalid direction: {direction}") sig = inspect.signature(hook) param_len = len(sig.parameters) expected_params = DIRECTION_TO_PARAMS[direction] has_varargs = any(param.kind == inspect.Parameter.VAR_POSITIONAL for param in sig.parameters.values()) num_optional_params = sum( 1 for param in sig.parameters.values() if param.default is not inspect.Parameter.empty or param.kind == inspect.Parameter.VAR_KEYWORD ) if has_varargs: if param_len > len(expected_params) + 1 + num_optional_params: raise ValueError( f"Hook ({direction}) must have at most {len(expected_params) + 1 + num_optional_params} positional parameters" ) return if param_len != len(expected_params) + num_optional_params: raise ValueError(f"Hook ({direction}) must have the signature {expected_params}")
[docs] def merge_paths(*paths: str) -> str: """Merge multiple paths into a single path.""" return ".".join(path for path in paths if path)
[docs] def resolve_submodule_path(root: nn.Module, path: str): """ Resolve a submodule path that may contain indexing expressions. Supports any valid Python attribute access and indexing: - "[0]" -> root[0] - "layers[-1]" -> root.layers[-1] - "layers['attr']" -> root.layers['attr'] - "layers.attention" -> root.layers.attention - "layers[1:3]" -> root.layers[1:3] - "fn(0)" -> root.fn(0) Supports custom attributes: - "<block0/module>" -> getattr(root, "block0/module") - "<block0/module>.layers.attention[0]" -> getattr(root, "block0/module").layers.attention[0] - "m1.<block0/module>.layers.<module>.linear[0]" -> getattr(getattr(root.m1, "block0/module").layers, "module").linear[0] - "m1.<0>.layers" -> getattr(root.m1, "0").layers """ if not path: return root path = re.sub(r"\.(\d[a-zA-Z0-9_]*)", r"<\1>", path) # Attributes starting with a number path = re.sub(r"(\.+)", ".", path) path = path.strip(".") start_key, *rest = path.split("<", maxsplit=1) if rest: start_root = resolve_submodule_path(root, start_key) attr, *rest = rest[0].split(">", maxsplit=1) if not rest: raise ValueError(f"Invalid submodule path '{path}', missing closing '>'") return resolve_submodule_path(getattr(start_root, attr), rest[0]) # Create a safe environment with only the current module safe_dict = {"root": root} try: if path.startswith(("[", ".")): return eval(f"root{path}", {"__builtins__": {}}, safe_dict) else: return eval(f"root.{path}", {"__builtins__": {}}, safe_dict) except (AttributeError, IndexError, KeyError, SyntaxError) as e: raise ValueError(f"Invalid submodule path '{path}': {e}") from e
[docs] def submodule_path_to_name(path: str) -> str: """Convert a submodule path to a name.""" if re.search(r"(\[\-)|\:|\(|\)", path): return path path = re.sub(r"[\"\']", "", path) path = re.sub(r"[<>\[\]]", ".", path) path = re.sub(r"\.+", ".", path) return path.strip(".")
[docs] def register_hook_to_module( module: nn.Module, hook: Callable, direction: HookDirection, prepend: bool = False, ) -> RemovableHandle: """Register the hook to the module.""" _check_hook_signature(hook, direction) if direction in ["fwd", "fwd_kwargs"]: return module.register_forward_hook(hook, prepend=prepend, with_kwargs=direction == "fwd_kwargs") elif direction == "bwd": return module.register_full_backward_hook(hook, prepend=prepend) elif direction in ["fwd_pre", "fwd_pre_kwargs"]: return module.register_forward_pre_hook(hook, prepend=prepend, with_kwargs=direction == "fwd_pre_kwargs") else: return module.register_full_backward_pre_hook(hook, prepend=prepend)
[docs] class RemovableHandleProtocol(Protocol):
[docs] def remove(self): ...
[docs] class MultiHookHandle: """ Handle for multiple hooks. """ def __init__(self, handles: Optional[List[RemovableHandleProtocol]] = None):
[docs] self._handles = handles or []
[docs] def remove(self): for handle in self._handles: handle.remove()
[docs] def __enter__(self): return self
[docs] def __exit__(self, exc_type, exc_value, traceback): self.remove()
[docs] def __add__(self, other: Any): if not isinstance(other, MultiHookHandle): raise TypeError(f"MultiHookHandle cannot be added to {type(other).__name__}") return MultiHookHandle(self._handles + other._handles)
[docs] class MultiHookManager: """ Manager for multiple hooks. """ def __init__( self, pattern: Optional[str] = None, classes_to_hook: Tuple[Type[nn.Module], ...] = (nn.Module,), classes_to_skip: Tuple[Type[nn.Module], ...] = (), ): if pattern is None: pattern = r"a^" # match nothing by default
[docs] self._pattern = pattern
[docs] self._classes_to_hook = classes_to_hook
[docs] self._classes_to_skip = classes_to_skip
[docs] self._reg_exp = re.compile(pattern)
@property
[docs] def pattern(self) -> str: """The pattern to match the modules.""" return self._pattern
@pattern.setter def pattern(self, pattern: str): self._pattern = pattern self._reg_exp = re.compile(pattern)
[docs] def register_hook( self, module: nn.Module, hook_factory: Callable[[str], Callable], *, direction: HookDirection = "fwd", prepend: bool = False, relative_path: Optional[str] = None, ): """Register the hook to the module.""" handles = [] root_module = resolve_submodule_path(module, relative_path) if relative_path else module for name, submodule in root_module.named_modules(): if name == "": continue if not isinstance(submodule, self._classes_to_hook) or isinstance(submodule, self._classes_to_skip): continue if self._reg_exp.match(name): handles.append(register_hook_to_module(submodule, hook_factory(name), direction, prepend)) return MultiHookHandle(handles)
[docs] class MutableWeakRef(Generic[T]): """ Weak reference to a mutable object. """ def __init__(self, referee: T):
[docs] self._ref = weakref.ref(referee)
[docs] def resolve(self) -> T: return self._ref()
[docs] def set(self, referee: T): self._ref = weakref.ref(referee)
[docs] class TensorDictRef: """ Reference to a TensorDict. """ def __init__(self, td: Optional[TensorDict]):
[docs] self._td = td
[docs] def resolve(self) -> TensorDict: return self._td
[docs] def set(self, td: TensorDict): self._td = td
[docs] class CacheProxy: """ Proxy for a cache. """ def __init__(self, key: str, cache: TensorDict | MutableWeakRef[TensorDict] | TensorDictRef):
[docs] self._key = key
[docs] self._cache = weakref.ref(cache)
[docs] def resolve(self) -> Any: cache = self._cache() if isinstance(cache, (MutableWeakRef, TensorDictRef)): cache = cache.resolve() if cache is None: raise ValueError("Dead reference to cache") return cache.get(self._key)
[docs] class EarlyStoppingException(Exception): """ Exception for early stopping. """ def __init__(self, key: str):
[docs] self._key = key
super().__init__(f"Early stopping triggered for key {key}")
[docs] class HookFactory: """ Factory for creating hooks. """ @staticmethod
[docs] def _check_callback_signature(callback: Callable, expected_param_names: set[str]): """Check callback signature matches expected parameter names.""" if callback is None: return sig = inspect.signature(callback) param_names = set(sig.parameters.keys()) has_positional_only = any(param.kind == inspect.Parameter.POSITIONAL_ONLY for param in sig.parameters.values()) if has_positional_only: raise ValueError("Callback cannot have positional-only parameters since we only pass named arguments") has_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values()) if has_kwargs: return missing_params = expected_param_names - param_names if missing_params: raise ValueError(f"Callback missing required parameters: {missing_params}")
@staticmethod
[docs] def make_caching_hook( key: UnraveledKey, cache: TensorDict | MutableWeakRef, *, callback: Optional[Callable] = None, direction: HookDirection = "fwd", ) -> Callable: """ Make a caching hook. """ if direction not in DIRECTION_TO_PARAMS: raise ValueError(f"Invalid direction: {direction}") params = DIRECTION_TO_PARAMS[direction] value_index = -2 if direction == "fwd_pre_kwargs" else -1 HookFactory._check_callback_signature(callback, set(params)) def hook(*args): nonlocal key, cache, callback, direction if callback is not None: value = callback(**dict(zip(params, args)), key=key, direction=direction) else: value = args[value_index] if not isinstance(value, torch.Tensor) and not isinstance(value, TensorDict): raise RuntimeError( f"{type(value).__name__} values are not supported for caching, use a `callback` to return a tensor or a tensordict" ) if isinstance(cache, MutableWeakRef | TensorDictRef): _cache = cache.resolve() if _cache is None: raise ValueError("Dead reference to cache") else: _cache = cache _cache[key] = value return hook
@staticmethod
[docs] def make_setting_hook( value: Any, *, callback: Optional[Callable] = None, direction: HookDirection = "fwd" ) -> Callable: """ Make a setting hook. """ if direction not in DIRECTION_TO_PARAMS: raise ValueError(f"Invalid direction: {direction}") params = DIRECTION_TO_PARAMS[direction] return_index = DIRECTION_TO_RETURN_INDEX[direction] HookFactory._check_callback_signature(callback, set(params)) def hook(*args): nonlocal value, callback, params, return_index, direction original_type = type(args[return_index]) _value = value.resolve() if isinstance(value, CacheProxy) else value if callback is not None: _value = callback(**dict(zip(params, args)), value=_value, direction=direction) if _value is not None and type(_value) is not original_type: raise RuntimeError( f"Callback returned a value of type {type(_value).__name__} but the original value was of type {original_type.__name__}" ) return _value return hook
@staticmethod
[docs] def make_reading_hook(*, callback: Callable, direction: HookDirection = "fwd") -> Callable: """ Make a reading hook. """ if direction not in DIRECTION_TO_PARAMS: raise ValueError(f"Invalid direction: {direction}") params = DIRECTION_TO_PARAMS[direction] HookFactory._check_callback_signature(callback, set(params)) def hook(*args): nonlocal callback, params, direction callback(**dict(zip(params, args)), direction=direction) return hook
@staticmethod
[docs] def make_stopping_hook(key: str) -> Callable: def hook(module, args, output): nonlocal key raise EarlyStoppingException(key) return hook