Source code for tdhook.latent.probing.estimators

"""
Estimators for probing.
"""

from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset


[docs] class MeanDifferenceClassifier: def __init__(self, normalize: bool = True):
[docs] self._normalize = normalize
[docs] self._coef = None
[docs] self._intercept = None
@property
[docs] def coef_(self): if self._coef is None: raise ValueError("Model not fitted") return self._coef
@property
[docs] def intercept_(self): if self._intercept is None: raise ValueError("Model not fitted") return self._intercept
[docs] def fit(self, X, y): if len(y.shape) > 1: raise ValueError("Multiclass classification not supported") y = np.expand_dims(y, 1) pos_count = y.sum() neg_count = (1 - y).sum() if pos_count == 0 or neg_count == 0: raise ValueError("Both classes must be present in y") pos = (X * y).sum(axis=0) / pos_count neg = (X * (1 - y)).sum(axis=0) / neg_count pos_norm = np.linalg.norm(pos) neg_norm = np.linalg.norm(neg) self._coef = pos - neg self._intercept = -0.5 * (pos_norm**2 - neg_norm**2) if self._normalize: coef_norm = np.linalg.norm(self._coef) if coef_norm > 0: self._intercept = self._intercept / coef_norm self._coef = self._coef / coef_norm self._intercept = self._intercept.reshape((1,)) self._coef = self._coef.reshape((1, -1))
[docs] def _decision_function(self, X): return (X * self._coef).sum(axis=1) + self._intercept
[docs] def predict(self, X): return self._decision_function(X) > 0
[docs] def predict_proba(self, X): pos_proba = 1 / (1 + np.exp(-self._decision_function(X))) neg_proba = 1 - pos_proba return np.stack([neg_proba, pos_proba], axis=1)
[docs] class TorchEstimator(nn.Module): """Base class for torch estimators.""" def __init__( self, *, num_classes: Optional[int] = None, epochs: int = 100, lr: float = 1e-3, batch_size: int = 128, device: Optional[torch.device] = None, verbose: bool = False, ): super().__init__()
[docs] self._num_classes = num_classes
[docs] self._epochs = epochs
[docs] self._lr = lr
[docs] self._batch_size = batch_size
[docs] self._device = device
[docs] self._verbose = verbose
[docs] def fit(self, *Xs: torch.Tensor, y: torch.Tensor): dataset = TensorDataset(*Xs, y) train_loader = DataLoader(dataset, batch_size=self._batch_size, shuffle=True) optimizer = torch.optim.Adam(self.parameters(), lr=self._lr) for epoch in range(self._epochs): self.train() for batch in train_loader: *Xs_b, y_b = batch o_b = self(*Xs_b) loss = self._loss_fn(o_b, y_b) optimizer.zero_grad() loss.backward() optimizer.step() if self._verbose and (epoch + 1) % 10 == 0: print(f"Epoch {epoch + 1}/{self._epochs}") return self
[docs] def predict(self, *Xs: torch.Tensor) -> torch.Tensor: self.eval() with torch.no_grad(): Y = self(*Xs) return Y.argmax(dim=-1) if self._num_classes is not None else Y
[docs] def _loss_fn(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self._num_classes is None: # regression if output.shape != target.shape: raise ValueError(f"Output shape {output.shape} does not match target shape {target.shape}") return F.mse_loss(output, target) else: if output.shape[:-1] != target.shape or output.shape[-1] != self._num_classes: raise ValueError(f"Output shape {output.shape} does not match target shape {target.shape}") return F.cross_entropy(output, target)
[docs] class LinearEstimator(TorchEstimator): """Linear estimator: W h + b.""" def __init__( self, d_latent: int, bias: bool = True, **kwargs, ): super().__init__(**kwargs)
[docs] self.linear = nn.Linear(d_latent, self._num_classes or 1, bias=bias, device=self._device)
[docs] def forward(self, *Xs: torch.Tensor) -> torch.Tensor: if len(Xs) != 1: raise ValueError(f"Linear estimator expects 1 input tensor, got {len(Xs)}") return self.linear(Xs[0])
[docs] class BilinearEstimator(TorchEstimator): """Bilinear estimator: h_1^T A h_2 + b.""" def __init__( self, d_latent1: int, d_latent2: int, bias: bool = True, **kwargs, ): super().__init__(**kwargs)
[docs] self.bilinear = nn.Bilinear(d_latent1, d_latent2, self._num_classes or 1, bias=bias, device=self._device)
[docs] def forward(self, h1: torch.Tensor, h2: torch.Tensor) -> torch.Tensor: return self.bilinear(h1, h2)
[docs] class LowRankBilinearEstimator(TorchEstimator): """Low-rank bilinear: (U h_1) * (V h_2) + b.""" def __init__( self, d_latent1: int, d_latent2: int, bias: bool = True, **kwargs, ): super().__init__(**kwargs) out_features = self._num_classes or 1
[docs] self.linear1 = nn.Linear(d_latent1, out_features, device=self._device)
[docs] self.linear2 = nn.Linear(d_latent2, out_features, device=self._device)
if bias: self.bias = nn.Parameter(torch.zeros(out_features, device=self._device)) else: self.register_parameter("bias", None)
[docs] def forward(self, h1: torch.Tensor, h2: torch.Tensor) -> torch.Tensor: output = self.linear1(h1) * self.linear2(h2) if self.bias is not None: output = output + self.bias return output