tdhook.modules#

Classes#

FunctionModule

Wrapper for a function to be used as a module.

ModuleCall

Wrapper to manage module calls.

ModuleCallWithCache

Wrapper to manage module calls with cache.

PGDModule

Wrapper to manage PGD module calls.

IntermediateKeysCleaner

Wrapper to clean intermediate keys.

HookedModuleRun

Context manager to execute module runs.

HookedModule

Wrapper to enhance a module with hooking capabilities.

Functions#

get_best_device()

flatten_select_reshape_call(module, td[, flatten, ...])

Module Contents#

tdhook.modules.get_best_device()[source]#
tdhook.modules.flatten_select_reshape_call(module, td, flatten=True, select=True, reshape=True)[source]#
Parameters:
  • module (tensordict.nn.TensorDictModuleBase)

  • td (tensordict.TensorDict)

  • flatten (bool)

  • select (bool)

  • reshape (bool)

Return type:

tensordict.TensorDict

class tdhook.modules.FunctionModule(td_fn, in_keys, out_keys)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Wrapper for a function to be used as a module.

Parameters:
in_keys[source]#
out_keys[source]#
_td_fn[source]#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
class tdhook.modules.ModuleCall(td_module, in_key=None, out_key=None, flatten=True)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Wrapper to manage module calls.

Parameters:
in_keys[source]#
out_keys[source]#
_td_module[source]#
_in_key = None[source]#
_out_key = None[source]#
_flatten = True[source]#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
class tdhook.modules.ModuleCallWithCache(td_module, stored_keys, cache_key=None, in_key=None, out_key=None, cache_ref=None, flatten=True, cache_as_output=True)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Wrapper to manage module calls with cache.

Parameters:
in_keys[source]#
_td_module[source]#
_cache_key = None[source]#
_in_key = None[source]#
_out_key = None[source]#
_flatten = True[source]#
_cache_as_output = True[source]#
_cache_ref[source]#
property cache_ref: tdhook.hooks.MutableWeakRef | tdhook.hooks.TensorDictRef[source]#
Return type:

tdhook.hooks.MutableWeakRef | tdhook.hooks.TensorDictRef

forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
class tdhook.modules.PGDModule(td_module, alpha=0.1, n_steps=10, min_value=-float('Inf'), max_value=float('Inf'), grad_key='_grad', working_key='_working', ascent=False, use_sign=True)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Wrapper to manage PGD module calls.

Parameters:
_td_module[source]#
in_keys[source]#
out_keys[source]#
_alpha = 0.1[source]#
_n_steps = 10[source]#
_min_value[source]#
_max_value[source]#
_grad_key = '_grad'[source]#
_working_key = '_working'[source]#
_ascent = False[source]#
_use_sign = True[source]#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

_pgd_step(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
class tdhook.modules.IntermediateKeysCleaner(intermediate_keys)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Wrapper to clean intermediate keys.

Parameters:

intermediate_keys (List[tdhook._types.UnraveledKey])

in_keys[source]#
out_keys = [][source]#
_intermediate_keys[source]#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
class tdhook.modules.HookedModuleRun(module, data, cache=None, run_name=None, run_sep=None, run_cache=None, grad_enabled=False, run_callback=None)[source]#

Context manager to execute module runs.

Parameters:
  • module (HookedModule)

  • data (tensordict.TensorDict)

  • cache (Optional[tensordict.TensorDict])

  • run_name (Optional[str])

  • run_sep (Optional[str])

  • run_cache (Optional[tensordict.TensorDict])

  • grad_enabled (bool)

  • run_callback (Optional[Callable])

_module[source]#
_data[source]#
_outer_cache = None[source]#
_name = 'run'[source]#
_sep = '.'[source]#
_cache[source]#
_grad_enabled = False[source]#
_run_callback[source]#
_save_cache[source]#
_handles = [][source]#
_in_context = False[source]#
property cache: tensordict.TensorDict[source]#
Return type:

tensordict.TensorDict

__enter__()[source]#
__exit__(exc_type, exc_value, traceback)[source]#
_ensure_in_context(method)[source]#
Parameters:

method (str)

set(key, value, *, callback=None, direction='fwd', prepend=False, relative=True)[source]#
Parameters:
  • key (str)

  • value (Any)

  • callback (Optional[Callable])

  • direction (tdhook.hooks.HookDirection)

  • prepend (bool)

  • relative (bool)

Return type:

None

get(key, *, cache_key=None, callback=None, direction='fwd', prepend=False, relative=True)[source]#
Parameters:
  • key (str)

  • cache_key (Optional[str])

  • callback (Optional[Callable])

  • direction (tdhook.hooks.HookDirection)

  • prepend (bool)

  • relative (bool)

Return type:

tdhook.hooks.CacheProxy

save(key, *, cache_key=None, callback=None, direction='fwd', prepend=False, relative=True)[source]#
Parameters:
  • key (str)

  • cache_key (Optional[str])

  • callback (Optional[Callable])

  • direction (tdhook.hooks.HookDirection)

  • prepend (bool)

  • relative (bool)

Return type:

tdhook.hooks.CacheProxy

set_grad(*args, **kwargs)[source]#
get_grad(*args, **kwargs)[source]#
save_grad(*args, **kwargs)[source]#
set_input(*args, **kwargs)[source]#
get_input(*args, **kwargs)[source]#
save_input(*args, **kwargs)[source]#
set_grad_output(*args, **kwargs)[source]#
get_grad_output(*args, **kwargs)[source]#
save_grad_output(*args, **kwargs)[source]#
stop(key)[source]#
Parameters:

key (str)

Return type:

None

class tdhook.modules.HookedModule(td_module, hooking_context=None, relative_path='td_module')[source]#

Bases: tensordict.nn.TensorDictModuleWrapper

Wrapper to enhance a module with hooking capabilities.

Parameters:
_hooking_context = None[source]#
_relative_path = 'td_module'[source]#
property relative_path: str[source]#
Return type:

str

__repr__()[source]#
property hooking_context: tdhook.contexts.HookingContext | None[source]#
Return type:

Optional[tdhook.contexts.HookingContext]

classmethod from_module(module, in_keys, out_keys, *, hooking_context=None, **kwargs)[source]#
Parameters:
Return type:

HookedModule

run(data, cache=None, run_name=None, run_sep=None, run_cache=None, grad_enabled=False, run_callback=None)[source]#
Parameters:
  • data (tensordict.TensorDict)

  • cache (Optional[tensordict.TensorDict])

  • run_name (Optional[str])

  • run_sep (Optional[str])

  • run_cache (Optional[tensordict.TensorDict])

  • grad_enabled (bool)

  • run_callback (Optional[Callable])

Return type:

HookedModuleRun

register_submodule_hook(key, hook, direction, prepend=False, relative=True)[source]#
Parameters:
  • key (str)

  • hook (Callable)

  • direction (tdhook.hooks.HookDirection)

  • prepend (bool)

  • relative (bool)

set(module_key, value, callback=None, direction='fwd', prepend=False, relative=True)[source]#
Parameters:
  • module_key (str)

  • value (Any)

  • callback (Optional[Callable])

  • direction (tdhook.hooks.HookDirection)

  • prepend (bool)

  • relative (bool)

Return type:

torch.utils.hooks.RemovableHandle

get(cache, module_key, cache_key=None, callback=None, direction='fwd', prepend=False, relative=True)[source]#
Parameters:
  • cache (tensordict.TensorDict)

  • module_key (str)

  • cache_key (Optional[str])

  • callback (Optional[Callable])

  • direction (tdhook.hooks.HookDirection)

  • prepend (bool)

  • relative (bool)

Return type:

Tuple[torch.utils.hooks.RemovableHandle, tdhook.hooks.CacheProxy]

set_input(*args, **kwargs)[source]#
get_input(*args, **kwargs)[source]#
set_grad(*args, **kwargs)[source]#
get_grad(*args, **kwargs)[source]#
set_grad_output(*args, **kwargs)[source]#
get_grad_output(*args, **kwargs)[source]#
stop(key)[source]#
Parameters:

key (str)

Return type:

None

forward(*args, **kwargs)[source]#
disable_context_hooks()[source]#
disable_context()[source]#
restore()[source]#

Restore the module to its original state. This is useful when using prepare(return_context=False) instead of the context manager.