tdhook.weights.pruning#

Classes#

PruningContext

Base class for hooking contexts.

Pruning

Relevance-based pruning [31] and circuit pruning [32].

Module Contents#

class tdhook.weights.pruning.PruningContext(*args, **kwargs)[source]#

Bases: tdhook.contexts.HookingContext

Base class for hooking contexts.

_old_weights = None[source]#
_enter(managed_by_context_manager=True)[source]#
Parameters:

managed_by_context_manager (bool)

__enter__()[source]#
__exit__(exc_type, exc_value, traceback)[source]#
class tdhook.weights.pruning.Pruning(importance_callback, amount_to_prune=None, modules_to_prune=None, skip_modules=None, relative_path=None)[source]#

Bases: tdhook.contexts.HookingContextFactory

Relevance-based pruning [31] and circuit pruning [32].

Parameters:
  • importance_callback (Callable)

  • amount_to_prune (Optional[float | int])

  • modules_to_prune (Optional[Dict[str, Tuple[int, Optional[float]]]])

  • skip_modules (Optional[Callable[[str, torch.nn.Module], bool]])

  • relative_path (Optional[str])

_hooking_context_class[source]#
_importance_callback[source]#
_amount_to_prune = None[source]#
_modules_to_prune = None[source]#
_skip_modules = None[source]#
_relative_path = ''[source]#
_prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

static default_skip(name, module)[source]#
Parameters:
  • name (str)

  • module (torch.nn.Module)

Return type:

bool