tdhook.latent.probing.estimators#

Estimators for probing.

Classes#

MeanDifferenceClassifier

TorchEstimator

Base class for torch estimators.

LinearEstimator

Linear estimator: W h + b.

BilinearEstimator

Bilinear estimator: h_1^T A h_2 + b.

LowRankBilinearEstimator

Low-rank bilinear: (U h_1) * (V h_2) + b.

Module Contents#

class tdhook.latent.probing.estimators.MeanDifferenceClassifier(normalize=True)[source]#
Parameters:

normalize (bool)

_normalize = True[source]#
_coef = None[source]#
_intercept = None[source]#
property coef_[source]#
property intercept_[source]#
fit(X, y)[source]#
_decision_function(X)[source]#
predict(X)[source]#
predict_proba(X)[source]#
class tdhook.latent.probing.estimators.TorchEstimator(*, num_classes=None, epochs=100, lr=0.001, batch_size=128, device=None, verbose=False)[source]#

Bases: torch.nn.Module

Base class for torch estimators.

Parameters:
  • num_classes (Optional[int])

  • epochs (int)

  • lr (float)

  • batch_size (int)

  • device (Optional[torch.device])

  • verbose (bool)

_num_classes = None[source]#
_epochs = 100[source]#
_lr = 0.001[source]#
_batch_size = 128[source]#
_device = None[source]#
_verbose = False[source]#
fit(*Xs, y)[source]#
Parameters:
  • Xs (torch.Tensor)

  • y (torch.Tensor)

predict(*Xs)[source]#
Parameters:

Xs (torch.Tensor)

Return type:

torch.Tensor

_loss_fn(output, target)[source]#
Parameters:
  • output (torch.Tensor)

  • target (torch.Tensor)

Return type:

torch.Tensor

class tdhook.latent.probing.estimators.LinearEstimator(d_latent, bias=True, **kwargs)[source]#

Bases: TorchEstimator

Linear estimator: W h + b.

Parameters:
  • d_latent (int)

  • bias (bool)

linear[source]#
forward(*Xs)[source]#
Parameters:

Xs (torch.Tensor)

Return type:

torch.Tensor

class tdhook.latent.probing.estimators.BilinearEstimator(d_latent1, d_latent2, bias=True, **kwargs)[source]#

Bases: TorchEstimator

Bilinear estimator: h_1^T A h_2 + b.

Parameters:
  • d_latent1 (int)

  • d_latent2 (int)

  • bias (bool)

bilinear[source]#
forward(h1, h2)[source]#
Parameters:
  • h1 (torch.Tensor)

  • h2 (torch.Tensor)

Return type:

torch.Tensor

class tdhook.latent.probing.estimators.LowRankBilinearEstimator(d_latent1, d_latent2, bias=True, **kwargs)[source]#

Bases: TorchEstimator

Low-rank bilinear: (U h_1) * (V h_2) + b.

Parameters:
  • d_latent1 (int)

  • d_latent2 (int)

  • bias (bool)

linear1[source]#
linear2[source]#
forward(h1, h2)[source]#
Parameters:
  • h1 (torch.Tensor)

  • h2 (torch.Tensor)

Return type:

torch.Tensor