tdhook.latent.representation_similarity.information_imbalance#

Classes#

InformationImbalanceEstimator

Information Imbalance between two representations.

Functions#

_validate_inputs(x, y)

_compute_ranks_from_dist(dist)

Convert pairwise distances (N, N) to rank matrix (N, N).

_as_distance_input_dtype(x, y)

_directional_imbalance(source_ranks, target_ranks)

_information_imbalance(x, y, *, p)

Module Contents#

class tdhook.latent.representation_similarity.information_imbalance.InformationImbalanceEstimator(in_key_a='data_a', in_key_b='data_b', out_key_a_to_b='information_imbalance_a_to_b', out_key_b_to_a='information_imbalance_b_to_a', p=2.0)[source]#

Bases: tensordict.nn.TensorDictModuleBase

Information Imbalance between two representations.

Reads two data tensors from the input TensorDict. Expects (N, D) or (…, N, D) for both tensors, with shared batch shape and sample count. Outputs both directional imbalances per batch item: A->B and B->A.

This implementation uses the nearest-neighbor definition: for each point i, select j such that r^A_ij = 1 and average r^B_ij with normalization 2 / N, yielding values close to 0 for strong neighborhood predictability and close to 1 for uninformative mappings.

Parameters:
  • in_key_a (str)

  • in_key_b (str)

  • out_key_a_to_b (str)

  • out_key_b_to_a (str)

  • p (float)

in_key_a = 'data_a'[source]#
in_key_b = 'data_b'[source]#
out_key_a_to_b = 'information_imbalance_a_to_b'[source]#
out_key_b_to_a = 'information_imbalance_b_to_a'[source]#
p = 2.0[source]#
in_keys[source]#
out_keys[source]#
forward(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

__repr__()[source]#
tdhook.latent.representation_similarity.information_imbalance._validate_inputs(x, y)[source]#
Parameters:
  • x (torch.Tensor)

  • y (torch.Tensor)

Return type:

None

tdhook.latent.representation_similarity.information_imbalance._compute_ranks_from_dist(dist)[source]#

Convert pairwise distances (N, N) to rank matrix (N, N).

Convention: diagonal entries (self) have rank 0 and non-diagonal entries are ranked from 1 to N-1 according to ascending distance for each row.

Parameters:

dist (torch.Tensor)

Return type:

torch.Tensor

tdhook.latent.representation_similarity.information_imbalance._as_distance_input_dtype(x, y)[source]#
Parameters:
  • x (torch.Tensor)

  • y (torch.Tensor)

Return type:

torch.dtype

tdhook.latent.representation_similarity.information_imbalance._directional_imbalance(source_ranks, target_ranks)[source]#
Parameters:
  • source_ranks (torch.Tensor)

  • target_ranks (torch.Tensor)

Return type:

torch.Tensor

tdhook.latent.representation_similarity.information_imbalance._information_imbalance(x, y, *, p)[source]#
Parameters:
  • x (torch.Tensor)

  • y (torch.Tensor)

  • p (float)

Return type:

tuple[torch.Tensor, torch.Tensor]