tdhook.latent.probing.managers#
Classes#
Probe for bilinear estimators; caches first activation when h1 != h2. |
|
Manager for bilinear probes; one probe per (h1, h2) pair. |
Module Contents#
- class tdhook.latent.probing.managers.Probe(estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#
- Parameters:
estimator (Any)
predict_callback (Callable[[Any, Any], Any])
fit_callback (Optional[Callable[[Any, Any], Any]])
data_preprocess_callback (Optional[Callable[[Any], Any]])
- class tdhook.latent.probing.managers.BilinearProbe(h1_key, h2_key, estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#
Bases:
ProbeProbe for bilinear estimators; caches first activation when h1 != h2.
- Parameters:
h1_key (str)
h2_key (str)
estimator (Any)
predict_callback (Callable[[Any, Any], Any])
fit_callback (Optional[Callable[[Any, Any], Any]])
data_preprocess_callback (Optional[Callable[[Any], Any]])
- class tdhook.latent.probing.managers.ProbeManager(estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#
- Parameters:
estimator_class (Any)
estimator_kwargs (dict)
compute_metrics (Callable[[Any, Any], Dict[str, Any]])
allow_overwrite (bool)
data_preprocess_callback (Callable[[Any], Any])
- class tdhook.latent.probing.managers.BilinearProbeManager(pairs, estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#
Bases:
ProbeManagerManager for bilinear probes; one probe per (h1, h2) pair.
- Parameters:
pairs (List[Tuple[str, str]])
estimator_class (Any)
estimator_kwargs (dict)
compute_metrics (Callable[[Any, Any], Dict[str, Any]])
allow_overwrite (bool)
data_preprocess_callback (Optional[Callable[[Any], Any]])
- _pair_probes: Dict[Tuple[str, str, str], BilinearProbe][source]#
- _key_to_probes: Dict[Tuple[str, str], List[BilinearProbe]][source]#
- property key_pattern: str[source]#
Read-only regex alternation of all keys present in pairs.
- Return type:
str
- probe_factory(key, direction)[source]#
- Parameters:
key (str)
direction (tdhook.hooks.HookDirection)
- Return type: