tdhook.latent.probing#

Classes#

Probe

Base class for protocol classes.

Probing

Linear probing [19] and concept activation vectors [20].

SklearnProbe

SklearnProbeManager

MeanDifferenceClassifier

Module Contents#

class tdhook.latent.probing.Probe[source]#

Bases: Protocol

Base class for protocol classes.

Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
step(data, **kwargs)[source]#
Parameters:

data (Any)

Return type:

Any

class tdhook.latent.probing.Probing(key_pattern, probe_factory, relative=True, directions=None, additional_keys=None, classes_to_hook=None, classes_to_skip=None)[source]#

Bases: tdhook.contexts.HookingContextFactory

Linear probing [19] and concept activation vectors [20].

Parameters:
  • key_pattern (str)

  • probe_factory (Callable[[str, str], Probe])

  • relative (bool)

  • directions (Optional[List[tdhook.hooks.HookDirection]])

  • additional_keys (Optional[List[str]])

  • classes_to_hook (Optional[List[Type[torch.nn.Module]]])

  • classes_to_skip (Optional[List[Type[torch.nn.Module]]])

default_classes_to_hook[source]#
default_classes_to_skip[source]#
_key_pattern[source]#
_hook_manager[source]#
_relative = True[source]#
_probe_factory[source]#
_directions = ['fwd'][source]#
_additional_keys = None[source]#
property key_pattern: str[source]#
Return type:

str

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

class tdhook.latent.probing.SklearnProbe(probe, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#
Parameters:
  • probe (Any)

  • predict_callback (Callable[[Any, Any], Any])

  • fit_callback (Optional[Callable[[Any, Any], Any]])

  • data_preprocess_callback (Optional[Callable[[Any], Any]])

_probe[source]#
_predict_callback[source]#
_fit_callback = None[source]#
_data_preprocess_callback[source]#
step(data, labels, step_type)[source]#
Parameters:
  • data (Any)

  • labels (Any)

  • step_type (str)

_default_data_preprocess_callback(data)[source]#
Parameters:

data (Any)

Return type:

Any

class tdhook.latent.probing.SklearnProbeManager(probe_class, probe_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#
Parameters:
  • probe_class (Any)

  • probe_kwargs (dict)

  • compute_metrics (Callable[[Any, Any], Dict[str, Any]])

  • allow_overwrite (bool)

  • data_preprocess_callback (Callable[[Any], Any])

_probe_class[source]#
_probe_kwargs[source]#
_compute_metrics[source]#
_allow_overwrite = False[source]#
_data_preprocess_callback = None[source]#
_probes[source]#
_fit_metrics[source]#
_predict_metrics[source]#
property probes: dict[str, SklearnProbe][source]#
Return type:

dict[str, SklearnProbe]

property fit_metrics: dict[str, Any][source]#
Return type:

dict[str, Any]

property predict_metrics: dict[str, Any][source]#
Return type:

dict[str, Any]

probe_factory(key, direction)[source]#
Parameters:
  • key (str)

  • direction (tdhook.hooks.HookDirection)

Return type:

SklearnProbe

reset_probes()[source]#
reset_metrics()[source]#
class tdhook.latent.probing.MeanDifferenceClassifier(normalize=True)[source]#
Parameters:

normalize (bool)

_normalize = True[source]#
_coef = None[source]#
_intercept = None[source]#
property coef_[source]#
property intercept_[source]#
fit(X, y)[source]#
_decision_function(X)[source]#
predict(X)[source]#
predict_proba(X)[source]#