tdhook.latent#

Module for latent methods.

Submodules#

Classes#

ActivationCaching

Maximally activating samples [14] and attention visualisation [15].

ActivationPatching

Causal mediation analysis [16] and latent editing [17, 18].

Probing

Linear probing [19] and concept activation vectors [20].

SteeringVectors

Steering vectors [21].

ActivationAddition

Factory for creating hooking contexts.

Package Contents#

class tdhook.latent.ActivationCaching(key_pattern, relative=True, cache=None, callback=None, directions=None, use_nested_keys=False, clear_cache=True)[source]#

Bases: tdhook.contexts.HookingContextFactory

Maximally activating samples [14] and attention visualisation [15].

Parameters:
  • key_pattern (str)

  • relative (bool)

  • cache (Optional[tensordict.TensorDict])

  • callback (Optional[Callable])

  • directions (Optional[List[tdhook.hooks.HookDirection]])

  • use_nested_keys (bool)

  • clear_cache (bool)

_hooking_context_class#
_key_pattern#
_relative = True#
_hook_manager#
_callback = None#
_directions = ['fwd']#
_use_nested_keys#
property key_pattern: str#
Return type:

str

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.latent.ActivationPatching(modules_to_patch, patch_key='patched', clean_intermediate_keys=True, patch_fn=None, cache_callback=None)[source]#

Bases: tdhook.contexts.HookingContextFactory

Causal mediation analysis [16] and latent editing [17, 18].

Parameters:
  • modules_to_patch (List[str])

  • patch_key (tdhook._types.UnraveledKey)

  • clean_intermediate_keys (bool)

  • patch_fn (Optional[Callable])

  • cache_callback (Optional[Callable])

_modules_to_patch#
_patch_key = 'patched'#
_clean_intermediate_keys = True#
_patch_fn = None#
_cache_callback = None#
_prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.latent.Probing(key_pattern, probe_factory, relative=True, directions=None, additional_keys=None, classes_to_hook=None, classes_to_skip=None)[source]#

Bases: tdhook.contexts.HookingContextFactory

Linear probing [19] and concept activation vectors [20].

Parameters:
  • key_pattern (str)

  • probe_factory (Callable[[str, str], Probe])

  • relative (bool)

  • directions (Optional[List[tdhook.hooks.HookDirection]])

  • additional_keys (Optional[List[str]])

  • classes_to_hook (Optional[List[Type[torch.nn.Module]]])

  • classes_to_skip (Optional[List[Type[torch.nn.Module]]])

default_classes_to_hook#
default_classes_to_skip#
_key_pattern#
_hook_manager#
_relative = True#
_probe_factory#
_directions = ['fwd']#
_additional_keys = None#
property key_pattern: str#
Return type:

str

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.latent.SteeringVectors(modules_to_steer, steer_fn)[source]#

Bases: tdhook.contexts.HookingContextFactory

Steering vectors [21].

Parameters:
  • modules_to_steer (List[str])

  • steer_fn (Callable)

_modules_to_steer#
_steer_fn#
_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.latent.ActivationAddition(modules_to_steer, positive_key='positive', negative_key='negative', steer_key='steer', clean_intermediate_keys=True, cache_callback=None)[source]#

Bases: tdhook.contexts.HookingContextFactory

Factory for creating hooking contexts.

Parameters:
_modules_to_steer#
_positive_key = 'positive'#
_negative_key = 'negative'#
_steer_key = 'steer'#
_clean_intermediate_keys = True#
_cache_callback = None#
_prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

_compute_steering_vectors(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict