tdhook.latent.representation_similarity#

Representation similarity methods.

Submodules#

Classes#

CkaEstimator

Centered kernel alignment (CKA) between two representations.

InformationImbalanceEstimator

Information Imbalance between two representations.

Package Contents#

class tdhook.latent.representation_similarity.CkaEstimator(in_key_a='data_a', in_key_b='data_b', out_key='cka', kernel='linear', eps=1e-12)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Centered kernel alignment (CKA) between two representations.

Reads two data tensors from the input TensorDict. Expects (N, D) or (…, N, D) for both tensors, with shared batch shape and sample count. Outputs one scalar similarity value per batch item.

Parameters:
  • in_key_a (str)

  • in_key_b (str)

  • out_key (str)

  • kernel (str)

  • eps (float)

in_key_a = 'data_a'#
in_key_b = 'data_b'#
out_key = 'cka'#
kernel = 'linear'#
eps = 1e-12#
in_keys#
out_keys#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
class tdhook.latent.representation_similarity.InformationImbalanceEstimator(in_key_a='data_a', in_key_b='data_b', out_key_a_to_b='information_imbalance_a_to_b', out_key_b_to_a='information_imbalance_b_to_a', p=2.0)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Information Imbalance between two representations.

Reads two data tensors from the input TensorDict. Expects (N, D) or (…, N, D) for both tensors, with shared batch shape and sample count. Outputs both directional imbalances per batch item: A->B and B->A.

This implementation uses the nearest-neighbor definition: for each point i, select j such that r^A_ij = 1 and average r^B_ij with normalization 2 / N, yielding values close to 0 for strong neighborhood predictability and close to 1 for uninformative mappings.

Parameters:
  • in_key_a (str)

  • in_key_b (str)

  • out_key_a_to_b (str)

  • out_key_b_to_a (str)

  • p (float)

in_key_a = 'data_a'#
in_key_b = 'data_b'#
out_key_a_to_b = 'information_imbalance_a_to_b'#
out_key_b_to_a = 'information_imbalance_b_to_a'#
p = 2.0#
in_keys#
out_keys#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#