Source code for tdhook.latent.activation_patching
from typing import Callable, Optional, List
from tensordict.nn import TensorDictModuleBase, TensorDictSequential
from tdhook.contexts import HookingContextFactory
from tdhook.hooks import MultiHookHandle
from tdhook.modules import HookedModule, ModuleCallWithCache, IntermediateKeysCleaner, ModuleCall
from tdhook._types import UnraveledKey
[docs]
class ActivationPatching(HookingContextFactory):
"""
Causal mediation analysis :cite:`Vig2020InvestigatingGB` and latent editing :cite:`belrose2023leace,Dreyer2023FromHT`.
"""
def __init__(
self,
modules_to_patch: List[str],
patch_key: UnraveledKey = "patched",
clean_intermediate_keys: bool = True,
patch_fn: Optional[Callable] = None,
cache_callback: Optional[Callable] = None,
):
super().__init__()
[docs]
self._modules_to_patch = modules_to_patch
[docs]
self._patch_key = patch_key
[docs]
self._patch_fn = patch_fn
[docs]
self._cache_callback = cache_callback
self._hooked_module_kwargs["relative_path"] = "td_module.module[0]._td_module"
[docs]
def _prepare_module(
self,
module: TensorDictModuleBase,
in_keys: List[UnraveledKey],
out_keys: List[UnraveledKey],
extra_relative_path: str,
) -> TensorDictModuleBase:
stored_keys = [f"{m}_output" for m in self._modules_to_patch]
modules = [
ModuleCallWithCache(
module,
cache_key="_cache",
out_key=None,
stored_keys=stored_keys,
),
ModuleCall(
module,
in_key=self._patch_key,
out_key=self._patch_key,
),
]
if self._clean_intermediate_keys:
modules.append(IntermediateKeysCleaner(intermediate_keys=["_cache"]))
return TensorDictSequential(*modules)
[docs]
def _hook_module(self, module: HookedModule) -> MultiHookHandle:
cache_ref = module.td_module[0].cache_ref
handles = []
for module_key in self._modules_to_patch:
handle, proxy = module.get(
cache=cache_ref,
cache_key=module_key,
module_key=module_key,
callback=self._cache_callback,
)
handles.append(handle)
def callback(**kwargs):
nonlocal module_key, self
value = kwargs["value"]
output = kwargs["output"]
if value is None: # clean run
return output
elif self._patch_fn is not None:
patched_output = self._patch_fn(module_key=module_key, output=output, output_to_patch=value)
return value if patched_output is None else patched_output
else:
return value
handle = module.set(
module_key=module_key,
value=proxy,
callback=callback,
direction="fwd",
prepend=True,
)
handles.append(handle)
return MultiHookHandle(handles)