tdhook.contexts#
Classes#
Base class for hooking contexts. |
|
Hooking context with cache. |
|
Factory for creating hooking contexts. |
|
Composite hooking context factory. |
Module Contents#
- class tdhook.contexts.HookingContext(factory, module, in_keys=None, out_keys=None, pre_factories=None)[source]#
Base class for hooking contexts.
- Parameters:
factory (HookingContextFactory)
module (torch.nn.Module)
in_keys (Optional[List[tdhook._types.UnraveledKey] | Dict[tdhook._types.UnraveledKey, str]])
out_keys (Optional[List[tdhook._types.UnraveledKey]])
pre_factories (Optional[List[HookingContextFactory]])
- class tdhook.contexts.HookingContextWithCache(*args, cache=None, clear_cache=True, **kwargs)[source]#
Bases:
HookingContextHooking context with cache.
- Parameters:
cache (Optional[tensordict.TensorDict])
clear_cache (bool)
- class tdhook.contexts.HookingContextFactory[source]#
Factory for creating hooking contexts.
- prepare(module: torch.nn.Module, in_keys: List[tdhook._types.UnraveledKey] | Dict[tdhook._types.UnraveledKey, str] | None = None, out_keys: List[tdhook._types.UnraveledKey] | None = None, *, return_context: Literal[True] = True) HookingContext[source]#
- prepare(module: torch.nn.Module, in_keys: List[tdhook._types.UnraveledKey] | Dict[tdhook._types.UnraveledKey, str] | None = None, out_keys: List[tdhook._types.UnraveledKey] | None = None, *, return_context: Literal[False]) tdhook.modules.HookedModule
Prepare the module for execution.
- Parameters:
module – The module to prepare.
in_keys – Optional input keys.
out_keys – Optional output keys.
return_context – If True (default), returns a context manager. If False, returns the hooked module directly.
- Returns:
If return_context is True, returns a HookingContext that can be used as a context manager. If return_context is False, returns the HookedModule directly (context is automatically entered).
- _prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase
- _restore_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase
- _spawn_hooked_module(prep_module, hooking_context, extra_relative_path)[source]#
- Parameters:
prep_module (tensordict.nn.TensorDictModuleBase)
hooking_context (HookingContext)
extra_relative_path (str)
- Return type:
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.contexts.CompositeHookingContextFactory(*contexts)[source]#
Bases:
HookingContextFactoryComposite hooking context factory.
- Parameters:
contexts (HookingContextFactory)
- _prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase
- _restore_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type: