tdhook.latent#
Module for latent methods.
Submodules#
Classes#
Package Contents#
- class tdhook.latent.ActivationCaching(key_pattern, relative=True, cache=None, callback=None, directions=None, use_nested_keys=False, clear_cache=True)[source]#
Bases:
tdhook.contexts.HookingContextFactoryMaximally activating samples [14] and attention visualisation [15].
- Parameters:
key_pattern (str)
relative (bool)
cache (Optional[tensordict.TensorDict])
callback (Optional[Callable])
directions (Optional[List[tdhook.hooks.HookDirection]])
use_nested_keys (bool)
clear_cache (bool)
- _hooking_context_class#
- _key_pattern#
- _relative = True#
- _hook_manager#
- _callback = None#
- _directions = ['fwd']#
- _use_nested_keys#
- property key_pattern: str#
- Return type:
str
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.latent.ActivationPatching(modules_to_patch, patch_key='patched', clean_intermediate_keys=True, patch_fn=None, cache_callback=None)[source]#
Bases:
tdhook.contexts.HookingContextFactoryCausal mediation analysis [16] and latent editing [17, 18].
- Parameters:
modules_to_patch (List[str])
patch_key (tdhook._types.UnraveledKey)
clean_intermediate_keys (bool)
patch_fn (Optional[Callable])
cache_callback (Optional[Callable])
- _modules_to_patch#
- _patch_key = 'patched'#
- _clean_intermediate_keys = True#
- _patch_fn = None#
- _cache_callback = None#
- _prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.latent.Probing(key_pattern, probe_factory, relative=True, directions=None, additional_keys=None, classes_to_hook=None, classes_to_skip=None)[source]#
Bases:
tdhook.contexts.HookingContextFactoryLinear probing [19] and concept activation vectors [20].
- Parameters:
key_pattern (str)
probe_factory (Callable[[str, str], Probe])
relative (bool)
directions (Optional[List[tdhook.hooks.HookDirection]])
additional_keys (Optional[List[str]])
classes_to_hook (Optional[List[Type[torch.nn.Module]]])
classes_to_skip (Optional[List[Type[torch.nn.Module]]])
- default_classes_to_hook#
- default_classes_to_skip#
- _key_pattern#
- _hook_manager#
- _relative = True#
- _probe_factory#
- _directions = ['fwd']#
- _additional_keys = None#
- property key_pattern: str#
- Return type:
str
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.latent.SteeringVectors(modules_to_steer, steer_fn)[source]#
Bases:
tdhook.contexts.HookingContextFactorySteering vectors [21].
- Parameters:
modules_to_steer (List[str])
steer_fn (Callable)
- _modules_to_steer#
- _steer_fn#
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.latent.ActivationAddition(modules_to_steer, positive_key='positive', negative_key='negative', steer_key='steer', clean_intermediate_keys=True, cache_callback=None)[source]#
Bases:
tdhook.contexts.HookingContextFactoryFactory for creating hooking contexts.
- Parameters:
modules_to_steer (List[str])
positive_key (tdhook._types.UnraveledKey)
negative_key (tdhook._types.UnraveledKey)
steer_key (tdhook._types.UnraveledKey)
clean_intermediate_keys (bool)
cache_callback (Optional[Callable])
- _modules_to_steer#
- _positive_key = 'positive'#
- _negative_key = 'negative'#
- _steer_key = 'steer'#
- _clean_intermediate_keys = True#
- _cache_callback = None#
- _prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type: