tdhook.attribution.lrp_helpers.rules#

LRP rules.

This code is adapted from the Zennit library (LGPL-3.0) and the LXT library (Clear BSD) Original sources: - chr5tphr/zennit - rachtibat/LRP-eXplains-Transformers

Classes#

ParamModifier

RemovableRuleHandle

AbstractFunctionMeta

Metaclass for defining Abstract Base Classes (ABCs).

Rule

Base class for LRP rules implemented as custom autograd functions.

EpsilonRule

Base class for LRP rules implemented as custom autograd functions.

UniformEpsilonRule

Base class for LRP rules implemented as custom autograd functions.

PassRule

Base class for LRP rules implemented as custom autograd functions.

IgnoreRule

Base class for LRP rules implemented as custom autograd functions.

WSquareRule

Base class for LRP rules implemented as custom autograd functions.

FlatRule

Base class for LRP rules implemented as custom autograd functions.

UniformRule

Base class for LRP rules implemented as custom autograd functions.

StopRule

Base class for LRP rules implemented as custom autograd functions.

AlphaBetaRule

Base class for LRP rules implemented as custom autograd functions.

SoftmaxEpsilonRule

Base class for LRP rules implemented as custom autograd functions.

LayerNormRule

Base class for LRP rules implemented as custom autograd functions.

PseudoIdentityRule

Base class for LRP rules implemented as custom autograd functions.

AHQKVRule

Base class for LRP rules implemented as custom autograd functions.

BaseRuleMapper

EpsilonPlus

Functions#

stabilize(tensor[, epsilon])

raise_for_unconserved_rel_factory([atol, rtol])

Module Contents#

tdhook.attribution.lrp_helpers.rules.stabilize(tensor, epsilon=1e-06)[source]#
class tdhook.attribution.lrp_helpers.rules.ParamModifier(modify_fn=None, select_fn=None)[source]#
Parameters:
  • modify_fn (Optional[Callable[[str, torch.nn.Parameter], torch.Tensor]])

  • select_fn (Optional[Callable[[str, torch.nn.Parameter], bool]])

_modify_fn[source]#
_select_fn[source]#
state_dicts(module)[source]#
Parameters:

module (torch.nn.Module)

__call__(module)[source]#
Parameters:

module (torch.nn.Module)

static from_modifiers(modifiers)[source]#
Parameters:

modifiers (List[ParamModifier])

static select_all(name, param)[source]#
Parameters:
  • name (str)

  • param (torch.nn.Parameter)

class tdhook.attribution.lrp_helpers.rules.RemovableRuleHandle(rule, module)[source]#
Parameters:
  • rule (Rule)

  • module (torch.nn.Module)

_rule[source]#
_module_ref[source]#
__enter__()[source]#
__exit__(exc_type, exc_value, traceback)[source]#
remove()[source]#
class tdhook.attribution.lrp_helpers.rules.AbstractFunctionMeta(name, bases, attrs)[source]#

Bases: abc.ABCMeta, torch.autograd.function.FunctionMeta

Metaclass for defining Abstract Base Classes (ABCs).

Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as ‘virtual subclasses’ – these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won’t show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).

class tdhook.attribution.lrp_helpers.rules.Rule[source]#

Bases: torch.autograd.function.Function

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

_apply_kwargs[source]#
register(module)[source]#
Parameters:

module (torch.nn.Module)

unregister(module)[source]#
Parameters:

module (torch.nn.Module)

static forward(ctx, apply_kwargs, module, model_kwargs, *inputs)[source]#
Abstractmethod:

Run the wrapped module and save any tensors needed for relevance propagation.

static backward(ctx, *out_relevance)[source]#
Abstractmethod:

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.EpsilonRule(epsilon=1e-06)[source]#

Bases: Rule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

property epsilon[source]#
static forward(ctx, apply_kwargs, module, model_kwargs, *inputs)[source]#

Run the wrapped module and save any tensors needed for relevance propagation.

static backward(ctx, *out_relevance)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.UniformEpsilonRule(epsilon=1e-06)[source]#

Bases: EpsilonRule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

static backward(ctx, *out_relevance)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.PassRule[source]#

Bases: Rule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

static forward(ctx, apply_kwargs, module, model_kwargs, *inputs)[source]#

Run the wrapped module and save any tensors needed for relevance propagation.

static backward(ctx, *out_relevance)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.IgnoreRule[source]#

Bases: Rule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

register(module)[source]#
Parameters:

module (torch.nn.Module)

unregister(module)[source]#
Parameters:

module (torch.nn.Module)

static forward(ctx, apply_kwargs, module, model_kwargs, *inputs)[source]#

Run the wrapped module and save any tensors needed for relevance propagation.

static backward(ctx, *out_relevance)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.WSquareRule(stabilizer=1e-06)[source]#

Bases: Rule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

static forward(ctx, apply_kwargs, module, model_kwargs, *inputs)[source]#

Run the wrapped module and save any tensors needed for relevance propagation.

static backward(ctx, *out_relevance)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.FlatRule(stabilizer=1e-06)[source]#

Bases: WSquareRule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

class tdhook.attribution.lrp_helpers.rules.UniformRule[source]#

Bases: Rule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

static forward(ctx, apply_kwargs, module, model_kwargs, *inputs)[source]#

Run the wrapped module and save any tensors needed for relevance propagation.

static backward(ctx, *out_relevances)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.StopRule[source]#

Bases: Rule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

static forward(ctx, apply_kwargs, module, model_kwargs, *inputs)[source]#

Run the wrapped module and save any tensors needed for relevance propagation.

static backward(ctx, *out_relevances)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.AlphaBetaRule(alpha=2.0, beta=1.0, stabilizer=1e-06)[source]#

Bases: Rule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

static forward(ctx, apply_kwargs, module, model_kwargs, *inputs)[source]#

Run the wrapped module and save any tensors needed for relevance propagation.

static backward(ctx, *out_relevance)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.SoftmaxEpsilonRule(epsilon=1e-06)[source]#

Bases: EpsilonRule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

static backward(ctx, *out_relevances)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.LayerNormRule[source]#

Bases: Rule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

static forward(ctx, apply_kwargs, module, model_kwargs, *inputs)[source]#

Run the wrapped module and save any tensors needed for relevance propagation.

static backward(ctx, *out_relevances)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.PseudoIdentityRule(stabilizer=1e-06)[source]#

Bases: Rule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

static forward(ctx, apply_kwargs, module, model_kwargs, *inputs)[source]#

Run the wrapped module and save any tensors needed for relevance propagation.

static backward(ctx, *out_relevances)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.AHQKVRule[source]#

Bases: Rule

Base class for LRP rules implemented as custom autograd functions.

Subclasses override forward() and backward() to define how relevance is propagated through a wrapped module. Instances are registered onto modules by temporarily replacing module.forward with Function.apply(…).

static forward(ctx, apply_kwargs, module, model_kwargs, *inputs)[source]#

Run the wrapped module and save any tensors needed for relevance propagation.

static backward(ctx, *out_relevances)[source]#

Propagate output relevance back to the inputs of the wrapped module.

class tdhook.attribution.lrp_helpers.rules.BaseRuleMapper(stabilizer=1e-06, rule_mapper=None)[source]#
Parameters:

rule_mapper (Optional[Callable[[str, torch.nn.Module], Rule | None]])

_stabilizer = 1e-06[source]#
_rule_mapper[source]#
_rules[source]#
_call(name, module)[source]#
Parameters:
  • name (str)

  • module (torch.nn.Module)

Return type:

Rule | None

__call__(name, module)[source]#
Parameters:
  • name (str)

  • module (torch.nn.Module)

Return type:

Rule | None

class tdhook.attribution.lrp_helpers.rules.EpsilonPlus(epsilon=1e-06, stabilizer=1e-06, rule_mapper=None)[source]#

Bases: BaseRuleMapper

Parameters:

rule_mapper (Optional[Callable[[str, torch.nn.Module], Rule | None]])

_call(name, module)[source]#
Parameters:
  • name (str)

  • module (torch.nn.Module)

Return type:

Rule | None

tdhook.attribution.lrp_helpers.rules.raise_for_unconserved_rel_factory(atol=1e-06, rtol=1e-06)[source]#
Parameters:
  • atol (float)

  • rtol (float)