tdhook.latent.representation_similarity.cka#

Classes#

CkaEstimator

Centered kernel alignment (CKA) between two representations.

Functions#

_validate_inputs(x, y)

_linear_cka(x, y, eps)

Module Contents#

class tdhook.latent.representation_similarity.cka.CkaEstimator(in_key_a='data_a', in_key_b='data_b', out_key='cka', kernel='linear', eps=1e-12)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Centered kernel alignment (CKA) between two representations.

Reads two data tensors from the input TensorDict. Expects (N, D) or (…, N, D) for both tensors, with shared batch shape and sample count. Outputs one scalar similarity value per batch item.

Parameters:
  • in_key_a (str)

  • in_key_b (str)

  • out_key (str)

  • kernel (str)

  • eps (float)

in_key_a = 'data_a'[source]#
in_key_b = 'data_b'[source]#
out_key = 'cka'[source]#
kernel = 'linear'[source]#
eps = 1e-12[source]#
in_keys[source]#
out_keys[source]#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
tdhook.latent.representation_similarity.cka._validate_inputs(x, y)[source]#
Parameters:
  • x (torch.Tensor)

  • y (torch.Tensor)

Return type:

None

tdhook.latent.representation_similarity.cka._linear_cka(x, y, eps)[source]#
Parameters:
  • x (torch.Tensor)

  • y (torch.Tensor)

  • eps (float)

Return type:

torch.Tensor