tdhook.hooks#
Attributes#
Exceptions#
Exception for early stopping. |
Classes#
Base class for protocol classes. |
|
Handle for multiple hooks. |
|
Manager for multiple hooks. |
|
Weak reference to a mutable object. |
|
Reference to a TensorDict. |
|
Proxy for a cache. |
|
Factory for creating hooks. |
Functions#
|
Check the signature of the hook. |
|
Merge multiple paths into a single path. |
|
Resolve a submodule path that may contain indexing expressions. |
|
Convert a submodule path to a name. |
|
Register the hook to the module. |
Module Contents#
- tdhook.hooks._check_hook_signature(hook, direction)[source]#
Check the signature of the hook.
- Parameters:
hook (Callable)
direction (HookDirection)
- tdhook.hooks.merge_paths(*paths)[source]#
Merge multiple paths into a single path.
- Parameters:
paths (str)
- Return type:
str
- tdhook.hooks.resolve_submodule_path(root, path)[source]#
Resolve a submodule path that may contain indexing expressions.
Supports any valid Python attribute access and indexing: - “[0]” -> root[0] - “layers[-1]” -> root.layers[-1] - “layers[‘attr’]” -> root.layers[‘attr’] - “layers.attention” -> root.layers.attention - “layers[1:3]” -> root.layers[1:3] - “fn(0)” -> root.fn(0)
Supports custom attributes: - “<block0/module>” -> getattr(root, “block0/module”) - “<block0/module>.layers.attention[0]” -> getattr(root, “block0/module”).layers.attention[0] - “m1.<block0/module>.layers.<module>.linear[0]” -> getattr(getattr(root.m1, “block0/module”).layers, “module”).linear[0] - “m1.<0>.layers” -> getattr(root.m1, “0”).layers
- Parameters:
root (torch.nn.Module)
path (str)
- tdhook.hooks.submodule_path_to_name(path)[source]#
Convert a submodule path to a name.
- Parameters:
path (str)
- Return type:
str
- tdhook.hooks.register_hook_to_module(module, hook, direction, prepend=False)[source]#
Register the hook to the module.
- Parameters:
module (torch.nn.Module)
hook (Callable)
direction (HookDirection)
prepend (bool)
- Return type:
torch.utils.hooks.RemovableHandle
- class tdhook.hooks.RemovableHandleProtocol[source]#
Bases:
ProtocolBase class for protocol classes.
Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:
class GenProto(Protocol[T]): def meth(self) -> T: ...
- class tdhook.hooks.MultiHookHandle(handles=None)[source]#
Handle for multiple hooks.
- Parameters:
handles (Optional[List[RemovableHandleProtocol]])
- class tdhook.hooks.MultiHookManager(pattern=None, classes_to_hook=(nn.Module,), classes_to_skip=())[source]#
Manager for multiple hooks.
- Parameters:
pattern (Optional[str])
classes_to_hook (Tuple[Type[torch.nn.Module], Ellipsis])
classes_to_skip (Tuple[Type[torch.nn.Module], Ellipsis])
- class tdhook.hooks.MutableWeakRef(referee)[source]#
Bases:
Generic[T]Weak reference to a mutable object.
- Parameters:
referee (T)
- class tdhook.hooks.TensorDictRef(td)[source]#
Reference to a TensorDict.
- Parameters:
td (Optional[tensordict.TensorDict])
- class tdhook.hooks.CacheProxy(key, cache)[source]#
Proxy for a cache.
- Parameters:
key (str)
cache (tensordict.TensorDict | MutableWeakRef[tensordict.TensorDict] | TensorDictRef)
- exception tdhook.hooks.EarlyStoppingException(key)[source]#
Bases:
ExceptionException for early stopping.
- Parameters:
key (str)
- class tdhook.hooks.HookFactory[source]#
Factory for creating hooks.
- static _check_callback_signature(callback, expected_param_names)[source]#
Check callback signature matches expected parameter names.
- Parameters:
callback (Callable)
expected_param_names (set[str])
- static make_caching_hook(key, cache, *, callback=None, direction='fwd')[source]#
Make a caching hook.
- Parameters:
cache (tensordict.TensorDict | MutableWeakRef)
callback (Optional[Callable])
direction (HookDirection)
- Return type:
Callable
- static make_setting_hook(value, *, callback=None, direction='fwd')[source]#
Make a setting hook.
- Parameters:
value (Any)
callback (Optional[Callable])
direction (HookDirection)
- Return type:
Callable