tdhook.latent#

Module for latent methods.

Submodules#

Classes#

ActivationCaching

Maximally activating samples [15] and attention visualisation [16].

ActivationPatching

Causal mediation analysis [17] and latent editing [18, 19].

LocalKnnDimensionEstimator

Local intrinsic dimension estimation via k-NN distances [20].

TwoNnDimensionEstimator

Intrinsic dimension estimation via the Two NN algorithm [21].

BilinearProbe

Probe for bilinear estimators; caches first activation when h1 != h2.

BilinearProbeManager

Manager for bilinear probes; one probe per (h1, h2) pair.

LinearEstimator

Linear estimator: W h + b.

LowRankBilinearEstimator

Low-rank bilinear: (U h_1) * (V h_2) + b.

MeanDifferenceClassifier

Probing

Linear probing [22] and concept activation vectors [23].

Probe

ProbeManager

CkaEstimator

Centered kernel alignment (CKA) between two representations.

InformationImbalanceEstimator

Information Imbalance between two representations.

SteeringVectors

Steering vectors [24].

ActivationAddition

Factory for creating hooking contexts.

Package Contents#

class tdhook.latent.ActivationCaching(key_pattern, relative=True, cache=None, callback=None, directions=None, use_nested_keys=False, clear_cache=True)[source]#

Bases: tdhook.contexts.HookingContextFactory

Maximally activating samples [15] and attention visualisation [16].

Parameters:
  • key_pattern (str)

  • relative (bool)

  • cache (Optional[tensordict.TensorDict])

  • callback (Optional[Callable])

  • directions (Optional[List[tdhook.hooks.HookDirection]])

  • use_nested_keys (bool)

  • clear_cache (bool)

_hooking_context_class#
_key_pattern#
_relative = True#
_hook_manager#
_callback = None#
_directions = ['fwd']#
_use_nested_keys#
property key_pattern: str#
Return type:

str

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.latent.ActivationPatching(modules_to_patch, patch_key='patched', clean_intermediate_keys=True, patch_fn=None, cache_callback=None)[source]#

Bases: tdhook.contexts.HookingContextFactory

Causal mediation analysis [17] and latent editing [18, 19].

Parameters:
  • modules_to_patch (List[str])

  • patch_key (tdhook._types.UnraveledKey)

  • clean_intermediate_keys (bool)

  • patch_fn (Optional[Callable])

  • cache_callback (Optional[Callable])

_modules_to_patch#
_patch_key = 'patched'#
_clean_intermediate_keys = True#
_patch_fn = None#
_cache_callback = None#
_prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.latent.LocalKnnDimensionEstimator(k='auto', in_key='data', out_key='dimension', eps=1e-05)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Local intrinsic dimension estimation via k-NN distances [20].

For each point x, d(x) = ln(2) / ln(R2k/Rk), where Rk and R2k are distances to the k-th and 2k-th nearest neighbors respectively.

Reads a data tensor from the input TensorDict. Expects (N, D) or (…, N, D). Outputs per-point dimension estimates of shape (…, N).

Parameters:
  • k (Union[int, Literal['auto']])

  • in_key (str)

  • out_key (str)

  • eps (float)

k = 'auto'#
in_key = 'data'#
out_key = 'dimension'#
eps = 1e-05#
in_keys#
out_keys#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
class tdhook.latent.TwoNnDimensionEstimator(in_key='data', out_key='dimension', return_xy=False, eps=1e-05)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Intrinsic dimension estimation via the Two NN algorithm [21].

Reads a data tensor from the input TensorDict. Expects (N, D) or (…, N, D). For (…, N, D), flattens all leading dims, computes one dimension per dataset, stacks and reshapes to preserve the original batch shape (excluding last two dims).

Parameters:
  • in_key (str)

  • out_key (str)

  • return_xy (bool)

  • eps (float)

in_key = 'data'#
out_key = 'dimension'#
return_xy = False#
eps = 1e-05#
in_keys#
out_keys#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
class tdhook.latent.BilinearProbe(h1_key, h2_key, estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#

Bases: Probe

Probe for bilinear estimators; caches first activation when h1 != h2.

Parameters:
  • h1_key (str)

  • h2_key (str)

  • estimator (Any)

  • predict_callback (Callable[[Any, Any], Any])

  • fit_callback (Optional[Callable[[Any, Any], Any]])

  • data_preprocess_callback (Optional[Callable[[Any], Any]])

_h1_key#
_h2_key#
_cached: Dict[str, Any]#
_waiting_active = False#
step(data, key, labels, step_type, **kwargs)[source]#
Parameters:
  • data (Any)

  • key (str)

  • labels (Any)

  • step_type (str)

_run(h1, h2, labels, step_type)[source]#
Parameters:
  • h1 (Any)

  • h2 (Any)

  • labels (Any)

  • step_type (str)

before_all()[source]#
after_all()[source]#
Return type:

List[Tuple[str, str]]

property is_waiting: bool#
Return type:

bool

class tdhook.latent.BilinearProbeManager(pairs, estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#

Bases: ProbeManager

Manager for bilinear probes; one probe per (h1, h2) pair.

Parameters:
  • pairs (List[Tuple[str, str]])

  • estimator_class (Any)

  • estimator_kwargs (dict)

  • compute_metrics (Callable[[Any, Any], Dict[str, Any]])

  • allow_overwrite (bool)

  • data_preprocess_callback (Optional[Callable[[Any], Any]])

_pairs#
_pair_probes: Dict[Tuple[str, str, str], BilinearProbe]#
_key_to_probes: Dict[Tuple[str, str], List[BilinearProbe]]#
property key_pattern: str#

Read-only regex alternation of all keys present in pairs.

Return type:

str

probe_factory(key, direction)[source]#
Parameters:
  • key (str)

  • direction (tdhook.hooks.HookDirection)

Return type:

Probe

_create_pair_probe(h1, h2, direction)[source]#
Parameters:
  • h1 (str)

  • h2 (str)

  • direction (tdhook.hooks.HookDirection)

Return type:

BilinearProbe

before_all()[source]#

Initialize waiting state on all bilinear probes for a run.

after_all()[source]#

Clear waiting state and raise if any probes still wait on missing keys.

reset_estimators()[source]#
reset_metrics()[source]#
class tdhook.latent.LinearEstimator(d_latent, bias=True, **kwargs)[source]#

Bases: TorchEstimator

Linear estimator: W h + b.

Parameters:
  • d_latent (int)

  • bias (bool)

linear#
forward(*Xs)[source]#
Parameters:

Xs (torch.Tensor)

Return type:

torch.Tensor

class tdhook.latent.LowRankBilinearEstimator(d_latent1, d_latent2, bias=True, **kwargs)[source]#

Bases: TorchEstimator

Low-rank bilinear: (U h_1) * (V h_2) + b.

Parameters:
  • d_latent1 (int)

  • d_latent2 (int)

  • bias (bool)

linear1#
linear2#
forward(h1, h2)[source]#
Parameters:
  • h1 (torch.Tensor)

  • h2 (torch.Tensor)

Return type:

torch.Tensor

class tdhook.latent.MeanDifferenceClassifier(normalize=True)[source]#
Parameters:

normalize (bool)

_normalize = True#
_coef = None#
_intercept = None#
property coef_#
property intercept_#
fit(X, y)[source]#
_decision_function(X)[source]#
predict(X)[source]#
predict_proba(X)[source]#
class tdhook.latent.Probing(key_pattern, probe_factory, relative=True, directions=None, additional_keys=None, classes_to_hook=None, classes_to_skip=None)[source]#

Bases: tdhook.contexts.HookingContextFactory

Linear probing [22] and concept activation vectors [23].

Parameters:
  • key_pattern (str)

  • probe_factory (Callable[[str, str], Probe])

  • relative (bool)

  • directions (Optional[List[tdhook.hooks.HookDirection]])

  • additional_keys (Optional[List[str]])

  • classes_to_hook (Optional[List[Type[torch.nn.Module]]])

  • classes_to_skip (Optional[List[Type[torch.nn.Module]]])

default_classes_to_hook#
default_classes_to_skip#
_key_pattern#
_hook_manager#
_relative = True#
_probe_factory#
_directions = ['fwd']#
_additional_keys = None#
property key_pattern: str#
Return type:

str

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.latent.Probe(estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#
Parameters:
  • estimator (Any)

  • predict_callback (Callable[[Any, Any], Any])

  • fit_callback (Optional[Callable[[Any, Any], Any]])

  • data_preprocess_callback (Optional[Callable[[Any], Any]])

_estimator#
_predict_callback#
_fit_callback = None#
_data_preprocess_callback#
step(data, **kwargs)[source]#
Parameters:

data (Any)

_default_data_preprocess_callback(data)[source]#
Parameters:

data (Any)

Return type:

Any

class tdhook.latent.ProbeManager(estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#
Parameters:
  • estimator_class (Any)

  • estimator_kwargs (dict)

  • compute_metrics (Callable[[Any, Any], Dict[str, Any]])

  • allow_overwrite (bool)

  • data_preprocess_callback (Callable[[Any], Any])

_estimator_class#
_estimator_kwargs#
_compute_metrics#
_allow_overwrite = False#
_data_preprocess_callback = None#
_estimators#
_fit_metrics#
_predict_metrics#
property estimators: dict[str, Any]#
Return type:

dict[str, Any]

property fit_metrics: dict[str, Any]#
Return type:

dict[str, Any]

property predict_metrics: dict[str, Any]#
Return type:

dict[str, Any]

probe_factory(key, direction)[source]#
Parameters:
  • key (str)

  • direction (tdhook.hooks.HookDirection)

Return type:

Probe

reset_estimators()[source]#
reset_metrics()[source]#
class tdhook.latent.CkaEstimator(in_key_a='data_a', in_key_b='data_b', out_key='cka', kernel='linear', eps=1e-12)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Centered kernel alignment (CKA) between two representations.

Reads two data tensors from the input TensorDict. Expects (N, D) or (…, N, D) for both tensors, with shared batch shape and sample count. Outputs one scalar similarity value per batch item.

Parameters:
  • in_key_a (str)

  • in_key_b (str)

  • out_key (str)

  • kernel (str)

  • eps (float)

in_key_a = 'data_a'#
in_key_b = 'data_b'#
out_key = 'cka'#
kernel = 'linear'#
eps = 1e-12#
in_keys#
out_keys#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
class tdhook.latent.InformationImbalanceEstimator(in_key_a='data_a', in_key_b='data_b', out_key_a_to_b='information_imbalance_a_to_b', out_key_b_to_a='information_imbalance_b_to_a', p=2.0)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Information Imbalance between two representations.

Reads two data tensors from the input TensorDict. Expects (N, D) or (…, N, D) for both tensors, with shared batch shape and sample count. Outputs both directional imbalances per batch item: A->B and B->A.

This implementation uses the nearest-neighbor definition: for each point i, select j such that r^A_ij = 1 and average r^B_ij with normalization 2 / N, yielding values close to 0 for strong neighborhood predictability and close to 1 for uninformative mappings.

Parameters:
  • in_key_a (str)

  • in_key_b (str)

  • out_key_a_to_b (str)

  • out_key_b_to_a (str)

  • p (float)

in_key_a = 'data_a'#
in_key_b = 'data_b'#
out_key_a_to_b = 'information_imbalance_a_to_b'#
out_key_b_to_a = 'information_imbalance_b_to_a'#
p = 2.0#
in_keys#
out_keys#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
class tdhook.latent.SteeringVectors(modules_to_steer, steer_fn)[source]#

Bases: tdhook.contexts.HookingContextFactory

Steering vectors [24].

Parameters:
  • modules_to_steer (List[str])

  • steer_fn (Callable)

_modules_to_steer#
_steer_fn#
_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.latent.ActivationAddition(modules_to_steer, positive_key='positive', negative_key='negative', steer_key='steer', clean_intermediate_keys=True, cache_callback=None)[source]#

Bases: tdhook.contexts.HookingContextFactory

Factory for creating hooking contexts.

Parameters:
_modules_to_steer#
_positive_key = 'positive'#
_negative_key = 'negative'#
_steer_key = 'steer'#
_clean_intermediate_keys = True#
_cache_callback = None#
_prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

_compute_steering_vectors(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict