Source code for tdhook.metrics

import torch
from tensordict import TensorDict
from typing import List

from tdhook.modules import HookedModule


# TODO: test against captum
# TODO: fix this
[docs] class InfidelityMetric: def __init__(self, n_perturb_samples: int = 10):
[docs] self.n_perturb_samples = n_perturb_samples
[docs] def __call__( self, module: HookedModule, original_data: TensorDict, ) -> TensorDict: """Compute infidelity as the difference between attribution-weighted perturbations and model output changes.""" infidelities = TensorDict(batch_size=original_data.batch_size) n_batch_dims = len(original_data.batch_size) for key in module.in_keys: # Get original attribution original_attr = original_data.get(("attr", key)) # Generate multiple perturbations perturbation_scores = [] output_changes = [] for _ in range(self.n_perturb_samples): # Generate perturbed data perturbed_data = self._perturb_data(original_data, [key]) # Get perturbed attribution module(perturbed_data) # Calculate perturbation (difference between original and perturbed input) perturbation = original_data[key] - perturbed_data[key] # Attribution-weighted perturbation attr_weighted_perturb = (original_attr * perturbation).sum( dim=tuple(range(n_batch_dims, original_attr.dim())) ) # Model output change original_output = module(original_data)[("_mod_out", "output")] perturbed_output = module(perturbed_data)[("_mod_out", "output")] output_change = (original_output - perturbed_output).sum( dim=tuple(range(n_batch_dims, original_output.dim())) ) perturbation_scores.append(attr_weighted_perturb) output_changes.append(output_change) # Stack results perturbation_scores = torch.stack(perturbation_scores, dim=-1) # [batch, n_samples] output_changes = torch.stack(output_changes, dim=-1) # [batch, n_samples] # Compute infidelity as MSE between attribution-weighted perturbations and output changes infidelity = ((perturbation_scores - output_changes) ** 2).mean(dim=-1) infidelities[key] = infidelity return infidelities
@torch.no_grad()
[docs] def _perturb_data(self, data: TensorDict, in_keys: List[str]) -> TensorDict: """Add random noise to create perturbations.""" perturbed_data = data.clone() for key in in_keys: value = perturbed_data[key] if isinstance(value, torch.Tensor): # Generate random noise for perturbation noise = torch.randn_like(value) * 0.01 # Small noise perturbed_data[key] = value + noise return perturbed_data
[docs] class SensitivityMetric: def __init__(self, perturb_radius: float = 0.02):
[docs] self.perturb_radius = perturb_radius
[docs] def __call__( self, module: HookedModule, original_data: TensorDict, ) -> TensorDict: """Compute sensitivity as the relative change in explanation when input is perturbed.""" perturbed_data = self._perturb_data(original_data, module.in_keys) module(perturbed_data) sensitivities = TensorDict(batch_size=original_data.batch_size) n_batch_dims = len(original_data.batch_size) for key in module.in_keys: original_attr = original_data.get(("attr", key)) perturbed_attr = perturbed_data.get(("attr", key)) explanation_diff = (original_attr - perturbed_attr).abs() # Calculate mean over all dimensions except the batch dimensions # batch dimensions are the first `batch_size` dimensions non_batch_dims = tuple(range(n_batch_dims, original_attr.dim())) original_magnitude = original_attr.abs().mean(dim=non_batch_dims) explanation_diff_mean = explanation_diff.mean(dim=non_batch_dims) # Avoid division by zero sensitivities[key] = torch.where( original_magnitude == 0, explanation_diff_mean, explanation_diff_mean / original_magnitude ) return sensitivities
@torch.no_grad()
[docs] def _perturb_data(self, data: TensorDict, in_keys: List[str]) -> TensorDict: """Add random noise within the perturbation radius.""" perturbed_data = data.clone() for key in in_keys: value = perturbed_data[key] if isinstance(value, torch.Tensor): noise = ( torch.FloatTensor(value.size()) .uniform_(-self.perturb_radius, self.perturb_radius) # TODO: replace with actual radius dist .to(value.device) ) perturbed_data[key] = value + noise return perturbed_data