tdhook.latent.probing#
Classes#
Base class for protocol classes. |
|
Module Contents#
- class tdhook.latent.probing.Probe[source]#
Bases:
ProtocolBase class for protocol classes.
Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:
class GenProto(Protocol[T]): def meth(self) -> T: ...
- class tdhook.latent.probing.Probing(key_pattern, probe_factory, relative=True, directions=None, additional_keys=None, classes_to_hook=None, classes_to_skip=None)[source]#
Bases:
tdhook.contexts.HookingContextFactoryLinear probing [19] and concept activation vectors [20].
- Parameters:
key_pattern (str)
probe_factory (Callable[[str, str], Probe])
relative (bool)
directions (Optional[List[tdhook.hooks.HookDirection]])
additional_keys (Optional[List[str]])
classes_to_hook (Optional[List[Type[torch.nn.Module]]])
classes_to_skip (Optional[List[Type[torch.nn.Module]]])
- _hook_module(module)[source]#
- Parameters:
module (tdhook.modules.HookedModule)
- Return type:
- class tdhook.latent.probing.SklearnProbe(probe, predict_callback, fit_callback=None, data_preprocess_callback=None)[source]#
- Parameters:
probe (Any)
predict_callback (Callable[[Any, Any], Any])
fit_callback (Optional[Callable[[Any, Any], Any]])
data_preprocess_callback (Optional[Callable[[Any], Any]])
- class tdhook.latent.probing.SklearnProbeManager(probe_class, probe_kwargs, compute_metrics, allow_overwrite=False, data_preprocess_callback=None)[source]#
- Parameters:
probe_class (Any)
probe_kwargs (dict)
compute_metrics (Callable[[Any, Any], Dict[str, Any]])
allow_overwrite (bool)
data_preprocess_callback (Callable[[Any], Any])
- property probes: dict[str, SklearnProbe][source]#
- Return type:
dict[str, SklearnProbe]