tdhook.latent.representation_similarity.information_imbalance#
Classes#
Information Imbalance between two representations. |
Functions#
|
|
|
Convert pairwise distances (N, N) to rank matrix (N, N). |
|
|
|
|
|
Module Contents#
- class tdhook.latent.representation_similarity.information_imbalance.InformationImbalanceEstimator(in_key_a='data_a', in_key_b='data_b', out_key_a_to_b='information_imbalance_a_to_b', out_key_b_to_a='information_imbalance_b_to_a', p=2.0)[source]#
Bases:
tensordict.nn.TensorDictModuleBaseInformation Imbalance between two representations.
Reads two data tensors from the input TensorDict. Expects (N, D) or (…, N, D) for both tensors, with shared batch shape and sample count. Outputs both directional imbalances per batch item: A->B and B->A.
This implementation uses the nearest-neighbor definition: for each point i, select j such that r^A_ij = 1 and average r^B_ij with normalization 2 / N, yielding values close to 0 for strong neighborhood predictability and close to 1 for uninformative mappings.
- Parameters:
in_key_a (str)
in_key_b (str)
out_key_a_to_b (str)
out_key_b_to_a (str)
p (float)
- tdhook.latent.representation_similarity.information_imbalance._validate_inputs(x, y)[source]#
- Parameters:
x (torch.Tensor)
y (torch.Tensor)
- Return type:
None
- tdhook.latent.representation_similarity.information_imbalance._compute_ranks_from_dist(dist)[source]#
Convert pairwise distances (N, N) to rank matrix (N, N).
Convention: diagonal entries (self) have rank 0 and non-diagonal entries are ranked from 1 to N-1 according to ascending distance for each row.
- Parameters:
dist (torch.Tensor)
- Return type:
torch.Tensor
- tdhook.latent.representation_similarity.information_imbalance._as_distance_input_dtype(x, y)[source]#
- Parameters:
x (torch.Tensor)
y (torch.Tensor)
- Return type:
torch.dtype