tdhook.weights#
Module for weight analysis methods.
Submodules#
Classes#
Package Contents#
- class tdhook.weights.Adapters(adapters, cache_callback=None, relative=True, directions=None, cache=None, clear_cache=True)[source]#
Bases:
tdhook.contexts.HookingContextFactoryROME [28], sparse autoencoders [29] and transcoders [30].
- Parameters:
adapters (Dict[str, Tuple[torch.nn.Module, str, str]])
cache_callback (Optional[Callable])
relative (bool)
directions (Optional[List[tdhook.hooks.HookDirection]])
cache (Optional[tensordict.TensorDict])
clear_cache (bool)
- _hooked_module_class#
- _hooking_context_class#
- _adapters#
- _cache_callback = None#
- _relative = True#
- _directions = ['fwd']#
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.weights.Pruning(importance_callback, amount_to_prune=None, modules_to_prune=None, skip_modules=None, relative_path=None)[source]#
Bases:
tdhook.contexts.HookingContextFactoryRelevance-based pruning [31] and circuit pruning [32].
- Parameters:
importance_callback (Callable)
amount_to_prune (Optional[float | int])
modules_to_prune (Optional[Dict[str, Tuple[int, Optional[float]]]])
skip_modules (Optional[Callable[[str, torch.nn.Module], bool]])
relative_path (Optional[str])
- _hooking_context_class#
- _importance_callback#
- _amount_to_prune = None#
- _modules_to_prune = None#
- _skip_modules = None#
- _relative_path = ''#
- _prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase
- class tdhook.weights.TaskVectors(alphas, get_test_accuracy, get_control_adequacy)[source]#
Bases:
tdhook.contexts.HookingContextFactoryTask vectors [33].
- Parameters:
alphas (Iterable[float])
get_test_accuracy (Callable[[torch.nn.Module], float])
get_control_adequacy (Callable[[torch.nn.Module], bool])
- _hooking_context_class#
- _hooked_module_class#