Source code for tdhook.contexts

from contextlib import contextmanager
from contextlib import ExitStack
from typing import List, Optional, Generator, Dict, overload, Literal
from torch import nn
from tensordict.nn import TensorDictModuleBase, TensorDictModule
from tensordict import TensorDict

from tdhook.modules import HookedModule
from tdhook.hooks import MultiHookHandle, merge_paths
from tdhook._types import UnraveledKey


[docs] class HookingContext: """ Base class for hooking contexts. """ def __init__( self, factory: "HookingContextFactory", module: nn.Module, in_keys: Optional[List[UnraveledKey] | Dict[UnraveledKey, str]] = None, out_keys: Optional[List[UnraveledKey]] = None, pre_factories: Optional[List["HookingContextFactory"]] = None, ):
[docs] self._prepare = factory._prepare_module
[docs] self._restore = factory._restore_module
[docs] self._spawn = factory._spawn_hooked_module
[docs] self._hook = factory._hook_module
[docs] self._in_context = False
[docs] self._handle = None
[docs] self._hooked_module = None
[docs] self._pre_factories = pre_factories or []
[docs] self._stack = None
if isinstance(module, TensorDictModuleBase): self._module = module self._extra_relative_path = "" else: self._module = TensorDictModule(module, in_keys or ["input"], out_keys or ["output"]) self._extra_relative_path = "module"
[docs] self._in_keys = self._module.in_keys
[docs] self._out_keys = self._module.out_keys
[docs] self._managed_by_context_manager = False
[docs] def _enter(self, managed_by_context_manager: bool = True): if self._in_context: raise RuntimeError("Cannot enter context twice") self._in_context = True self._managed_by_context_manager = managed_by_context_manager working_module = self._module with ExitStack() as stack: for factory in self._pre_factories: working_module = stack.enter_context(factory.prepare(working_module, self._in_keys, self._out_keys)) self._stack = stack.pop_all() prep_module = self._prepare(working_module, self._in_keys, self._out_keys, self._extra_relative_path) self._hooked_module = self._spawn(prep_module, self, self._extra_relative_path) self._handle = self._hook(self._hooked_module) return self._hooked_module
[docs] def __enter__(self): return self._enter(managed_by_context_manager=True)
[docs] def __exit__(self, exc_type, exc_value, traceback): self._handle.remove() self._restore(self._module, self._in_keys, self._out_keys, self._extra_relative_path) self._in_context = False self._hooked_module = None self._handle = None self._stack.__exit__(exc_type, exc_value, traceback)
@contextmanager
[docs] def disable_hooks(self) -> Generator[None, None, None]: if not self._in_context: raise RuntimeError("Cannot disable hooks outside of context") self._handle.remove() try: yield finally: self._handle = self._hook(self._hooked_module)
@contextmanager
[docs] def disable(self) -> Generator[nn.Module, None, None]: if not self._in_context: raise RuntimeError("Cannot disable context outside of context") with self.disable_hooks(): try: yield self._restore( self._hooked_module.module, self._in_keys, self._out_keys, self._extra_relative_path ) finally: self._hooked_module.module = self._prepare( self._module, self._in_keys, self._out_keys, self._extra_relative_path )
[docs] class HookingContextWithCache(HookingContext): """ Hooking context with cache. """ def __init__(self, *args, cache: Optional[TensorDict] = None, clear_cache: bool = True, **kwargs): super().__init__(*args, **kwargs)
[docs] self._cache = TensorDict() if cache is None else cache
[docs] self._clear_cache = clear_cache
@property
[docs] def cache(self) -> TensorDict: return self._cache
[docs] def clear(self): self._cache.clear()
[docs] def _enter(self, managed_by_context_manager: bool = True): if self._clear_cache: self.clear() return super()._enter(managed_by_context_manager=managed_by_context_manager)
[docs] def __enter__(self): return self._enter(managed_by_context_manager=True)
[docs] class HookingContextFactory: """ Factory for creating hooking contexts. """
[docs] _hooked_module_class = HookedModule
[docs] _hooking_context_class = HookingContext
def __init__(self):
[docs] self._hooking_context_kwargs = {}
[docs] self._hooked_module_kwargs = {}
@overload
[docs] def prepare( self, module: nn.Module, in_keys: Optional[List[UnraveledKey] | Dict[UnraveledKey, str]] = None, out_keys: Optional[List[UnraveledKey]] = None, *, return_context: Literal[True] = True, ) -> "HookingContext": ...
@overload def prepare( self, module: nn.Module, in_keys: Optional[List[UnraveledKey] | Dict[UnraveledKey, str]] = None, out_keys: Optional[List[UnraveledKey]] = None, *, return_context: Literal[False], ) -> HookedModule: ... def prepare( self, module: nn.Module, in_keys: Optional[List[UnraveledKey] | Dict[UnraveledKey, str]] = None, out_keys: Optional[List[UnraveledKey]] = None, *, return_context: bool = True, ) -> "HookingContext | HookedModule": """ Prepare the module for execution. Args: module: The module to prepare. in_keys: Optional input keys. out_keys: Optional output keys. return_context: If True (default), returns a context manager. If False, returns the hooked module directly. Returns: If return_context is True, returns a HookingContext that can be used as a context manager. If return_context is False, returns the HookedModule directly (context is automatically entered). """ if isinstance(module, TensorDictModuleBase): if in_keys is not None: for key in in_keys: if not isinstance(key, UnraveledKey): raise ValueError(f"in_keys must be unraveled, got {type(key)}") if key not in module.in_keys: raise ValueError(f"Key {key} not in module.in_keys") if out_keys is not None: for key in out_keys: if not isinstance(key, UnraveledKey): raise ValueError(f"out_keys must be unraveled, got {type(key)}") if key not in module.out_keys: raise ValueError(f"Key {key} not in module.out_keys") context = self._hooking_context_class(self, module, in_keys, out_keys, **self._hooking_context_kwargs) if return_context: return context else: return context._enter(managed_by_context_manager=False)
[docs] def _prepare_module( self, module: TensorDictModuleBase, in_keys: List[UnraveledKey], out_keys: List[UnraveledKey], extra_relative_path: str, ) -> TensorDictModuleBase: return module
[docs] def _restore_module( self, module: TensorDictModuleBase, in_keys: List[UnraveledKey], out_keys: List[UnraveledKey], extra_relative_path: str, ) -> TensorDictModuleBase: return module
[docs] def _spawn_hooked_module( self, prep_module: TensorDictModuleBase, hooking_context: "HookingContext", extra_relative_path: str ) -> HookedModule: base_relative_path = self._hooked_module_kwargs.get("relative_path", "td_module") relative_path = merge_paths(base_relative_path, extra_relative_path) kwargs = { **self._hooked_module_kwargs, "relative_path": relative_path, } return self._hooked_module_class(prep_module, hooking_context=hooking_context, **kwargs)
[docs] def _hook_module(self, module: HookedModule) -> MultiHookHandle: return MultiHookHandle()
[docs] class CompositeHookingContextFactory(HookingContextFactory): """ Composite hooking context factory. """ def __init__(self, *contexts: HookingContextFactory): super().__init__()
[docs] self._contexts = contexts
attributes = ("_spawn_hooked_module", "_hooking_context_class", "_hooked_module_class") composite_overriden = { attr: getattr(type(self), attr) != getattr(HookingContextFactory, attr) for attr in attributes } for context in contexts: for attr in attributes: if ( getattr(type(context), attr) != getattr(HookingContextFactory, attr) and not composite_overriden[attr] ): raise ValueError( f"Cannot compose factories that override {attr}, consider subclassing this factory to override {attr}" )
[docs] def _prepare_module( self, module: TensorDictModuleBase, in_keys: List[UnraveledKey], out_keys: List[UnraveledKey], extra_relative_path: str, ) -> TensorDictModuleBase: for context in self._contexts: module = context._prepare_module(module, in_keys, out_keys, extra_relative_path) return module
[docs] def _restore_module( self, module: TensorDictModuleBase, in_keys: List[UnraveledKey], out_keys: List[UnraveledKey], extra_relative_path: str, ) -> TensorDictModuleBase: for context in reversed(self._contexts): module = context._restore_module(module, in_keys, out_keys, extra_relative_path) return module
[docs] def _hook_module(self, module: HookedModule) -> MultiHookHandle: handles = [context._hook_module(module) for context in self._contexts] return MultiHookHandle(handles)