Source code for tdhook.modules

from torch.utils.hooks import RemovableHandle
from tensordict.nn import TensorDictModule, TensorDictModuleWrapper, TensorDictModuleBase
from tensordict import TensorDict, NonTensorData
from typing import Callable, Any, Optional, Tuple, TYPE_CHECKING, List
import torch
import warnings
import torch.nn as nn
from contextlib import contextmanager
from textwrap import indent

from tdhook.hooks import (
    register_hook_to_module,
    CacheProxy,
    HookFactory,
    EarlyStoppingException,
    resolve_submodule_path,
    HookDirection,
    DIRECTION_TO_TYPE,
    MutableWeakRef,
    TensorDictRef,
)
from tdhook._types import UnraveledKey

if TYPE_CHECKING:
    from tdhook.contexts import HookingContext


[docs] def get_best_device(): if torch.cuda.is_available(): return torch.device("cuda") elif torch.backends.mps.is_available(): return torch.device("mps") else: return torch.device("cpu")
[docs] def flatten_select_reshape_call( module: TensorDictModuleBase, td: TensorDict, flatten: bool = True, select: bool = True, reshape: bool = True ) -> TensorDict: _td = td.flatten() if flatten else td _td = module(_td) _td = _td.select(*module.out_keys) if select else _td _td = _td.reshape(td.shape) if reshape else _td return _td
[docs] class FunctionModule(TensorDictModuleBase): """ Wrapper for a function to be used as a module. """ def __init__( self, td_fn: Callable[[TensorDict], TensorDict], in_keys: List[UnraveledKey], out_keys: List[UnraveledKey] ): super().__init__()
[docs] self.in_keys = in_keys
[docs] self.out_keys = out_keys
[docs] self._td_fn = td_fn
[docs] def forward(self, td: TensorDict) -> TensorDict: return self._td_fn(td)
[docs] def __repr__(self): fields = indent( f"in_keys={self.in_keys},\nout_keys={self.out_keys},\ntd_fn={self._td_fn}", 4 * " ", ) return f"{type(self).__name__}(\n{fields})"
[docs] class ModuleCall(TensorDictModuleBase): """ Wrapper to manage module calls. """ def __init__( self, td_module: TensorDictModuleBase, in_key: Optional[UnraveledKey] = None, out_key: Optional[UnraveledKey] = None, flatten: bool = True, ): super().__init__()
[docs] self.in_keys = [k if in_key is None else (in_key, k) for k in td_module.in_keys]
[docs] self.out_keys = [k if out_key is None else (out_key, k) for k in td_module.out_keys]
[docs] self._td_module = td_module
[docs] self._in_key = in_key
[docs] self._out_key = out_key
[docs] self._flatten = flatten
[docs] def forward(self, td: TensorDict) -> TensorDict: inputs = td if self._in_key is None else td[self._in_key] outputs = flatten_select_reshape_call(self._td_module, inputs, flatten=self._flatten) if self._out_key is not None: prev_out = td.get(self._out_key) if isinstance(prev_out, TensorDict): prev_out.update(outputs) else: td[self._out_key] = outputs else: td.update(outputs) return td
[docs] def __repr__(self): fields = indent( f"td_module={self._td_module},\nin_keys={self.in_keys},\nout_keys={self.out_keys}", 4 * " ", ) return f"{type(self).__name__}(\n{fields})"
[docs] class ModuleCallWithCache(TensorDictModuleBase): """ Wrapper to manage module calls with cache. """ def __init__( self, td_module: TensorDictModuleBase, stored_keys: List[UnraveledKey], cache_key: Optional[UnraveledKey] = None, in_key: Optional[UnraveledKey] = None, out_key: Optional[UnraveledKey] = None, cache_ref: Optional[MutableWeakRef | TensorDictRef] = None, flatten: bool = True, cache_as_output: bool = True, ): super().__init__()
[docs] self.in_keys = [k if in_key is None else (in_key, k) for k in td_module.in_keys]
if cache_as_output: self.out_keys = [k if out_key is None else (out_key, k) for k in td_module.out_keys] + [ k if cache_key is None else (cache_key, k) for k in stored_keys ] else: self.out_keys = [k if out_key is None else (out_key, k) for k in td_module.out_keys]
[docs] self._td_module = td_module
[docs] self._cache_key = cache_key
[docs] self._in_key = in_key
[docs] self._out_key = out_key
[docs] self._flatten = flatten
[docs] self._cache_as_output = cache_as_output
[docs] self._cache_ref = cache_ref or MutableWeakRef(TensorDict())
@property
[docs] def cache_ref(self) -> MutableWeakRef | TensorDictRef: return self._cache_ref
[docs] def forward(self, td: TensorDict) -> TensorDict: inputs = td if self._in_key is None else td[self._in_key] cache = TensorDict(batch_size=inputs.batch_size, device=inputs.device).flatten() self._cache_ref.set(cache) outputs = flatten_select_reshape_call(self._td_module, inputs, flatten=self._flatten) if self._out_key is not None: td[self._out_key] = outputs else: td.update(outputs) if self._cache_as_output and self._cache_key is not None: td[self._cache_key] = cache.reshape(inputs.shape) elif self._cache_as_output: td.update(cache.reshape(inputs.shape)) else: cache["_shape"] = NonTensorData(tuple(inputs.shape)) return td
[docs] def __repr__(self): fields = indent( f"td_module={self._td_module},\nin_keys={self.in_keys},\nout_keys={self.out_keys}", 4 * " ", ) return f"{type(self).__name__}(\n{fields})"
[docs] class PGDModule(TensorDictModuleBase): """ Wrapper to manage PGD module calls. """ def __init__( self, td_module: TensorDictModuleBase, alpha: float = 0.1, n_steps: int = 10, min_value: float = -float("Inf"), max_value: float = float("Inf"), grad_key: UnraveledKey = "_grad", working_key: UnraveledKey = "_working", ascent: bool = False, use_sign: bool = True, ): super().__init__()
[docs] self._td_module = td_module
[docs] self.in_keys = td_module.in_keys
[docs] self.out_keys = [k if working_key is None else (working_key, k) for k in td_module.out_keys]
[docs] self._alpha = alpha
[docs] self._n_steps = n_steps
[docs] self._min_value = min_value
[docs] self._max_value = max_value
[docs] self._grad_key = grad_key
[docs] self._working_key = working_key
[docs] self._ascent = ascent
[docs] self._use_sign = use_sign
[docs] def forward(self, td: TensorDict) -> TensorDict: working_td = td if self._working_key is None else td[self._working_key] for _ in range(self._n_steps): working_td = self._td_module(working_td) working_td = self._pgd_step(working_td) if self._working_key is not None: td[self._working_key] = working_td else: td.update(working_td) return td
[docs] def _pgd_step(self, td: TensorDict) -> TensorDict: grads: TensorDict = td[self._grad_key] if self._ascent: grads = -grads if self._use_sign: grads = torch.sign(grads) for key in grads.keys(True, True): td[key] = torch.clamp(td[key] - self._alpha * grads[key], min=self._min_value, max=self._max_value) return td
[docs] def __repr__(self): fields = indent(f"td_module={self._td_module},\nin_keys={self.in_keys},\nout_keys={self.out_keys},\n", 4 * " ") return f"{type(self).__name__}(\n{fields})"
[docs] class IntermediateKeysCleaner(TensorDictModuleBase): """ Wrapper to clean intermediate keys. """ def __init__(self, intermediate_keys: List[UnraveledKey]): super().__init__()
[docs] self.in_keys = intermediate_keys
[docs] self.out_keys = []
[docs] self._intermediate_keys = intermediate_keys
[docs] def forward(self, td: TensorDict) -> TensorDict: return td.exclude(*self._intermediate_keys)
[docs] def __repr__(self): fields = indent( f"in_keys={self.in_keys},\nout_keys={self.out_keys}", 4 * " ", ) return f"{type(self).__name__}(\n{fields})"
[docs] class HookedModuleRun: """ Context manager to execute module runs. """ def __init__( self, module: "HookedModule", data: TensorDict, cache: Optional[TensorDict] = None, run_name: Optional[str] = None, run_sep: Optional[str] = None, run_cache: Optional[TensorDict] = None, grad_enabled: bool = False, run_callback: Optional[Callable] = None, ):
[docs] self._module = module
[docs] self._data = data
[docs] self._outer_cache = cache
[docs] self._name = run_name or "run"
[docs] self._sep = run_sep or "."
[docs] self._cache = TensorDict() if run_cache is None else run_cache
[docs] self._grad_enabled = grad_enabled
[docs] self._run_callback = run_callback or (lambda module, data: module(data))
[docs] self._save_cache = self._cache if self._outer_cache is None else self._outer_cache
[docs] self._handles = []
[docs] self._in_context = False
@property
[docs] def cache(self) -> TensorDict: return self._cache
@cache.setter def cache(self, cache: TensorDict): self._cache = cache
[docs] def __enter__(self): self._in_context = True return self
[docs] def __exit__(self, exc_type, exc_value, traceback): try: with torch.set_grad_enabled(self._grad_enabled): self._run_callback(self._module, self._data) except EarlyStoppingException: pass except Exception as e: raise e finally: for handle in self._handles: handle.remove() self._in_context = False
[docs] def _ensure_in_context(self, method: str): if not self._in_context: raise RuntimeError(f"Not in context, method {method} must be called in context or directly on the module")
[docs] def set( self, key: str, value: Any, *, callback: Optional[Callable] = None, direction: HookDirection = "fwd", prepend: bool = False, relative: bool = True, ) -> None: self._ensure_in_context("set") handle = self._module.set( key, value, callback=callback, direction=direction, prepend=prepend, relative=relative ) self._handles.append(handle)
[docs] def get( self, key: str, *, cache_key: Optional[str] = None, callback: Optional[Callable] = None, direction: HookDirection = "fwd", prepend: bool = False, relative: bool = True, ) -> CacheProxy: self._ensure_in_context("get") handle, proxy = self._module.get( self._cache, key, cache_key, callback=callback, direction=direction, prepend=prepend, relative=relative ) self._handles.append(handle) return proxy
[docs] def save( self, key: str, *, cache_key: Optional[str] = None, callback: Optional[Callable] = None, direction: HookDirection = "fwd", prepend: bool = False, relative: bool = True, ) -> CacheProxy: self._ensure_in_context("save") cache_key = cache_key or f"{self._name + self._sep + key}_{DIRECTION_TO_TYPE[direction]}" handle, proxy = self._module.get( self._save_cache, key, cache_key=cache_key, callback=callback, direction=direction, prepend=prepend, relative=relative, ) self._handles.append(handle) return proxy
# TODO: rename grad_input
[docs] def set_grad(self, *args, **kwargs): self._ensure_in_context("set_grad") self._grad_enabled = True kwargs["direction"] = "bwd" self.set(*args, **kwargs)
[docs] def get_grad(self, *args, **kwargs): self._ensure_in_context("get_grad") self._grad_enabled = True kwargs["direction"] = "bwd" return self.get(*args, **kwargs)
[docs] def save_grad(self, *args, **kwargs): self._ensure_in_context("save_grad") self._grad_enabled = True kwargs["direction"] = "bwd" return self.save(*args, **kwargs)
[docs] def set_input(self, *args, **kwargs): kwargs["direction"] = "fwd_pre" self.set(*args, **kwargs)
[docs] def get_input(self, *args, **kwargs): kwargs["direction"] = "fwd_pre" return self.get(*args, **kwargs)
[docs] def save_input(self, *args, **kwargs): kwargs["direction"] = "fwd_pre" return self.save(*args, **kwargs)
[docs] def set_grad_output(self, *args, **kwargs): kwargs["direction"] = "bwd_pre" self.set(*args, **kwargs)
[docs] def get_grad_output(self, *args, **kwargs): kwargs["direction"] = "bwd_pre" return self.get(*args, **kwargs)
[docs] def save_grad_output(self, *args, **kwargs): kwargs["direction"] = "bwd_pre" return self.save(*args, **kwargs)
[docs] def stop(self, key: str) -> None: self._ensure_in_context("stop") handle = self._module.stop(key) self._handles.append(handle)
[docs] class HookedModule(TensorDictModuleWrapper): """ Wrapper to enhance a module with hooking capabilities. """ def __init__( self, td_module: TensorDictModule, hooking_context: Optional["HookingContext"] = None, relative_path: str = "td_module", ): super().__init__(td_module)
[docs] self._hooking_context = hooking_context
[docs] self._relative_path = relative_path
@property
[docs] def relative_path(self) -> str: return self._relative_path
[docs] def __repr__(self): fields = indent( f"td_module={self.td_module},\nin_keys={self.in_keys},\nout_keys={self.out_keys}", 4 * " ", ) return f"{type(self).__name__}(\n{fields})"
@property
[docs] def hooking_context(self) -> Optional["HookingContext"]: return self._hooking_context
@classmethod
[docs] def from_module( cls, module: Callable, in_keys: List[str], out_keys: List[str], *, hooking_context: Optional["HookingContext"] = None, **kwargs, ) -> "HookedModule": td_module = TensorDictModule(module, in_keys, out_keys, **kwargs) return cls(td_module, hooking_context=hooking_context)
[docs] def run( self, data: TensorDict, cache: Optional[TensorDict] = None, run_name: Optional[str] = None, run_sep: Optional[str] = None, run_cache: Optional[TensorDict] = None, grad_enabled: bool = False, run_callback: Optional[Callable] = None, ) -> HookedModuleRun: return HookedModuleRun(self, data, cache, run_name, run_sep, run_cache, grad_enabled, run_callback)
[docs] def register_submodule_hook( self, key: str, hook: Callable, direction: HookDirection, prepend: bool = False, relative: bool = True, ): root = resolve_submodule_path(self, self._relative_path) if relative else self submodule = resolve_submodule_path(root, key) if isinstance(submodule, nn.ModuleList): warnings.warn(f"You are hooking a ModuleList ({key}), which will never be executed.") return register_hook_to_module(submodule, hook, direction, prepend)
[docs] def set( self, module_key: str, value: Any, callback: Optional[Callable] = None, direction: HookDirection = "fwd", prepend: bool = False, relative: bool = True, ) -> RemovableHandle: return self.register_submodule_hook( key=module_key, hook=HookFactory.make_setting_hook(value, callback=callback, direction=direction), direction=direction, prepend=prepend, relative=relative, )
[docs] def get( self, cache: TensorDict, module_key: str, cache_key: Optional[str] = None, callback: Optional[Callable] = None, direction: HookDirection = "fwd", prepend: bool = False, relative: bool = True, ) -> Tuple[RemovableHandle, CacheProxy]: cache_key = cache_key or f"{module_key}_{DIRECTION_TO_TYPE[direction]}" proxy = CacheProxy(cache_key, cache) handle = self.register_submodule_hook( key=module_key, hook=HookFactory.make_caching_hook(cache_key, cache, callback=callback, direction=direction), direction=direction, prepend=prepend, relative=relative, ) return handle, proxy
[docs] def set_input(self, *args, **kwargs): kwargs["direction"] = "fwd_pre" return self.set(*args, **kwargs)
[docs] def get_input(self, *args, **kwargs): kwargs["direction"] = "fwd_pre" return self.get(*args, **kwargs)
[docs] def set_grad(self, *args, **kwargs): kwargs["direction"] = "bwd" return self.set(*args, **kwargs)
[docs] def get_grad(self, *args, **kwargs): kwargs["direction"] = "bwd" return self.get(*args, **kwargs)
[docs] def set_grad_output(self, *args, **kwargs): kwargs["direction"] = "bwd_pre" return self.set(*args, **kwargs)
[docs] def get_grad_output(self, *args, **kwargs): kwargs["direction"] = "bwd_pre" return self.get(*args, **kwargs)
[docs] def stop(self, key: str) -> None: return self.register_submodule_hook( key=key, hook=HookFactory.make_stopping_hook(key), direction="fwd", )
[docs] def forward(self, *args, **kwargs): if self._hooking_context is not None and not self._hooking_context._in_context: raise RuntimeError("Contextual HookedModule must be called in context") return self.td_module(*args, **kwargs)
@contextmanager
[docs] def disable_context_hooks(self): if self._hooking_context is None: raise RuntimeError("No hooking context provided to this module") with self._hooking_context.disable_hooks(): yield
@contextmanager
[docs] def disable_context(self): if self._hooking_context is None: raise RuntimeError("No hooking context provided to this module") with self._hooking_context.disable() as raw_module: yield raw_module
[docs] def restore(self): """ Restore the module to its original state. This is useful when using prepare(return_context=False) instead of the context manager. """ if self._hooking_context is None: raise RuntimeError("No hooking context provided to this module") if not self._hooking_context._in_context: raise RuntimeError("Context is not active") if self._hooking_context._managed_by_context_manager: raise RuntimeError("Cannot call restore() when context is managed by a context manager. ") self._hooking_context.__exit__(None, None, None)