tdhook.latent.probing.estimators#
Estimators for probing.
Classes#
Base class for torch estimators. |
|
Linear estimator: W h + b. |
|
Bilinear estimator: h_1^T A h_2 + b. |
|
Low-rank bilinear: (U h_1) * (V h_2) + b. |
Module Contents#
- class tdhook.latent.probing.estimators.MeanDifferenceClassifier(normalize=True)[source]#
- Parameters:
normalize (bool)
- class tdhook.latent.probing.estimators.TorchEstimator(*, num_classes=None, epochs=100, lr=0.001, batch_size=128, device=None, verbose=False)[source]#
Bases:
torch.nn.ModuleBase class for torch estimators.
- Parameters:
num_classes (Optional[int])
epochs (int)
lr (float)
batch_size (int)
device (Optional[torch.device])
verbose (bool)
- class tdhook.latent.probing.estimators.LinearEstimator(d_latent, bias=True, **kwargs)[source]#
Bases:
TorchEstimatorLinear estimator: W h + b.
- Parameters:
d_latent (int)
bias (bool)
- class tdhook.latent.probing.estimators.BilinearEstimator(d_latent1, d_latent2, bias=True, **kwargs)[source]#
Bases:
TorchEstimatorBilinear estimator: h_1^T A h_2 + b.
- Parameters:
d_latent1 (int)
d_latent2 (int)
bias (bool)
- class tdhook.latent.probing.estimators.LowRankBilinearEstimator(d_latent1, d_latent2, bias=True, **kwargs)[source]#
Bases:
TorchEstimatorLow-rank bilinear: (U h_1) * (V h_2) + b.
- Parameters:
d_latent1 (int)
d_latent2 (int)
bias (bool)