"""
LRP rules.
This code is adapted from the Zennit library (LGPL-3.0) and the LXT library (Clear BSD)
Original sources:
- https://github.com/chr5tphr/zennit/blob/main/src/zennit/rules.py
- https://github.com/rachtibat/LRP-eXplains-Transformers/blob/main/lxt/explicit/rules.py
"""
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
from typing import Callable, List, Optional
import weakref
import torch
import torch.nn as nn
from torch.autograd.function import Function, FunctionMeta
from .types import Activation, AvgPool, BatchNorm, Convolution, Linear, MaxPool
from .layers import Sum
[docs]
def stabilize(tensor, epsilon=1e-6):
return tensor + epsilon * ((-1) ** (tensor < 0))
[docs]
class ParamModifier:
def __init__(
self,
modify_fn: Optional[Callable[[str, nn.Parameter], torch.Tensor]] = None,
select_fn: Optional[Callable[[str, nn.Parameter], bool]] = None,
):
[docs]
self._modify_fn = modify_fn or (lambda x, _: x)
[docs]
self._select_fn = select_fn or (lambda _, __: False)
[docs]
def state_dicts(self, module: nn.Module):
original_state = {
name: param for name, param in module.named_parameters(recurse=False) if self._select_fn(name, param)
}
modified_state = {
name: self._modify_fn(name, param)
for name, param in module.named_parameters(recurse=False)
if self._select_fn(name, param)
}
return original_state, modified_state
@contextmanager
[docs]
def __call__(self, module: nn.Module):
original_state = {}
try:
original_state, modified_state = self.state_dicts(module)
module.load_state_dict(modified_state, strict=False, assign=True)
yield module
finally:
module.load_state_dict(original_state, strict=False, assign=True)
@staticmethod
[docs]
def from_modifiers(modifiers: List["ParamModifier"]):
def select_fn(name, param):
return any(modifier._select_fn(name, param) for modifier in modifiers)
def modify_fn(name, param):
new_param = param
for modifier in modifiers:
if modifier._select_fn(name, new_param):
new_param = modifier._modify_fn(name, new_param)
return new_param
return ParamModifier(modify_fn=modify_fn, select_fn=select_fn)
@staticmethod
[docs]
def select_all(name: str, param: nn.Parameter):
return True
[docs]
class RemovableRuleHandle:
def __init__(self, rule: "Rule", module: nn.Module):
[docs]
self._module_ref = weakref.ref(module)
[docs]
def __enter__(self):
return self
[docs]
def __exit__(self, exc_type, exc_value, traceback):
self.remove()
[docs]
def remove(self):
module = self._module_ref()
if module is not None:
self._rule.unregister(module)
[docs]
class Rule(Function, metaclass=AbstractFunctionMeta):
def __init__(self):
[docs]
self._apply_kwargs = {}
# TODO: Add zero_params argument for all rules
[docs]
def register(self, module: nn.Module):
module._prev_forward = module.forward
def forward(*inputs, **model_kwargs):
nonlocal module
return self.apply(self._apply_kwargs, module, model_kwargs, *inputs)
module.forward = forward
return RemovableRuleHandle(self, module)
[docs]
def unregister(self, module: nn.Module):
module.forward = module._prev_forward
del module._prev_forward
@staticmethod
@abstractmethod
[docs]
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
pass
@staticmethod
@abstractmethod
[docs]
def backward(ctx, *out_relevance):
pass
[docs]
class EpsilonRule(Rule):
def __init__(self, epsilon=1e-6):
super().__init__()
self._apply_kwargs["epsilon"] = epsilon
@property
[docs]
def epsilon(self):
return self._apply_kwargs["epsilon"]
@epsilon.setter
def epsilon(self, value):
self._apply_kwargs["epsilon"] = value
@staticmethod
[docs]
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
# TODO: Move logic to backward
if not any(ctx.needs_input_grad):
return module._prev_forward(*inputs, **model_kwargs)
# TODO: Check if really needed
inputs = tuple(inp.detach().requires_grad_() if inp.requires_grad else inp for inp in inputs)
with torch.enable_grad():
output = module._prev_forward(*inputs, **model_kwargs)
ctx.save_for_backward(output, *inputs)
ctx.epsilon = apply_kwargs["epsilon"]
return output.detach()
@staticmethod
[docs]
def backward(ctx, *out_relevance):
output, *inputs = ctx.saved_tensors
relevance_norm = out_relevance[0] / stabilize(output, ctx.epsilon)
grads = torch.autograd.grad(output, inputs, relevance_norm)
return (
None,
None,
None,
*(grads[i].mul(inputs[i]) if ctx.needs_input_grad[i + 3] else None for i in range(len(inputs))),
)
[docs]
class PassRule(Rule):
@staticmethod
[docs]
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
# TODO: Move logic to backward
n_inputs = len(inputs)
output = module._prev_forward(*inputs, **model_kwargs)
n_outputs = len(output) if isinstance(output, tuple) else 1
if n_inputs != n_outputs:
raise ValueError(
(
"PassRule requires the number of inputs and outputs to be the same, ",
f"got {n_inputs} inputs and {n_outputs} outputs",
)
)
all_outputs = [output] if isinstance(output, torch.Tensor) else output
for index, (i, o) in enumerate(zip(inputs, all_outputs)):
if i.shape != o.shape:
raise ValueError(
f"Input (shape={i.shape}) and output (shape={o.shape}) have different shapes at index {index}"
)
return output
@staticmethod
[docs]
def backward(ctx, *out_relevance):
return None, None, None, *out_relevance
[docs]
class IgnoreRule(Rule):
[docs]
def register(self, module: nn.Module):
return RemovableRuleHandle(self, module)
[docs]
def unregister(self, module: nn.Module):
pass
@staticmethod
[docs]
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
pass
@staticmethod
[docs]
def backward(ctx, *out_relevance):
pass
[docs]
class WSquareRule(Rule):
def __init__(self, stabilizer=1e-6):
super().__init__()
self._apply_kwargs["_modifier"] = ParamModifier(
select_fn=ParamModifier.select_all, modify_fn=lambda _, param: param**2
)
self._apply_kwargs["stabilizer"] = stabilizer
@staticmethod
[docs]
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
# TODO: Move logic to backward
ctx.stabilizer = apply_kwargs["stabilizer"]
output = module._prev_forward(*inputs, **model_kwargs)
mod_inputs = tuple(torch.ones_like(inp).requires_grad_() for inp in inputs)
with torch.enable_grad():
with apply_kwargs["_modifier"](module) as modified_module:
mod_output = modified_module._prev_forward(*mod_inputs, **model_kwargs)
ctx.save_for_backward(mod_output, *mod_inputs)
return output
@staticmethod
[docs]
def backward(ctx, *out_relevance):
mod_output, *mod_inputs = ctx.saved_tensors
normed_relevance = out_relevance[0] / stabilize(mod_output, ctx.stabilizer)
in_relevance = torch.autograd.grad(mod_output, mod_inputs, normed_relevance)
return None, None, None, *in_relevance
[docs]
class FlatRule(WSquareRule):
def __init__(self, stabilizer=1e-6):
super().__init__(stabilizer)
to_ones = ParamModifier(select_fn=ParamModifier.select_all, modify_fn=lambda _, param: torch.ones_like(param))
zero_bias = ParamModifier(
select_fn=lambda name, param: name == "bias", modify_fn=lambda _, param: torch.zeros_like(param)
)
self._apply_kwargs["_modifier"] = ParamModifier.from_modifiers([to_ones, zero_bias])
[docs]
class StopRule(Rule):
@staticmethod
[docs]
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
ctx.n_inputs = len(inputs)
return module._prev_forward(*inputs, **model_kwargs)
@staticmethod
[docs]
def backward(ctx, *out_relevances):
return None, None, None, *(None,) * ctx.n_inputs
[docs]
class AlphaBetaRule(Rule):
def __init__(self, alpha=2.0, beta=1.0, stabilizer=1e-6):
if alpha < 0 or beta < 0:
raise ValueError("Both alpha and beta parameters must be non-negative!")
if (alpha - beta) != 1.0:
raise ValueError("The difference of parameters alpha - beta must equal 1!")
super().__init__()
self._apply_kwargs["alpha"] = alpha
self._apply_kwargs["beta"] = beta
self._apply_kwargs["stabilizer"] = stabilizer
self._apply_kwargs["_zero_bias"] = ParamModifier(
select_fn=lambda name, param: name == "bias", modify_fn=lambda _, param: torch.zeros_like(param)
)
self._apply_kwargs["_positive"] = ParamModifier(
select_fn=ParamModifier.select_all, modify_fn=lambda _, param: param.clamp(min=0)
)
self._apply_kwargs["_negative"] = ParamModifier(
select_fn=ParamModifier.select_all, modify_fn=lambda _, param: param.clamp(max=0)
)
@staticmethod
[docs]
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
# TODO: Move logic to backward
output = module._prev_forward(*inputs, **model_kwargs)
if len(inputs) > 1:
raise NotImplementedError("AlphaBetaRule does not support multiple inputs")
pos_input = inputs[0].clamp(min=0).detach().requires_grad_()
neg_input = inputs[0].clamp(max=0).detach().requires_grad_()
with torch.enable_grad():
with apply_kwargs["_positive"](module) as positive_module:
out_pos = positive_module._prev_forward(pos_input, **model_kwargs)
with apply_kwargs["_zero_bias"](positive_module) as modified_module:
out_pos_zero = modified_module._prev_forward(neg_input, **model_kwargs)
with apply_kwargs["_negative"](module) as negative_module:
out_neg = negative_module._prev_forward(pos_input, **model_kwargs)
with apply_kwargs["_zero_bias"](negative_module) as modified_module:
out_neg_zero = modified_module._prev_forward(neg_input, **model_kwargs)
ctx.save_for_backward(out_pos, out_pos_zero, out_neg, out_neg_zero, neg_input, pos_input)
ctx.alpha = apply_kwargs["alpha"]
ctx.beta = apply_kwargs["beta"]
ctx.stabilizer = apply_kwargs["stabilizer"]
return output
@staticmethod
[docs]
def backward(ctx, *out_relevance):
out_pos, out_pos_zero, out_neg, out_neg_zero, neg_input, pos_input = ctx.saved_tensors
relevance_pos = out_relevance[0] / stabilize(out_pos + out_neg_zero, ctx.stabilizer)
relevance_neg = out_relevance[0] / stabilize(out_neg + out_pos_zero, ctx.stabilizer)
pos_grad = torch.autograd.grad(out_pos, pos_input, relevance_pos)
pos_zero_grad = torch.autograd.grad(out_pos_zero, neg_input, relevance_neg)
neg_grad = torch.autograd.grad(out_neg, pos_input, relevance_neg)
neg_zero_grad = torch.autograd.grad(out_neg_zero, neg_input, relevance_pos)
in_relevance = ctx.alpha * (pos_input * pos_grad[0] + neg_input * neg_zero_grad[0]) - ctx.beta * (
pos_input * neg_grad[0] + neg_input * pos_zero_grad[0]
)
return None, None, None, in_relevance
[docs]
class SoftmaxEpsilonRule(EpsilonRule):
@staticmethod
[docs]
def backward(ctx, *out_relevances):
output, *inputs = ctx.saved_tensors
relevance = (out_relevances[0] - (output * out_relevances[0].sum(-1, keepdim=True))) * inputs[0]
return None, None, None, relevance
[docs]
class LayerNormRule(Rule):
@staticmethod
[docs]
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
if len(inputs) > 1:
raise NotImplementedError("LayerNormRule does not support multiple inputs")
x = inputs[0]
weight = module.weight
bias = module.bias
eps = module.eps
with torch.enable_grad():
mean = x.mean(dim=-1, keepdim=True)
var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
std = (var + eps).sqrt()
y = (
(x - mean) / std.detach()
) # detach std operation will remove it from computational graph i.e. identity rule on x/std
if weight is not None:
y *= weight
if bias is not None:
y += bias
ctx.save_for_backward(x, y)
return y.detach()
@staticmethod
[docs]
def backward(ctx, *out_relevances):
x, y = ctx.saved_tensors
input_relevances = torch.autograd.grad(y, x, out_relevances[0])
return (None, None, None, input_relevances[0])
[docs]
class PseudoIdentityRule(Rule):
def __init__(self, stabilizer=1e-6):
super().__init__()
self._apply_kwargs["stabilizer"] = stabilizer
@staticmethod
[docs]
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
if len(inputs) > 1:
raise NotImplementedError("PseudoIdentityRule does not support multiple inputs")
ctx.stabilizer = apply_kwargs["stabilizer"]
outputs = module._prev_forward(inputs[0], **model_kwargs)
ctx.save_for_backward(outputs, inputs[0])
return outputs
@staticmethod
[docs]
def backward(ctx, *out_relevances):
outputs, inputs = ctx.saved_tensors
pseudo_identity = outputs / stabilize(inputs, ctx.stabilizer)
return None, None, None, pseudo_identity * out_relevances[0]
[docs]
class AHQKVRule(Rule):
@staticmethod
[docs]
def forward(ctx, apply_kwargs, module, model_kwargs, *inputs):
with torch.enable_grad():
outputs = module._prev_forward(*inputs, **model_kwargs)
d_model = outputs.shape[-1] // 3
query, key, value = outputs.split(d_model, dim=-1)
mod_outputs = torch.cat([query.detach(), key.detach(), value], dim=-1)
ctx.save_for_backward(*inputs, mod_outputs)
return outputs
@staticmethod
[docs]
def backward(ctx, *out_relevances):
*inputs, mod_outputs = ctx.saved_tensors
in_relevances = torch.autograd.grad(mod_outputs, inputs, out_relevances[0], retain_graph=True)
return None, None, None, *in_relevances
[docs]
class BaseRuleMapper:
def __init__(self, stabilizer=1e-6, rule_mapper: Optional[Callable[[str, nn.Module], Rule | None]] = None):
[docs]
self._stabilizer = stabilizer
[docs]
self._rule_mapper = rule_mapper or (lambda name, module: None)
[docs]
self._rules = {
"pass": PassRule(),
"norm": EpsilonRule(epsilon=self._stabilizer),
"ignore": IgnoreRule(),
}
[docs]
def _call(self, name: str, module: nn.Module) -> Rule | None:
if isinstance(module, (Activation, BatchNorm)):
return self._rules["pass"]
if isinstance(module, (MaxPool, nn.Identity, nn.Dropout, nn.Flatten)):
return self._rules["ignore"]
elif isinstance(module, (Sum, AvgPool)):
return self._rules["norm"]
[docs]
def __call__(self, name: str, module: nn.Module) -> Rule | None:
rule = self._rule_mapper(name, module)
if rule is None:
return self._call(name, module)
return rule
[docs]
class EpsilonPlus(BaseRuleMapper):
def __init__(
self, epsilon=1e-6, stabilizer=1e-6, rule_mapper: Optional[Callable[[str, nn.Module], Rule | None]] = None
):
super().__init__(stabilizer, rule_mapper)
self._rules["epsilon"] = EpsilonRule(epsilon=epsilon)
self._rules["zplus"] = AlphaBetaRule(alpha=1.0, beta=0.0, stabilizer=stabilizer)
[docs]
def _call(self, name: str, module: nn.Module) -> Rule | None:
if isinstance(module, Convolution):
return self._rules["zplus"]
elif isinstance(module, Linear):
return self._rules["epsilon"]
return super()._call(name, module)
[docs]
def raise_for_unconserved_rel_factory(atol: float = 1e-6, rtol: float = 1e-6):
def raise_for_unconserved_rel(module, in_relevances, out_relevances):
if isinstance(in_relevances, tuple):
in_rel_sum = sum(in_rel.sum() for in_rel in in_relevances)
else:
in_rel_sum = in_relevances.sum()
if isinstance(out_relevances, tuple):
out_rel_sum = sum(out_rel.sum() for out_rel in out_relevances)
else:
out_rel_sum = out_relevances.sum()
if not torch.isclose(in_rel_sum, out_rel_sum, atol=atol, rtol=rtol):
raise RuntimeError(
(f"Unconserved relevance for module {module.__class__.__name__} ({in_rel_sum=}) ({out_rel_sum=})")
)
return raise_for_unconserved_rel
# TODO: Add lxt rules and tests