tdhook.contexts#

Classes#

HookingContext

Base class for hooking contexts.

HookingContextWithCache

Hooking context with cache.

HookingContextFactory

Factory for creating hooking contexts.

CompositeHookingContextFactory

Composite hooking context factory.

Module Contents#

class tdhook.contexts.HookingContext(factory, module, in_keys=None, out_keys=None, pre_factories=None)[source]#

Base class for hooking contexts.

Parameters:
_prepare[source]#
_restore[source]#
_spawn[source]#
_hook[source]#
_in_context = False[source]#
_handle = None[source]#
_hooked_module = None[source]#
_pre_factories = [][source]#
_stack = None[source]#
_in_keys[source]#
_out_keys[source]#
_managed_by_context_manager = False[source]#
_enter(managed_by_context_manager=True)[source]#
Parameters:

managed_by_context_manager (bool)

__enter__()[source]#
__exit__(exc_type, exc_value, traceback)[source]#
disable_hooks()[source]#
Return type:

Generator[None, None, None]

disable()[source]#
Return type:

Generator[torch.nn.Module, None, None]

class tdhook.contexts.HookingContextWithCache(*args, cache=None, clear_cache=True, **kwargs)[source]#

Bases: HookingContext

Hooking context with cache.

Parameters:
  • cache (Optional[tensordict.TensorDict])

  • clear_cache (bool)

_cache[source]#
_clear_cache = True[source]#
property cache: tensordict.TensorDict[source]#
Return type:

tensordict.TensorDict

clear()[source]#
_enter(managed_by_context_manager=True)[source]#
Parameters:

managed_by_context_manager (bool)

__enter__()[source]#
class tdhook.contexts.HookingContextFactory[source]#

Factory for creating hooking contexts.

_hooked_module_class[source]#
_hooking_context_class[source]#
_hooking_context_kwargs[source]#
_hooked_module_kwargs[source]#
prepare(module: torch.nn.Module, in_keys: List[tdhook._types.UnraveledKey] | Dict[tdhook._types.UnraveledKey, str] | None = None, out_keys: List[tdhook._types.UnraveledKey] | None = None, *, return_context: Literal[True] = True) HookingContext[source]#
prepare(module: torch.nn.Module, in_keys: List[tdhook._types.UnraveledKey] | Dict[tdhook._types.UnraveledKey, str] | None = None, out_keys: List[tdhook._types.UnraveledKey] | None = None, *, return_context: Literal[False]) tdhook.modules.HookedModule

Prepare the module for execution.

Parameters:
  • module – The module to prepare.

  • in_keys – Optional input keys.

  • out_keys – Optional output keys.

  • return_context – If True (default), returns a context manager. If False, returns the hooked module directly.

Returns:

If return_context is True, returns a HookingContext that can be used as a context manager. If return_context is False, returns the HookedModule directly (context is automatically entered).

_prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

_restore_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

_spawn_hooked_module(prep_module, hooking_context, extra_relative_path)[source]#
Parameters:
  • prep_module (tensordict.nn.TensorDictModuleBase)

  • hooking_context (HookingContext)

  • extra_relative_path (str)

Return type:

tdhook.modules.HookedModule

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.contexts.CompositeHookingContextFactory(*contexts)[source]#

Bases: HookingContextFactory

Composite hooking context factory.

Parameters:

contexts (HookingContextFactory)

_contexts = ()[source]#
_prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

_restore_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle