tdhook.latent.probing#

Probing: linear and bilinear probing for model representations.

Submodules#

Classes#

Probing

Linear probing [22] and concept activation vectors [23].

LinearEstimator

Linear estimator: W h + b.

LowRankBilinearEstimator

Low-rank bilinear: (U h_1) * (V h_2) + b.

MeanDifferenceClassifier

BilinearProbe

Probe for bilinear estimators; caches first activation when h1 != h2.

BilinearProbeManager

Manager for bilinear probes; one probe per (h1, h2) pair.

Probe

ProbeManager

Package Contents#

class tdhook.latent.probing.Probing(key_pattern, probe_factory, relative=True, directions=None, additional_keys=None, classes_to_hook=None, classes_to_skip=None)[source]#

Bases: tdhook.contexts.HookingContextFactory

Linear probing [22] and concept activation vectors [23].

Parameters:
  • key_pattern (str)

  • probe_factory (Callable[[str, str], Probe])

  • relative (bool)

  • directions (Optional[List[tdhook.hooks.HookDirection]])

  • additional_keys (Optional[List[str]])

  • classes_to_hook (Optional[List[Type[torch.nn.Module]]])

  • classes_to_skip (Optional[List[Type[torch.nn.Module]]])

default_classes_to_hook#
default_classes_to_skip#
_key_pattern#
_hook_manager#
_relative = True#
_probe_factory#
_directions = ['fwd']#
_additional_keys = None#
property key_pattern: str#
Return type:

str

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.latent.probing.LinearEstimator(d_latent, bias=True, **kwargs)[source]#

Bases: TorchEstimator

Linear estimator: W h + b.

Parameters:
  • d_latent (int)

  • bias (bool)

linear#
forward(*Xs)[source]#
Parameters:

Xs (torch.Tensor)

Return type:

torch.Tensor

class tdhook.latent.probing.LowRankBilinearEstimator(d_latent1, d_latent2, bias=True, **kwargs)[source]#

Bases: TorchEstimator

Low-rank bilinear: (U h_1) * (V h_2) + b.

Parameters:
  • d_latent1 (int)

  • d_latent2 (int)

  • bias (bool)

linear1#
linear2#
forward(h1, h2)[source]#
Parameters:
  • h1 (torch.Tensor)

  • h2 (torch.Tensor)

Return type:

torch.Tensor

class tdhook.latent.probing.MeanDifferenceClassifier(normalize=True)[source]#
Parameters:

normalize (bool)

_normalize = True#
_coef = None#
_intercept = None#
property coef_#
property intercept_#
fit(X, y)[source]#
_decision_function(X)[source]#
predict(X)[source]#
predict_proba(X)[source]#
class tdhook.latent.probing.BilinearProbe(h1_key, h2_key, estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#

Bases: Probe

Probe for bilinear estimators; caches first activation when h1 != h2.

Parameters:
  • h1_key (str)

  • h2_key (str)

  • estimator (Any)

  • predict_callback (Callable[[Any, Any], Any])

  • fit_callback (Optional[Callable[[Any, Any], Any]])

  • data_preprocess_callback (Optional[Callable[[Any], Any]])

_h1_key#
_h2_key#
_cached: Dict[str, Any]#
_waiting_active = False#
step(data, key, labels, step_type, **kwargs)[source]#
Parameters:
  • data (Any)

  • key (str)

  • labels (Any)

  • step_type (str)

_run(h1, h2, labels, step_type)[source]#
Parameters:
  • h1 (Any)

  • h2 (Any)

  • labels (Any)

  • step_type (str)

before_all()[source]#
after_all()[source]#
Return type:

List[Tuple[str, str]]

property is_waiting: bool#
Return type:

bool

class tdhook.latent.probing.BilinearProbeManager(pairs, estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#

Bases: ProbeManager

Manager for bilinear probes; one probe per (h1, h2) pair.

Parameters:
  • pairs (List[Tuple[str, str]])

  • estimator_class (Any)

  • estimator_kwargs (dict)

  • compute_metrics (Callable[[Any, Any], Dict[str, Any]])

  • allow_overwrite (bool)

  • data_preprocess_callback (Optional[Callable[[Any], Any]])

_pairs#
_pair_probes: Dict[Tuple[str, str, str], BilinearProbe]#
_key_to_probes: Dict[Tuple[str, str], List[BilinearProbe]]#
property key_pattern: str#

Read-only regex alternation of all keys present in pairs.

Return type:

str

probe_factory(key, direction)[source]#
Parameters:
  • key (str)

  • direction (tdhook.hooks.HookDirection)

Return type:

Probe

_create_pair_probe(h1, h2, direction)[source]#
Parameters:
  • h1 (str)

  • h2 (str)

  • direction (tdhook.hooks.HookDirection)

Return type:

BilinearProbe

before_all()[source]#

Initialize waiting state on all bilinear probes for a run.

after_all()[source]#

Clear waiting state and raise if any probes still wait on missing keys.

reset_estimators()[source]#
reset_metrics()[source]#
class tdhook.latent.probing.Probe(estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#
Parameters:
  • estimator (Any)

  • predict_callback (Callable[[Any, Any], Any])

  • fit_callback (Optional[Callable[[Any, Any], Any]])

  • data_preprocess_callback (Optional[Callable[[Any], Any]])

_estimator#
_predict_callback#
_fit_callback = None#
_data_preprocess_callback#
step(data, **kwargs)[source]#
Parameters:

data (Any)

_default_data_preprocess_callback(data)[source]#
Parameters:

data (Any)

Return type:

Any

class tdhook.latent.probing.ProbeManager(estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#
Parameters:
  • estimator_class (Any)

  • estimator_kwargs (dict)

  • compute_metrics (Callable[[Any, Any], Dict[str, Any]])

  • allow_overwrite (bool)

  • data_preprocess_callback (Callable[[Any], Any])

_estimator_class#
_estimator_kwargs#
_compute_metrics#
_allow_overwrite = False#
_data_preprocess_callback = None#
_estimators#
_fit_metrics#
_predict_metrics#
property estimators: dict[str, Any]#
Return type:

dict[str, Any]

property fit_metrics: dict[str, Any]#
Return type:

dict[str, Any]

property predict_metrics: dict[str, Any]#
Return type:

dict[str, Any]

probe_factory(key, direction)[source]#
Parameters:
  • key (str)

  • direction (tdhook.hooks.HookDirection)

Return type:

Probe

reset_estimators()[source]#
reset_metrics()[source]#