tdhook.latent.probing#
Probing: linear and bilinear probing for model representations.
Submodules#
Classes#
Linear estimator: W h + b. |
|
Low-rank bilinear: (U h_1) * (V h_2) + b. |
|
Probe for bilinear estimators; caches first activation when h1 != h2. |
|
Manager for bilinear probes; one probe per (h1, h2) pair. |
|
Package Contents#
- class tdhook.latent.probing.Probing(key_pattern, probe_factory, relative=True, directions=None, additional_keys=None, classes_to_hook=None, classes_to_skip=None)[source]#
Bases:
tdhook.contexts.HookingContextFactoryLinear probing [22] and concept activation vectors [23].
- Parameters:
key_pattern (str)
probe_factory (Callable[[str, str], Probe])
relative (bool)
directions (Optional[List[tdhook.hooks.HookDirection]])
additional_keys (Optional[List[str]])
classes_to_hook (Optional[List[Type[torch.nn.Module]]])
classes_to_skip (Optional[List[Type[torch.nn.Module]]])
- default_classes_to_hook#
- default_classes_to_skip#
- _key_pattern#
- _hook_manager#
- _relative = True#
- _probe_factory#
- _directions = ['fwd']#
- _additional_keys = None#
- property key_pattern: str#
- Return type:
str
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.latent.probing.LinearEstimator(d_latent, bias=True, **kwargs)[source]#
Bases:
TorchEstimatorLinear estimator: W h + b.
- Parameters:
d_latent (int)
bias (bool)
- linear#
- class tdhook.latent.probing.LowRankBilinearEstimator(d_latent1, d_latent2, bias=True, **kwargs)[source]#
Bases:
TorchEstimatorLow-rank bilinear: (U h_1) * (V h_2) + b.
- Parameters:
d_latent1 (int)
d_latent2 (int)
bias (bool)
- linear1#
- linear2#
- class tdhook.latent.probing.MeanDifferenceClassifier(normalize=True)[source]#
- Parameters:
normalize (bool)
- _normalize = True#
- _coef = None#
- _intercept = None#
- property coef_#
- property intercept_#
- class tdhook.latent.probing.BilinearProbe(h1_key, h2_key, estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#
Bases:
ProbeProbe for bilinear estimators; caches first activation when h1 != h2.
- Parameters:
h1_key (str)
h2_key (str)
estimator (Any)
predict_callback (Callable[[Any, Any], Any])
fit_callback (Optional[Callable[[Any, Any], Any]])
data_preprocess_callback (Optional[Callable[[Any], Any]])
- _h1_key#
- _h2_key#
- _cached: Dict[str, Any]#
- _waiting_active = False#
- step(data, key, labels, step_type, **kwargs)[source]#
- Parameters:
data (Any)
key (str)
labels (Any)
step_type (str)
- property is_waiting: bool#
- Return type:
bool
- class tdhook.latent.probing.BilinearProbeManager(pairs, estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#
Bases:
ProbeManagerManager for bilinear probes; one probe per (h1, h2) pair.
- Parameters:
pairs (List[Tuple[str, str]])
estimator_class (Any)
estimator_kwargs (dict)
compute_metrics (Callable[[Any, Any], Dict[str, Any]])
allow_overwrite (bool)
data_preprocess_callback (Optional[Callable[[Any], Any]])
- _pairs#
- _pair_probes: Dict[Tuple[str, str, str], BilinearProbe]#
- _key_to_probes: Dict[Tuple[str, str], List[BilinearProbe]]#
- property key_pattern: str#
Read-only regex alternation of all keys present in pairs.
- Return type:
str
- probe_factory(key, direction)[source]#
- Parameters:
key (str)
direction (tdhook.hooks.HookDirection)
- Return type:
- class tdhook.latent.probing.Probe(estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#
- Parameters:
estimator (Any)
predict_callback (Callable[[Any, Any], Any])
fit_callback (Optional[Callable[[Any, Any], Any]])
data_preprocess_callback (Optional[Callable[[Any], Any]])
- _estimator#
- _predict_callback#
- _fit_callback = None#
- _data_preprocess_callback#
- class tdhook.latent.probing.ProbeManager(estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#
- Parameters:
estimator_class (Any)
estimator_kwargs (dict)
compute_metrics (Callable[[Any, Any], Dict[str, Any]])
allow_overwrite (bool)
data_preprocess_callback (Callable[[Any], Any])
- _estimator_class#
- _estimator_kwargs#
- _compute_metrics#
- _allow_overwrite = False#
- _data_preprocess_callback = None#
- _estimators#
- _fit_metrics#
- _predict_metrics#
- property estimators: dict[str, Any]#
- Return type:
dict[str, Any]
- property fit_metrics: dict[str, Any]#
- Return type:
dict[str, Any]
- property predict_metrics: dict[str, Any]#
- Return type:
dict[str, Any]