Source code for tdhook.weights.adapters

from typing import Callable, Optional, List, Dict, Tuple
from torch import nn
from tensordict import TensorDict

from tdhook.contexts import HookingContextFactory, HookingContextWithCache
from tdhook.modules import HookedModule
from tdhook.hooks import DIRECTION_TO_RETURN, MultiHookHandle, HookDirection


[docs] class HookedModuleWithAdapters(HookedModule): def __init__(self, *args, adapters: Dict[str, nn.Module], **kwargs): super().__init__(*args, **kwargs)
[docs] self.adapters = nn.ModuleDict(adapters)
[docs] class Adapters(HookingContextFactory): """ ROME :cite:`Meng2022LocatingAE`, sparse autoencoders :cite:`Cunningham2023SparseAF` and transcoders :cite:`Dunefsky2024TranscodersFI`. """
[docs] _hooked_module_class = HookedModuleWithAdapters
[docs] _hooking_context_class = HookingContextWithCache
def __init__( self, adapters: Dict[str, Tuple[nn.Module, str, str]], cache_callback: Optional[Callable] = None, relative: bool = True, directions: Optional[List[HookDirection]] = None, cache: Optional[TensorDict] = None, clear_cache: bool = True, ): super().__init__() self._hooked_module_kwargs["adapters"] = {k: v[0] for k, v in adapters.items()} self._hooking_context_kwargs["clear_cache"] = clear_cache self._hooking_context_kwargs["cache"] = cache
[docs] self._adapters = adapters
[docs] self._cache_callback = cache_callback
[docs] self._relative = relative
[docs] self._directions = directions or ["fwd"]
[docs] def _hook_module(self, module: HookedModule) -> MultiHookHandle: cache = module.hooking_context.cache def callback_factory(adapter, cache_proxy=None): import inspect def callback(**kwargs): nonlocal adapter, cache_proxy if cache_proxy is not None: adapter_input = cache_proxy.resolve() else: adapter_input = kwargs.pop(DIRECTION_TO_RETURN[kwargs["direction"]]) # Filter kwargs to only those accepted by the adapter adapter_params = inspect.signature(adapter).parameters filtered_kwargs = {k: v for k, v in kwargs.items() if k in adapter_params} return adapter(adapter_input, **filtered_kwargs) return callback handles = [] for direction in self._directions: for adapter, in_module_key, out_module_key in self._adapters.values(): if in_module_key == out_module_key: cache_proxy = None else: handle, cache_proxy = module.get( cache=cache, module_key=in_module_key, callback=self._cache_callback, direction=direction, relative=self._relative, ) handles.append(handle) handle = module.set( module_key=out_module_key, value=None, callback=callback_factory(adapter, cache_proxy), direction=direction, relative=self._relative, ) handles.append(handle) return MultiHookHandle(handles)