tdhook.latent#
Module for latent methods.
Submodules#
Classes#
Maximally activating samples [15] and attention visualisation [16]. |
|
Local intrinsic dimension estimation via k-NN distances [20]. |
|
Intrinsic dimension estimation via the Two NN algorithm [21]. |
|
Probe for bilinear estimators; caches first activation when h1 != h2. |
|
Manager for bilinear probes; one probe per (h1, h2) pair. |
|
Linear estimator: W h + b. |
|
Low-rank bilinear: (U h_1) * (V h_2) + b. |
|
Centered kernel alignment (CKA) between two representations. |
|
Information Imbalance between two representations. |
|
Steering vectors [24]. |
|
Factory for creating hooking contexts. |
Package Contents#
- class tdhook.latent.ActivationCaching(key_pattern, relative=True, cache=None, callback=None, directions=None, use_nested_keys=False, clear_cache=True)[source]#
Bases:
tdhook.contexts.HookingContextFactoryMaximally activating samples [15] and attention visualisation [16].
- Parameters:
key_pattern (str)
relative (bool)
cache (Optional[tensordict.TensorDict])
callback (Optional[Callable])
directions (Optional[List[tdhook.hooks.HookDirection]])
use_nested_keys (bool)
clear_cache (bool)
- _hooking_context_class#
- _key_pattern#
- _relative = True#
- _hook_manager#
- _callback = None#
- _directions = ['fwd']#
- _use_nested_keys#
- property key_pattern: str#
- Return type:
str
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.latent.ActivationPatching(modules_to_patch, patch_key='patched', clean_intermediate_keys=True, patch_fn=None, cache_callback=None)[source]#
Bases:
tdhook.contexts.HookingContextFactoryCausal mediation analysis [17] and latent editing [18, 19].
- Parameters:
modules_to_patch (List[str])
patch_key (tdhook._types.UnraveledKey)
clean_intermediate_keys (bool)
patch_fn (Optional[Callable])
cache_callback (Optional[Callable])
- _modules_to_patch#
- _patch_key = 'patched'#
- _clean_intermediate_keys = True#
- _patch_fn = None#
- _cache_callback = None#
- _prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.latent.LocalKnnDimensionEstimator(k='auto', in_key='data', out_key='dimension', eps=1e-05)[source]#
Bases:
tensordict.nn.TensorDictModuleBaseLocal intrinsic dimension estimation via k-NN distances [20].
For each point x, d(x) = ln(2) / ln(R2k/Rk), where Rk and R2k are distances to the k-th and 2k-th nearest neighbors respectively.
Reads a data tensor from the input TensorDict. Expects (N, D) or (…, N, D). Outputs per-point dimension estimates of shape (…, N).
- Parameters:
k (Union[int, Literal['auto']])
in_key (str)
out_key (str)
eps (float)
- k = 'auto'#
- in_key = 'data'#
- out_key = 'dimension'#
- eps = 1e-05#
- in_keys#
- out_keys#
- class tdhook.latent.TwoNnDimensionEstimator(in_key='data', out_key='dimension', return_xy=False, eps=1e-05)[source]#
Bases:
tensordict.nn.TensorDictModuleBaseIntrinsic dimension estimation via the Two NN algorithm [21].
Reads a data tensor from the input TensorDict. Expects (N, D) or (…, N, D). For (…, N, D), flattens all leading dims, computes one dimension per dataset, stacks and reshapes to preserve the original batch shape (excluding last two dims).
- Parameters:
in_key (str)
out_key (str)
return_xy (bool)
eps (float)
- in_key = 'data'#
- out_key = 'dimension'#
- return_xy = False#
- eps = 1e-05#
- in_keys#
- out_keys#
- class tdhook.latent.BilinearProbe(h1_key, h2_key, estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#
Bases:
ProbeProbe for bilinear estimators; caches first activation when h1 != h2.
- Parameters:
h1_key (str)
h2_key (str)
estimator (Any)
predict_callback (Callable[[Any, Any], Any])
fit_callback (Optional[Callable[[Any, Any], Any]])
data_preprocess_callback (Optional[Callable[[Any], Any]])
- _h1_key#
- _h2_key#
- _cached: Dict[str, Any]#
- _waiting_active = False#
- step(data, key, labels, step_type, **kwargs)[source]#
- Parameters:
data (Any)
key (str)
labels (Any)
step_type (str)
- property is_waiting: bool#
- Return type:
bool
- class tdhook.latent.BilinearProbeManager(pairs, estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#
Bases:
ProbeManagerManager for bilinear probes; one probe per (h1, h2) pair.
- Parameters:
pairs (List[Tuple[str, str]])
estimator_class (Any)
estimator_kwargs (dict)
compute_metrics (Callable[[Any, Any], Dict[str, Any]])
allow_overwrite (bool)
data_preprocess_callback (Optional[Callable[[Any], Any]])
- _pairs#
- _pair_probes: Dict[Tuple[str, str, str], BilinearProbe]#
- _key_to_probes: Dict[Tuple[str, str], List[BilinearProbe]]#
- property key_pattern: str#
Read-only regex alternation of all keys present in pairs.
- Return type:
str
- probe_factory(key, direction)[source]#
- Parameters:
key (str)
direction (tdhook.hooks.HookDirection)
- Return type:
- class tdhook.latent.LinearEstimator(d_latent, bias=True, **kwargs)[source]#
Bases:
TorchEstimatorLinear estimator: W h + b.
- Parameters:
d_latent (int)
bias (bool)
- linear#
- class tdhook.latent.LowRankBilinearEstimator(d_latent1, d_latent2, bias=True, **kwargs)[source]#
Bases:
TorchEstimatorLow-rank bilinear: (U h_1) * (V h_2) + b.
- Parameters:
d_latent1 (int)
d_latent2 (int)
bias (bool)
- linear1#
- linear2#
- class tdhook.latent.MeanDifferenceClassifier(normalize=True)[source]#
- Parameters:
normalize (bool)
- _normalize = True#
- _coef = None#
- _intercept = None#
- property coef_#
- property intercept_#
- class tdhook.latent.Probing(key_pattern, probe_factory, relative=True, directions=None, additional_keys=None, classes_to_hook=None, classes_to_skip=None)[source]#
Bases:
tdhook.contexts.HookingContextFactoryLinear probing [22] and concept activation vectors [23].
- Parameters:
key_pattern (str)
probe_factory (Callable[[str, str], Probe])
relative (bool)
directions (Optional[List[tdhook.hooks.HookDirection]])
additional_keys (Optional[List[str]])
classes_to_hook (Optional[List[Type[torch.nn.Module]]])
classes_to_skip (Optional[List[Type[torch.nn.Module]]])
- default_classes_to_hook#
- default_classes_to_skip#
- _key_pattern#
- _hook_manager#
- _relative = True#
- _probe_factory#
- _directions = ['fwd']#
- _additional_keys = None#
- property key_pattern: str#
- Return type:
str
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.latent.Probe(estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#
- Parameters:
estimator (Any)
predict_callback (Callable[[Any, Any], Any])
fit_callback (Optional[Callable[[Any, Any], Any]])
data_preprocess_callback (Optional[Callable[[Any], Any]])
- _estimator#
- _predict_callback#
- _fit_callback = None#
- _data_preprocess_callback#
- class tdhook.latent.ProbeManager(estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#
- Parameters:
estimator_class (Any)
estimator_kwargs (dict)
compute_metrics (Callable[[Any, Any], Dict[str, Any]])
allow_overwrite (bool)
data_preprocess_callback (Callable[[Any], Any])
- _estimator_class#
- _estimator_kwargs#
- _compute_metrics#
- _allow_overwrite = False#
- _data_preprocess_callback = None#
- _estimators#
- _fit_metrics#
- _predict_metrics#
- property estimators: dict[str, Any]#
- Return type:
dict[str, Any]
- property fit_metrics: dict[str, Any]#
- Return type:
dict[str, Any]
- property predict_metrics: dict[str, Any]#
- Return type:
dict[str, Any]
- class tdhook.latent.CkaEstimator(in_key_a='data_a', in_key_b='data_b', out_key='cka', kernel='linear', eps=1e-12)[source]#
Bases:
tensordict.nn.TensorDictModuleBaseCentered kernel alignment (CKA) between two representations.
Reads two data tensors from the input TensorDict. Expects (N, D) or (…, N, D) for both tensors, with shared batch shape and sample count. Outputs one scalar similarity value per batch item.
- Parameters:
in_key_a (str)
in_key_b (str)
out_key (str)
kernel (str)
eps (float)
- in_key_a = 'data_a'#
- in_key_b = 'data_b'#
- out_key = 'cka'#
- kernel = 'linear'#
- eps = 1e-12#
- in_keys#
- out_keys#
- class tdhook.latent.InformationImbalanceEstimator(in_key_a='data_a', in_key_b='data_b', out_key_a_to_b='information_imbalance_a_to_b', out_key_b_to_a='information_imbalance_b_to_a', p=2.0)[source]#
Bases:
tensordict.nn.TensorDictModuleBaseInformation Imbalance between two representations.
Reads two data tensors from the input TensorDict. Expects (N, D) or (…, N, D) for both tensors, with shared batch shape and sample count. Outputs both directional imbalances per batch item: A->B and B->A.
This implementation uses the nearest-neighbor definition: for each point i, select j such that r^A_ij = 1 and average r^B_ij with normalization 2 / N, yielding values close to 0 for strong neighborhood predictability and close to 1 for uninformative mappings.
- Parameters:
in_key_a (str)
in_key_b (str)
out_key_a_to_b (str)
out_key_b_to_a (str)
p (float)
- in_key_a = 'data_a'#
- in_key_b = 'data_b'#
- out_key_a_to_b = 'information_imbalance_a_to_b'#
- out_key_b_to_a = 'information_imbalance_b_to_a'#
- p = 2.0#
- in_keys#
- out_keys#
- class tdhook.latent.SteeringVectors(modules_to_steer, steer_fn)[source]#
Bases:
tdhook.contexts.HookingContextFactorySteering vectors [24].
- Parameters:
modules_to_steer (List[str])
steer_fn (Callable)
- _modules_to_steer#
- _steer_fn#
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.latent.ActivationAddition(modules_to_steer, positive_key='positive', negative_key='negative', steer_key='steer', clean_intermediate_keys=True, cache_callback=None)[source]#
Bases:
tdhook.contexts.HookingContextFactoryFactory for creating hooking contexts.
- Parameters:
modules_to_steer (List[str])
positive_key (tdhook._types.UnraveledKey)
negative_key (tdhook._types.UnraveledKey)
steer_key (tdhook._types.UnraveledKey)
clean_intermediate_keys (bool)
cache_callback (Optional[Callable])
- _modules_to_steer#
- _positive_key = 'positive'#
- _negative_key = 'negative'#
- _steer_key = 'steer'#
- _clean_intermediate_keys = True#
- _cache_callback = None#
- _prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type: