Representation Similarity#

This notebook introduces representation similarity methods in tdhook.

It currently covers:

  • centered kernel alignment (CKA) through tdhook.latent.representation_similarity.CkaEstimator

  • Information Imbalance through tdhook.latent.representation_similarity.InformationImbalanceEstimator

Setup#

[1]:
import importlib.util

DEV = True

if importlib.util.find_spec("google.colab") is not None:
    MODE = "colab-dev" if DEV else "colab"
else:
    MODE = "local"
[2]:
if MODE == "colab":
    %pip install -q tdhook
elif MODE == "colab-dev":
    !rm -rf tdhook
    !git clone https://github.com/Xmaster6y/tdhook -b main
    %pip install -q ./tdhook

Imports#

[3]:
import torch
from tensordict import TensorDict

from tdhook.latent.representation_similarity import CkaEstimator, InformationImbalanceEstimator

Synthetic Example#

We build a few pairs of representations with known relationships:

  • same: identical representations, so CKA should be close to 1

  • rotated: an orthogonal transform of the same representation, which linear CKA should also score near 1

  • random: an unrelated representation, which should typically score much lower

[4]:
torch.manual_seed(0)

x = torch.randn(256, 32)
q, _ = torch.linalg.qr(torch.randn(32, 32))

examples = {
    "same": (x, x.clone()),
    "rotated": (x, x @ q),
    "random": (x, torch.randn(256, 24)),
}

estimator = CkaEstimator(kernel="linear")


def run_cka(x, y):
    td = TensorDict({"data_a": x, "data_b": y}, batch_size=[])
    return estimator(td.clone())["cka"].item()
[5]:
scores = {name: run_cka(a, b) for name, (a, b) in examples.items()}
# Display CKA scores for each synthetic pair.
scores
[5]:
{'same': 1.0, 'rotated': 1.0, 'random': 0.10095701366662979}

Batched Inputs#

Like the dimension-estimation modules, CkaEstimator accepts either (N, D) or batched (..., N, D) inputs and returns one scalar score per batch item.

[6]:
batched_x = torch.randn(3, 128, 16)
batched_y = batched_x + 0.05 * torch.randn(3, 128, 16)

td = TensorDict({"data_a": batched_x, "data_b": batched_y}, batch_size=[3])
CkaEstimator()(td.clone())["cka"]
[6]:
tensor([0.9979, 0.9978, 0.9976])

API Notes#

  • The estimator is named CkaEstimator and already exposes a kernel argument.

  • At the moment only kernel="linear" is implemented.

  • Degenerate inputs with zero variance return nan instead of raising.

Future methods can extend this notebook with additional sections, comparisons, and visualizations.

Information Imbalance#

InformationImbalanceEstimator compares how much neighborhood information is preserved from one representation space to another using distance ranks.

It returns two directional values:

  • information_imbalance_a_to_b

  • information_imbalance_b_to_a

Lower values indicate that the source representation is more informative about the target space.

[7]:
t = torch.linspace(-1.0, 1.0, 128)
data_a = t.unsqueeze(-1)
data_b = (4 * t).round().unsqueeze(-1) / 4.0

td = TensorDict({"data_a": data_a, "data_b": data_b}, batch_size=[])
InformationImbalanceEstimator()(td.clone())
[7]:
TensorDict(
    fields={
        data_a: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        data_b: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        information_imbalance_a_to_b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        information_imbalance_b_to_a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)