tdhook.hooks#

Attributes#

Exceptions#

EarlyStoppingException

Exception for early stopping.

Classes#

RemovableHandleProtocol

Base class for protocol classes.

MultiHookHandle

Handle for multiple hooks.

MultiHookManager

Manager for multiple hooks.

MutableWeakRef

Weak reference to a mutable object.

TensorDictRef

Reference to a TensorDict.

CacheProxy

Proxy for a cache.

HookFactory

Factory for creating hooks.

Functions#

_check_hook_signature(hook, direction)

Check the signature of the hook.

merge_paths(*paths)

Merge multiple paths into a single path.

resolve_submodule_path(root, path)

Resolve a submodule path that may contain indexing expressions.

submodule_path_to_name(path)

Convert a submodule path to a name.

register_hook_to_module(module, hook, direction[, prepend])

Register the hook to the module.

Module Contents#

tdhook.hooks.HookDirection[source]#
tdhook.hooks.T[source]#
tdhook.hooks.DIRECTION_TO_PARAMS[source]#
tdhook.hooks.DIRECTION_TO_RETURN[source]#
tdhook.hooks.DIRECTION_TO_RETURN_INDEX[source]#
tdhook.hooks.DIRECTION_TO_TYPE[source]#
tdhook.hooks._check_hook_signature(hook, direction)[source]#

Check the signature of the hook.

Parameters:
  • hook (Callable)

  • direction (HookDirection)

tdhook.hooks.merge_paths(*paths)[source]#

Merge multiple paths into a single path.

Parameters:

paths (str)

Return type:

str

tdhook.hooks.resolve_submodule_path(root, path)[source]#

Resolve a submodule path that may contain indexing expressions.

Supports any valid Python attribute access and indexing: - “[0]” -> root[0] - “layers[-1]” -> root.layers[-1] - “layers[‘attr’]” -> root.layers[‘attr’] - “layers.attention” -> root.layers.attention - “layers[1:3]” -> root.layers[1:3] - “fn(0)” -> root.fn(0)

Supports custom attributes: - “<block0/module>” -> getattr(root, “block0/module”) - “<block0/module>.layers.attention[0]” -> getattr(root, “block0/module”).layers.attention[0] - “m1.<block0/module>.layers.<module>.linear[0]” -> getattr(getattr(root.m1, “block0/module”).layers, “module”).linear[0] - “m1.<0>.layers” -> getattr(root.m1, “0”).layers

Parameters:
  • root (torch.nn.Module)

  • path (str)

tdhook.hooks.submodule_path_to_name(path)[source]#

Convert a submodule path to a name.

Parameters:

path (str)

Return type:

str

tdhook.hooks.register_hook_to_module(module, hook, direction, prepend=False)[source]#

Register the hook to the module.

Parameters:
  • module (torch.nn.Module)

  • hook (Callable)

  • direction (HookDirection)

  • prepend (bool)

Return type:

torch.utils.hooks.RemovableHandle

class tdhook.hooks.RemovableHandleProtocol[source]#

Bases: Protocol

Base class for protocol classes.

Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
remove()[source]#
class tdhook.hooks.MultiHookHandle(handles=None)[source]#

Handle for multiple hooks.

Parameters:

handles (Optional[List[RemovableHandleProtocol]])

_handles = [][source]#
remove()[source]#
__enter__()[source]#
__exit__(exc_type, exc_value, traceback)[source]#
__add__(other)[source]#
Parameters:

other (Any)

class tdhook.hooks.MultiHookManager(pattern=None, classes_to_hook=(nn.Module,), classes_to_skip=())[source]#

Manager for multiple hooks.

Parameters:
  • pattern (Optional[str])

  • classes_to_hook (Tuple[Type[torch.nn.Module], Ellipsis])

  • classes_to_skip (Tuple[Type[torch.nn.Module], Ellipsis])

_pattern = None[source]#
_classes_to_hook[source]#
_classes_to_skip = ()[source]#
_reg_exp[source]#
property pattern: str[source]#

The pattern to match the modules.

Return type:

str

register_hook(module, hook_factory, *, direction='fwd', prepend=False, relative_path=None)[source]#

Register the hook to the module.

Parameters:
  • module (torch.nn.Module)

  • hook_factory (Callable[[str], Callable])

  • direction (HookDirection)

  • prepend (bool)

  • relative_path (Optional[str])

class tdhook.hooks.MutableWeakRef(referee)[source]#

Bases: Generic[T]

Weak reference to a mutable object.

Parameters:

referee (T)

_ref[source]#
resolve()[source]#
Return type:

T

set(referee)[source]#
Parameters:

referee (T)

class tdhook.hooks.TensorDictRef(td)[source]#

Reference to a TensorDict.

Parameters:

td (Optional[tensordict.TensorDict])

_td[source]#
resolve()[source]#
Return type:

tensordict.TensorDict

set(td)[source]#
Parameters:

td (tensordict.TensorDict)

class tdhook.hooks.CacheProxy(key, cache)[source]#

Proxy for a cache.

Parameters:
_key[source]#
_cache[source]#
resolve()[source]#
Return type:

Any

exception tdhook.hooks.EarlyStoppingException(key)[source]#

Bases: Exception

Exception for early stopping.

Parameters:

key (str)

_key[source]#
class tdhook.hooks.HookFactory[source]#

Factory for creating hooks.

static _check_callback_signature(callback, expected_param_names)[source]#

Check callback signature matches expected parameter names.

Parameters:
  • callback (Callable)

  • expected_param_names (set[str])

static make_caching_hook(key, cache, *, callback=None, direction='fwd')[source]#

Make a caching hook.

Parameters:
Return type:

Callable

static make_setting_hook(value, *, callback=None, direction='fwd')[source]#

Make a setting hook.

Parameters:
  • value (Any)

  • callback (Optional[Callable])

  • direction (HookDirection)

Return type:

Callable

static make_reading_hook(*, callback, direction='fwd')[source]#

Make a reading hook.

Parameters:
  • callback (Callable)

  • direction (HookDirection)

Return type:

Callable

static make_stopping_hook(key)[source]#
Parameters:

key (str)

Return type:

Callable