tdhook.modules#
Classes#
Wrapper for a function to be used as a module. |
|
Wrapper to manage module calls. |
|
Wrapper to manage module calls with cache. |
|
Wrapper to manage PGD module calls. |
|
Wrapper to clean intermediate keys. |
|
Context manager to execute module runs. |
|
Wrapper to enhance a module with hooking capabilities. |
Functions#
|
Module Contents#
- tdhook.modules.flatten_select_reshape_call(module, td, flatten=True, select=True, reshape=True)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
td (tensordict.TensorDict)
flatten (bool)
select (bool)
reshape (bool)
- Return type:
tensordict.TensorDict
- class tdhook.modules.FunctionModule(td_fn, in_keys, out_keys)[source]#
Bases:
tensordict.nn.TensorDictModuleBaseWrapper for a function to be used as a module.
- Parameters:
td_fn (Callable[[tensordict.TensorDict], tensordict.TensorDict])
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
- class tdhook.modules.ModuleCall(td_module, in_key=None, out_key=None, flatten=True)[source]#
Bases:
tensordict.nn.TensorDictModuleBaseWrapper to manage module calls.
- Parameters:
td_module (tensordict.nn.TensorDictModuleBase)
in_key (Optional[tdhook._types.UnraveledKey])
out_key (Optional[tdhook._types.UnraveledKey])
flatten (bool)
- class tdhook.modules.ModuleCallWithCache(td_module, stored_keys, cache_key=None, in_key=None, out_key=None, cache_ref=None, flatten=True, cache_as_output=True)[source]#
Bases:
tensordict.nn.TensorDictModuleBaseWrapper to manage module calls with cache.
- Parameters:
td_module (tensordict.nn.TensorDictModuleBase)
stored_keys (List[tdhook._types.UnraveledKey])
cache_key (Optional[tdhook._types.UnraveledKey])
in_key (Optional[tdhook._types.UnraveledKey])
out_key (Optional[tdhook._types.UnraveledKey])
cache_ref (Optional[tdhook.hooks.MutableWeakRef | tdhook.hooks.TensorDictRef])
flatten (bool)
cache_as_output (bool)
- property cache_ref: tdhook.hooks.MutableWeakRef | tdhook.hooks.TensorDictRef[source]#
- Return type:
- class tdhook.modules.PGDModule(td_module, alpha=0.1, n_steps=10, min_value=-float('Inf'), max_value=float('Inf'), grad_key='_grad', working_key='_working', ascent=False, use_sign=True)[source]#
Bases:
tensordict.nn.TensorDictModuleBaseWrapper to manage PGD module calls.
- Parameters:
td_module (tensordict.nn.TensorDictModuleBase)
alpha (float)
n_steps (int)
min_value (float)
max_value (float)
grad_key (tdhook._types.UnraveledKey)
working_key (tdhook._types.UnraveledKey)
ascent (bool)
use_sign (bool)
- class tdhook.modules.IntermediateKeysCleaner(intermediate_keys)[source]#
Bases:
tensordict.nn.TensorDictModuleBaseWrapper to clean intermediate keys.
- Parameters:
intermediate_keys (List[tdhook._types.UnraveledKey])
- class tdhook.modules.HookedModuleRun(module, data, cache=None, run_name=None, run_sep=None, run_cache=None, grad_enabled=False, run_callback=None)[source]#
Context manager to execute module runs.
- Parameters:
module (HookedModule)
data (tensordict.TensorDict)
cache (Optional[tensordict.TensorDict])
run_name (Optional[str])
run_sep (Optional[str])
run_cache (Optional[tensordict.TensorDict])
grad_enabled (bool)
run_callback (Optional[Callable])
- set(key, value, *, callback=None, direction='fwd', prepend=False, relative=True)[source]#
- Parameters:
key (str)
value (Any)
callback (Optional[Callable])
direction (tdhook.hooks.HookDirection)
prepend (bool)
relative (bool)
- Return type:
None
- get(key, *, cache_key=None, callback=None, direction='fwd', prepend=False, relative=True)[source]#
- Parameters:
key (str)
cache_key (Optional[str])
callback (Optional[Callable])
direction (tdhook.hooks.HookDirection)
prepend (bool)
relative (bool)
- Return type:
- class tdhook.modules.HookedModule(td_module, hooking_context=None, relative_path='td_module')[source]#
Bases:
tensordict.nn.TensorDictModuleWrapperWrapper to enhance a module with hooking capabilities.
- Parameters:
td_module (tensordict.nn.TensorDictModule)
hooking_context (Optional[tdhook.contexts.HookingContext])
relative_path (str)
- property hooking_context: tdhook.contexts.HookingContext | None[source]#
- Return type:
Optional[tdhook.contexts.HookingContext]
- classmethod from_module(module, in_keys, out_keys, *, hooking_context=None, **kwargs)[source]#
- Parameters:
module (Callable)
in_keys (List[str])
out_keys (List[str])
hooking_context (Optional[tdhook.contexts.HookingContext])
- Return type:
- run(data, cache=None, run_name=None, run_sep=None, run_cache=None, grad_enabled=False, run_callback=None)[source]#
- Parameters:
data (tensordict.TensorDict)
cache (Optional[tensordict.TensorDict])
run_name (Optional[str])
run_sep (Optional[str])
run_cache (Optional[tensordict.TensorDict])
grad_enabled (bool)
run_callback (Optional[Callable])
- Return type:
- register_submodule_hook(key, hook, direction, prepend=False, relative=True)[source]#
- Parameters:
key (str)
hook (Callable)
direction (tdhook.hooks.HookDirection)
prepend (bool)
relative (bool)
- set(module_key, value, callback=None, direction='fwd', prepend=False, relative=True)[source]#
- Parameters:
module_key (str)
value (Any)
callback (Optional[Callable])
direction (tdhook.hooks.HookDirection)
prepend (bool)
relative (bool)
- Return type:
torch.utils.hooks.RemovableHandle
- get(cache, module_key, cache_key=None, callback=None, direction='fwd', prepend=False, relative=True)[source]#
- Parameters:
cache (tensordict.TensorDict)
module_key (str)
cache_key (Optional[str])
callback (Optional[Callable])
direction (tdhook.hooks.HookDirection)
prepend (bool)
relative (bool)
- Return type:
Tuple[torch.utils.hooks.RemovableHandle, tdhook.hooks.CacheProxy]