Source code for tdhook.latent.activation_caching

from typing import Callable, Optional, List

from tensordict import TensorDict

from tdhook.modules import HookedModule
from tdhook.contexts import HookingContextFactory, HookingContextWithCache
from tdhook.hooks import MultiHookManager, HookFactory, HookDirection, MultiHookHandle


[docs] class ActivationCaching(HookingContextFactory): """ Maximally activating samples :cite:`Chen2020ConceptWF` and attention visualisation :cite:`Abnar2020QuantifyingAF`. """
[docs] _hooking_context_class = HookingContextWithCache
def __init__( self, key_pattern: str, relative: bool = True, cache: Optional[TensorDict] = None, callback: Optional[Callable] = None, directions: Optional[List[HookDirection]] = None, use_nested_keys: bool = False, clear_cache: bool = True, ): super().__init__() self._hooking_context_kwargs["cache"] = cache self._hooking_context_kwargs["clear_cache"] = clear_cache
[docs] self._key_pattern = key_pattern
[docs] self._relative = relative
[docs] self._hook_manager = MultiHookManager(key_pattern)
[docs] self._callback = callback
[docs] self._directions = directions or ["fwd"]
[docs] self._use_nested_keys = use_nested_keys or len(self._directions) > 1
@property
[docs] def key_pattern(self) -> str: return self._key_pattern
@key_pattern.setter def key_pattern(self, key_pattern: str): self._key_pattern = key_pattern self._hook_manager.pattern = key_pattern
[docs] def _hook_module(self, module: HookedModule) -> MultiHookHandle: cache = module.hooking_context.cache def hook_factory(name: str, direction: HookDirection) -> Callable: nonlocal self, cache key = (direction, name) if self._use_nested_keys else name return HookFactory.make_caching_hook(key, cache, direction=direction, callback=self._callback) handles = [] for direction in self._directions: handles.append( self._hook_manager.register_hook( module, (lambda name: hook_factory(name, direction)), direction=direction, relative_path=module.relative_path if self._relative else None, ) ) return MultiHookHandle(handles)