Source code for tdhook.latent.probing

from typing import Callable, Optional, List, Protocol, Any, Dict, Type

import numpy as np
from tensordict import TensorDict
import torch.nn as nn
from tensordict.nn import TensorDictModuleBase

from tdhook.contexts import HookingContextFactory
from tdhook.hooks import (
    MultiHookManager,
    HookFactory,
    HookDirection,
    MultiHookHandle,
    DIRECTION_TO_RETURN,
)
from tdhook.modules import HookedModule


[docs] class Probe(Protocol):
[docs] def step(self, data: Any, **kwargs) -> Any: ...
[docs] class Probing(HookingContextFactory): """ Linear probing :cite:`alain2018understanding` and concept activation vectors :cite:`kim2018interpretability`. """
[docs] default_classes_to_hook = (nn.Module,)
[docs] default_classes_to_skip = (nn.ModuleList, nn.Sequential, TensorDictModuleBase)
def __init__( self, key_pattern: str, probe_factory: Callable[[str, str], Probe], relative: bool = True, directions: Optional[List[HookDirection]] = None, additional_keys: Optional[List[str]] = None, classes_to_hook: Optional[List[Type[nn.Module]]] = None, classes_to_skip: Optional[List[Type[nn.Module]]] = None, ): super().__init__()
[docs] self._key_pattern = key_pattern
classes_to_hook = self.default_classes_to_hook if classes_to_hook is None else classes_to_hook classes_to_skip = self.default_classes_to_skip if classes_to_skip is None else classes_to_skip
[docs] self._hook_manager = MultiHookManager(key_pattern, classes_to_hook, classes_to_skip)
[docs] self._relative = relative
[docs] self._probe_factory = probe_factory
[docs] self._directions = directions or ["fwd"]
[docs] self._additional_keys = additional_keys
@property
[docs] def key_pattern(self) -> str: return self._key_pattern
@key_pattern.setter def key_pattern(self, key_pattern: str): self._key_pattern = key_pattern self._hook_manager.pattern = key_pattern
[docs] def _hook_module(self, module: HookedModule) -> MultiHookHandle: handles = [] if self._additional_keys is not None: tmp_cache = TensorDict() handle, additional_items = module.get( cache=tmp_cache, module_key="td_module", cache_key="_additional_keys", callback=lambda **kwargs: kwargs["args"][0].select(*self._additional_keys), direction="fwd_pre", relative=False, ) handles.append(handle) else: additional_items = None def hook_factory(name: str, direction: HookDirection) -> Callable: nonlocal self, additional_items probe = self._probe_factory(name, direction) def callback(**kwargs): nonlocal additional_items if additional_items is not None: _additional_items = additional_items.resolve() else: _additional_items = {} return probe.step(kwargs[DIRECTION_TO_RETURN[direction]], **_additional_items) return HookFactory.make_reading_hook(callback=callback, direction=direction) for direction in self._directions: handles.append( self._hook_manager.register_hook( module, (lambda name: hook_factory(name, direction)), direction=direction, relative_path=module.relative_path if self._relative else None, ) ) return MultiHookHandle(handles)
[docs] class SklearnProbe: def __init__( self, probe: Any, predict_callback: Callable[[Any, Any], Any], fit_callback: Optional[Callable[[Any, Any], Any]] = None, data_preprocess_callback: Optional[Callable[[Any], Any]] = None, ):
[docs] self._probe = probe
[docs] self._predict_callback = predict_callback
[docs] self._fit_callback = fit_callback
[docs] self._data_preprocess_callback = data_preprocess_callback or self._default_data_preprocess_callback
[docs] def step(self, data: Any, labels: Any, step_type: str): data = self._data_preprocess_callback(data) if step_type == "fit": self._probe.fit(data, labels) if self._fit_callback is not None: self._fit_callback(self._probe.predict(data), labels) elif step_type == "predict": self._predict_callback(self._probe.predict(data), labels) else: raise ValueError(f"Invalid step type: {step_type}")
[docs] def _default_data_preprocess_callback(self, data: Any) -> Any: return data.detach().flatten(1)
[docs] class SklearnProbeManager: def __init__( self, probe_class: Any, probe_kwargs: dict, compute_metrics: Callable[[Any, Any], Dict[str, Any]], allow_overwrite: bool = False, data_preprocess_callback: Callable[[Any], Any] = None, ):
[docs] self._probe_class = probe_class
[docs] self._probe_kwargs = probe_kwargs
[docs] self._compute_metrics = compute_metrics
[docs] self._allow_overwrite = allow_overwrite
[docs] self._data_preprocess_callback = data_preprocess_callback
[docs] self._probes = {}
[docs] self._fit_metrics = {}
[docs] self._predict_metrics = {}
@property
[docs] def probes(self) -> dict[str, SklearnProbe]: return self._probes
@property
[docs] def fit_metrics(self) -> dict[str, Any]: return self._fit_metrics
@property
[docs] def predict_metrics(self) -> dict[str, Any]: return self._predict_metrics
[docs] def probe_factory(self, key: str, direction: HookDirection) -> SklearnProbe: _key = f"{key}_{direction}" if _key in self._probes and not self._allow_overwrite: raise ValueError( f"Probe {_key} already exists, call reset_probes() to reset the probes or use allow_overwrite=True to overwrite the existing probes" ) probe = self._probe_class(**self._probe_kwargs) self._probes[_key] = probe def predict_callback(predictions: Any, labels: Any): nonlocal self if _key in self._predict_metrics and not self._allow_overwrite: raise ValueError( f"Metrics for {_key} already exist, call reset_metrics() to reset the metrics or use allow_overwrite=True to overwrite the existing metrics" ) self._predict_metrics[_key] = self._compute_metrics(predictions, labels) def fit_callback(predictions: Any, labels: Any): nonlocal self if _key in self._fit_metrics and not self._allow_overwrite: raise ValueError( f"Metrics for {_key} already exist, call reset_metrics() to reset the metrics or use allow_overwrite=True to overwrite the existing metrics" ) self._fit_metrics[_key] = self._compute_metrics(predictions, labels) return SklearnProbe(probe, predict_callback, fit_callback, self._data_preprocess_callback)
[docs] def reset_probes(self): self._probes = {}
[docs] def reset_metrics(self): self._fit_metrics = {} self._predict_metrics = {}
[docs] class MeanDifferenceClassifier: def __init__(self, normalize: bool = True):
[docs] self._normalize = normalize
[docs] self._coef = None
[docs] self._intercept = None
@property
[docs] def coef_(self): if self._coef is None: raise ValueError("Model not fitted") return self._coef
@property
[docs] def intercept_(self): if self._intercept is None: raise ValueError("Model not fitted") return self._intercept
[docs] def fit(self, X, y): if len(y.shape) > 1: raise ValueError("Multiclass classification not supported") y = np.expand_dims(y, 1) pos = (X * y).sum(axis=0) / y.sum() neg = (X * (1 - y)).sum(axis=0) / (1 - y).sum() pos_norm = np.linalg.norm(pos) neg_norm = np.linalg.norm(neg) self._coef = pos - neg self._intercept = -0.5 * (pos_norm**2 - neg_norm**2) if self._normalize: self._intercept = self._intercept / np.linalg.norm(self._coef) self._coef = self._coef / np.linalg.norm(self._coef) self._intercept = self._intercept.reshape((1,)) self._coef = self._coef.reshape((1, -1))
[docs] def _decision_function(self, X): return (X * self._coef).sum(axis=1) + self._intercept
[docs] def predict(self, X): return self._decision_function(X) > 0
[docs] def predict_proba(self, X): pos_proba = 1 / (1 + np.exp(-self._decision_function(X))) neg_proba = 1 - pos_proba return np.stack([neg_proba, pos_proba], axis=1)