tdhook.weights#

Module for weight analysis methods.

Submodules#

Classes#

Adapters

ROME [28], sparse autoencoders [29] and transcoders [30].

Pruning

Relevance-based pruning [31] and circuit pruning [32].

TaskVectors

Task vectors [33].

Package Contents#

class tdhook.weights.Adapters(adapters, cache_callback=None, relative=True, directions=None, cache=None, clear_cache=True)[source]#

Bases: tdhook.contexts.HookingContextFactory

ROME [28], sparse autoencoders [29] and transcoders [30].

Parameters:
  • adapters (Dict[str, Tuple[torch.nn.Module, str, str]])

  • cache_callback (Optional[Callable])

  • relative (bool)

  • directions (Optional[List[tdhook.hooks.HookDirection]])

  • cache (Optional[tensordict.TensorDict])

  • clear_cache (bool)

_hooked_module_class#
_hooking_context_class#
_adapters#
_cache_callback = None#
_relative = True#
_directions = ['fwd']#
_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.weights.Pruning(importance_callback, amount_to_prune=None, modules_to_prune=None, skip_modules=None, relative_path=None)[source]#

Bases: tdhook.contexts.HookingContextFactory

Relevance-based pruning [31] and circuit pruning [32].

Parameters:
  • importance_callback (Callable)

  • amount_to_prune (Optional[float | int])

  • modules_to_prune (Optional[Dict[str, Tuple[int, Optional[float]]]])

  • skip_modules (Optional[Callable[[str, torch.nn.Module], bool]])

  • relative_path (Optional[str])

_hooking_context_class#
_importance_callback#
_amount_to_prune = None#
_modules_to_prune = None#
_skip_modules = None#
_relative_path = ''#
_prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

static default_skip(name, module)[source]#
Parameters:
  • name (str)

  • module (torch.nn.Module)

Return type:

bool

class tdhook.weights.TaskVectors(alphas, get_test_accuracy, get_control_adequacy)[source]#

Bases: tdhook.contexts.HookingContextFactory

Task vectors [33].

Parameters:
  • alphas (Iterable[float])

  • get_test_accuracy (Callable[[torch.nn.Module], float])

  • get_control_adequacy (Callable[[torch.nn.Module], bool])

_hooking_context_class#
_hooked_module_class#