Source code for tdhook.latent.probing.managers

import re
from typing import Callable, Optional, Any, Dict, List, Tuple

from tdhook.hooks import (
    HookDirection,
)


[docs] class Probe: def __init__( self, estimator: Any, predict_callback: Callable[[Any, Any], Any], fit_callback: Optional[Callable[[Any, Any], Any]] = None, data_preprocess_callback: Optional[Callable[[Any], Any]] = None, ):
[docs] self._estimator = estimator
[docs] self._predict_callback = predict_callback
[docs] self._fit_callback = fit_callback
[docs] self._data_preprocess_callback = data_preprocess_callback or self._default_data_preprocess_callback
[docs] def step(self, data: Any, **kwargs): labels = kwargs.get("labels") step_type = kwargs.get("step_type") if step_type not in ("fit", "predict"): raise ValueError(f"step_type must be 'fit' or 'predict', got {step_type!r}") data = self._data_preprocess_callback(data) if step_type == "fit": self._estimator.fit(data, y=labels) if self._fit_callback is not None: self._fit_callback(self._estimator.predict(data), labels) elif step_type == "predict": self._predict_callback(self._estimator.predict(data), labels) else: raise ValueError(f"Invalid step type: {step_type}")
[docs] def _default_data_preprocess_callback(self, data: Any) -> Any: return data.detach().flatten(1)
[docs] class BilinearProbe(Probe): """Probe for bilinear estimators; caches first activation when h1 != h2.""" def __init__( self, h1_key: str, h2_key: str, estimator: Any, predict_callback: Callable[[Any, Any], Any], fit_callback: Optional[Callable[[Any, Any], Any]] = None, data_preprocess_callback: Optional[Callable[[Any], Any]] = None, ): super().__init__(estimator, predict_callback, fit_callback, data_preprocess_callback)
[docs] self._h1_key = h1_key
[docs] self._h2_key = h2_key
[docs] self._cached: Dict[str, Any] = {}
[docs] self._waiting_active = False
[docs] def step(self, data: Any, key: str, labels: Any, step_type: str, **kwargs): data = self._data_preprocess_callback(data) if self._h1_key == self._h2_key: self._run(data, data, labels, step_type) return if key == self._h1_key: self._cached["h1"] = data elif key == self._h2_key: self._cached["h2"] = data else: return if "h1" in self._cached and "h2" in self._cached: h1, h2 = self._cached["h1"], self._cached["h2"] self._cached.clear() self._run(h1, h2, labels, step_type)
[docs] def _run(self, h1: Any, h2: Any, labels: Any, step_type: str): if step_type not in ("fit", "predict"): raise ValueError(f"step_type must be 'fit' or 'predict', got {step_type!r}") if step_type == "fit": self._estimator.fit(h1, h2, y=labels) if self._fit_callback is not None: self._fit_callback(self._estimator.predict(h1, h2), labels) elif step_type == "predict": self._predict_callback(self._estimator.predict(h1, h2), labels) else: raise ValueError(f"Invalid step type: {step_type}")
[docs] def before_all(self): self._waiting_active = True self._cached.clear()
[docs] def after_all(self) -> List[Tuple[str, str]]: self._waiting_active = False still_waiting = [] if self._h1_key != self._h2_key and self._cached: missing = [] if "h1" not in self._cached: missing.append(self._h1_key) if "h2" not in self._cached: missing.append(self._h2_key) if missing: still_waiting.append((self._h1_key, self._h2_key)) self._cached.clear() return still_waiting
@property
[docs] def is_waiting(self) -> bool: return ( self._h1_key != self._h2_key and ("h1" in self._cached or "h2" in self._cached) and not ("h1" in self._cached and "h2" in self._cached) )
[docs] class ProbeManager: def __init__( self, estimator_class: Any, estimator_kwargs: dict, compute_metrics: Callable[[Any, Any], Dict[str, Any]], allow_overwrite: bool = False, data_preprocess_callback: Callable[[Any], Any] = None, ):
[docs] self._estimator_class = estimator_class
[docs] self._estimator_kwargs = estimator_kwargs
[docs] self._compute_metrics = compute_metrics
[docs] self._allow_overwrite = allow_overwrite
[docs] self._data_preprocess_callback = data_preprocess_callback
[docs] self._estimators = {}
[docs] self._fit_metrics = {}
[docs] self._predict_metrics = {}
@property
[docs] def estimators(self) -> dict[str, Any]: return self._estimators
@property
[docs] def fit_metrics(self) -> dict[str, Any]: return self._fit_metrics
@property
[docs] def predict_metrics(self) -> dict[str, Any]: return self._predict_metrics
[docs] def probe_factory(self, key: str, direction: HookDirection) -> Probe: _key = f"{key}_{direction}" if _key in self._estimators and not self._allow_overwrite: raise ValueError( f"Probe {_key} already exists, call reset_estimators() to reset the estimators or use allow_overwrite=True to overwrite the existing estimators" ) estimator = self._estimator_class(**self._estimator_kwargs) self._estimators[_key] = estimator def predict_callback(predictions: Any, labels: Any): nonlocal self if _key in self._predict_metrics and not self._allow_overwrite: raise ValueError( f"Metrics for {_key} already exist, call reset_metrics() to reset the metrics or use allow_overwrite=True to overwrite the existing metrics" ) self._predict_metrics[_key] = self._compute_metrics(predictions, labels) def fit_callback(predictions: Any, labels: Any): nonlocal self if _key in self._fit_metrics and not self._allow_overwrite: raise ValueError( f"Metrics for {_key} already exist, call reset_metrics() to reset the metrics or use allow_overwrite=True to overwrite the existing metrics" ) self._fit_metrics[_key] = self._compute_metrics(predictions, labels) return Probe(estimator, predict_callback, fit_callback, self._data_preprocess_callback)
[docs] def reset_estimators(self): self._estimators = {}
[docs] def reset_metrics(self): self._fit_metrics = {} self._predict_metrics = {}
[docs] class BilinearProbeManager(ProbeManager): """Manager for bilinear probes; one probe per (h1, h2) pair.""" def __init__( self, pairs: List[Tuple[str, str]], estimator_class: Any, estimator_kwargs: dict, compute_metrics: Callable[[Any, Any], Dict[str, Any]], allow_overwrite: bool = False, data_preprocess_callback: Optional[Callable[[Any], Any]] = None, ): super().__init__( estimator_class=estimator_class, estimator_kwargs=estimator_kwargs, compute_metrics=compute_metrics, allow_overwrite=allow_overwrite, data_preprocess_callback=data_preprocess_callback, )
[docs] self._pairs = list(pairs)
[docs] self._pair_probes: Dict[Tuple[str, str, str], BilinearProbe] = {}
[docs] self._key_to_probes: Dict[Tuple[str, str], List[BilinearProbe]] = {}
@property
[docs] def key_pattern(self) -> str: """Read-only regex alternation of all keys present in pairs.""" keys = sorted(set(k for pair in self._pairs for k in pair)) return "|".join(re.escape(k) + "$" for k in keys)
[docs] def probe_factory(self, key: str, direction: HookDirection) -> Probe: key_dir = (key, direction) if key_dir in self._key_to_probes: probes = self._key_to_probes[key_dir] else: probes = [] for h1, h2 in self._pairs: if key not in (h1, h2): continue pair_key = (h1, h2, direction) if pair_key not in self._pair_probes: probe = self._create_pair_probe(h1, h2, direction) self._pair_probes[pair_key] = probe probes.append(self._pair_probes[pair_key]) self._key_to_probes[key_dir] = probes def dispatcher_step(data: Any, **kwargs): for probe in probes: probe.step(data, key=key, **kwargs) class DispatcherProbe: def step(self, data: Any, **kwargs): dispatcher_step(data, **kwargs) return DispatcherProbe()
[docs] def _create_pair_probe(self, h1: str, h2: str, direction: HookDirection) -> BilinearProbe: pair_key = f"{h1}_{h2}_{direction}" if pair_key in self._estimators and not self._allow_overwrite: raise ValueError( f"Probe {pair_key} already exists, call reset_estimators() to reset or use allow_overwrite=True" ) estimator = self._estimator_class(**self._estimator_kwargs) self._estimators[pair_key] = estimator def predict_callback(predictions: Any, labels: Any): if pair_key in self._predict_metrics and not self._allow_overwrite: raise ValueError( f"Metrics for {pair_key} already exist, call reset_metrics() or use allow_overwrite=True" ) self._predict_metrics[pair_key] = self._compute_metrics(predictions, labels) def fit_callback(predictions: Any, labels: Any): if pair_key in self._fit_metrics and not self._allow_overwrite: raise ValueError( f"Metrics for {pair_key} already exist, call reset_metrics() or use allow_overwrite=True" ) self._fit_metrics[pair_key] = self._compute_metrics(predictions, labels) return BilinearProbe( h1_key=h1, h2_key=h2, estimator=estimator, predict_callback=predict_callback, fit_callback=fit_callback, data_preprocess_callback=self._data_preprocess_callback, )
[docs] def before_all(self): """Initialize waiting state on all bilinear probes for a run.""" for probe in self._pair_probes.values(): probe.before_all()
[docs] def after_all(self): """Clear waiting state and raise if any probes still wait on missing keys.""" still_waiting: List[Tuple[str, str]] = [] for probe in self._pair_probes.values(): still_waiting.extend(probe.after_all()) if still_waiting: keys = sorted(set(k for pair in still_waiting for k in pair)) raise ValueError(f"Bilinear probes still waiting on keys: {keys}. Unresolved pairs: {still_waiting}")
[docs] def reset_estimators(self): super().reset_estimators() self._pair_probes.clear() self._key_to_probes.clear()
[docs] def reset_metrics(self): super().reset_metrics()