tdhook.latent.probing.managers#

Classes#

Probe

BilinearProbe

Probe for bilinear estimators; caches first activation when h1 != h2.

ProbeManager

BilinearProbeManager

Manager for bilinear probes; one probe per (h1, h2) pair.

Module Contents#

class tdhook.latent.probing.managers.Probe(estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#
Parameters:
  • estimator (Any)

  • predict_callback (Callable[[Any, Any], Any])

  • fit_callback (Optional[Callable[[Any, Any], Any]])

  • data_preprocess_callback (Optional[Callable[[Any], Any]])

_estimator[source]#
_predict_callback[source]#
_fit_callback = None[source]#
_data_preprocess_callback[source]#
step(data, **kwargs)[source]#
Parameters:

data (Any)

_default_data_preprocess_callback(data)[source]#
Parameters:

data (Any)

Return type:

Any

class tdhook.latent.probing.managers.BilinearProbe(h1_key, h2_key, estimator, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#

Bases: Probe

Probe for bilinear estimators; caches first activation when h1 != h2.

Parameters:
  • h1_key (str)

  • h2_key (str)

  • estimator (Any)

  • predict_callback (Callable[[Any, Any], Any])

  • fit_callback (Optional[Callable[[Any, Any], Any]])

  • data_preprocess_callback (Optional[Callable[[Any], Any]])

_h1_key[source]#
_h2_key[source]#
_cached: Dict[str, Any][source]#
_waiting_active = False[source]#
step(data, key, labels, step_type, **kwargs)[source]#
Parameters:
  • data (Any)

  • key (str)

  • labels (Any)

  • step_type (str)

_run(h1, h2, labels, step_type)[source]#
Parameters:
  • h1 (Any)

  • h2 (Any)

  • labels (Any)

  • step_type (str)

before_all()[source]#
after_all()[source]#
Return type:

List[Tuple[str, str]]

property is_waiting: bool[source]#
Return type:

bool

class tdhook.latent.probing.managers.ProbeManager(estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#
Parameters:
  • estimator_class (Any)

  • estimator_kwargs (dict)

  • compute_metrics (Callable[[Any, Any], Dict[str, Any]])

  • allow_overwrite (bool)

  • data_preprocess_callback (Callable[[Any], Any])

_estimator_class[source]#
_estimator_kwargs[source]#
_compute_metrics[source]#
_allow_overwrite = False[source]#
_data_preprocess_callback = None[source]#
_estimators[source]#
_fit_metrics[source]#
_predict_metrics[source]#
property estimators: dict[str, Any][source]#
Return type:

dict[str, Any]

property fit_metrics: dict[str, Any][source]#
Return type:

dict[str, Any]

property predict_metrics: dict[str, Any][source]#
Return type:

dict[str, Any]

probe_factory(key, direction)[source]#
Parameters:
  • key (str)

  • direction (tdhook.hooks.HookDirection)

Return type:

Probe

reset_estimators()[source]#
reset_metrics()[source]#
class tdhook.latent.probing.managers.BilinearProbeManager(pairs, estimator_class, estimator_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#

Bases: ProbeManager

Manager for bilinear probes; one probe per (h1, h2) pair.

Parameters:
  • pairs (List[Tuple[str, str]])

  • estimator_class (Any)

  • estimator_kwargs (dict)

  • compute_metrics (Callable[[Any, Any], Dict[str, Any]])

  • allow_overwrite (bool)

  • data_preprocess_callback (Optional[Callable[[Any], Any]])

_pairs[source]#
_pair_probes: Dict[Tuple[str, str, str], BilinearProbe][source]#
_key_to_probes: Dict[Tuple[str, str], List[BilinearProbe]][source]#
property key_pattern: str[source]#

Read-only regex alternation of all keys present in pairs.

Return type:

str

probe_factory(key, direction)[source]#
Parameters:
  • key (str)

  • direction (tdhook.hooks.HookDirection)

Return type:

Probe

_create_pair_probe(h1, h2, direction)[source]#
Parameters:
  • h1 (str)

  • h2 (str)

  • direction (tdhook.hooks.HookDirection)

Return type:

BilinearProbe

before_all()[source]#

Initialize waiting state on all bilinear probes for a run.

after_all()[source]#

Clear waiting state and raise if any probes still wait on missing keys.

reset_estimators()[source]#
reset_metrics()[source]#