Representation Similarity#
This notebook introduces representation similarity methods in tdhook.
It currently covers:
centered kernel alignment (CKA) through
tdhook.latent.representation_similarity.CkaEstimatorInformation Imbalance through
tdhook.latent.representation_similarity.InformationImbalanceEstimator
Setup#
[1]:
import importlib.util
DEV = True
if importlib.util.find_spec("google.colab") is not None:
MODE = "colab-dev" if DEV else "colab"
else:
MODE = "local"
[2]:
if MODE == "colab":
%pip install -q tdhook
elif MODE == "colab-dev":
!rm -rf tdhook
!git clone https://github.com/Xmaster6y/tdhook -b main
%pip install -q ./tdhook
Imports#
[3]:
import torch
from tensordict import TensorDict
from tdhook.latent.representation_similarity import CkaEstimator, InformationImbalanceEstimator
Synthetic Example#
We build a few pairs of representations with known relationships:
same: identical representations, so CKA should be close to1rotated: an orthogonal transform of the same representation, which linear CKA should also score near1random: an unrelated representation, which should typically score much lower
[4]:
torch.manual_seed(0)
x = torch.randn(256, 32)
q, _ = torch.linalg.qr(torch.randn(32, 32))
examples = {
"same": (x, x.clone()),
"rotated": (x, x @ q),
"random": (x, torch.randn(256, 24)),
}
estimator = CkaEstimator(kernel="linear")
def run_cka(x, y):
td = TensorDict({"data_a": x, "data_b": y}, batch_size=[])
return estimator(td.clone())["cka"].item()
[5]:
scores = {name: run_cka(a, b) for name, (a, b) in examples.items()}
# Display CKA scores for each synthetic pair.
scores
[5]:
{'same': 1.0, 'rotated': 1.0, 'random': 0.10095701366662979}
Batched Inputs#
Like the dimension-estimation modules, CkaEstimator accepts either (N, D) or batched (..., N, D) inputs and returns one scalar score per batch item.
[6]:
batched_x = torch.randn(3, 128, 16)
batched_y = batched_x + 0.05 * torch.randn(3, 128, 16)
td = TensorDict({"data_a": batched_x, "data_b": batched_y}, batch_size=[3])
CkaEstimator()(td.clone())["cka"]
[6]:
tensor([0.9979, 0.9978, 0.9976])
API Notes#
The estimator is named
CkaEstimatorand already exposes akernelargument.At the moment only
kernel="linear"is implemented.Degenerate inputs with zero variance return
naninstead of raising.
Future methods can extend this notebook with additional sections, comparisons, and visualizations.
Information Imbalance#
InformationImbalanceEstimator compares how much neighborhood information is preserved from one representation space to another using distance ranks.
It returns two directional values:
information_imbalance_a_to_binformation_imbalance_b_to_a
Lower values indicate that the source representation is more informative about the target space.
[7]:
t = torch.linspace(-1.0, 1.0, 128)
data_a = t.unsqueeze(-1)
data_b = (4 * t).round().unsqueeze(-1) / 4.0
td = TensorDict({"data_a": data_a, "data_b": data_b}, batch_size=[])
InformationImbalanceEstimator()(td.clone())
[7]:
TensorDict(
fields={
data_a: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.float32, is_shared=False),
data_b: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.float32, is_shared=False),
information_imbalance_a_to_b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
information_imbalance_b_to_a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)