tdhook.weights.pruning#
Classes#
Base class for hooking contexts. |
|
Module Contents#
- class tdhook.weights.pruning.PruningContext(*args, **kwargs)[source]#
Bases:
tdhook.contexts.HookingContextBase class for hooking contexts.
- class tdhook.weights.pruning.Pruning(importance_callback, amount_to_prune=None, modules_to_prune=None, skip_modules=None, relative_path=None)[source]#
Bases:
tdhook.contexts.HookingContextFactoryRelevance-based pruning [31] and circuit pruning [32].
- Parameters:
importance_callback (Callable)
amount_to_prune (Optional[float | int])
modules_to_prune (Optional[Dict[str, Tuple[int, Optional[float]]]])
skip_modules (Optional[Callable[[str, torch.nn.Module], bool]])
relative_path (Optional[str])
- _prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase