tdhook.attribution.lrp#
Classes#
Module Contents#
- class tdhook.attribution.lrp.LRP(use_inputs=True, use_outputs=True, input_modules=None, target_modules=None, init_attr_targets=None, init_attr_inputs=None, init_attr_cache_in=None, init_attr_grads=None, additional_init_keys=None, output_grad_callbacks=None, attribution_key='attr', clean_intermediate_keys=True, cache_callback=None, rule_mapper=None, warn_on_missing_rule=True, skip_modules=None)[source]#
Bases:
tdhook.attribution.gradient_helpers.GradientAttributionDifferent LRP rules such as LRP-0, LRP-epsilon z-plus [3], flat [4], gamma [5, 6], w-square [7] and its conditional variant [1].
- Parameters:
use_inputs (bool)
use_outputs (bool)
input_modules (Optional[List[str]])
target_modules (Optional[List[str]])
init_attr_targets (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])
init_attr_inputs (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])
init_attr_cache_in (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])
init_attr_grads (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])
additional_init_keys (Optional[List[tdhook._types.UnraveledKey]])
output_grad_callbacks (Optional[Dict[str, Callable]])
attribution_key (tdhook._types.UnraveledKey)
clean_intermediate_keys (bool)
cache_callback (Optional[Callable])
rule_mapper (Callable[[str, torch.nn.Module], tdhook.attribution.lrp_helpers.rules.Rule | None] | None)
warn_on_missing_rule (bool)
skip_modules (Optional[Callable[[str, torch.nn.Module], bool]])
- _prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase
- _restore_module(module, in_keys, out_keys, extra_relative_path)[source]#
- Parameters:
module (tensordict.nn.TensorDictModuleBase)
in_keys (List[tdhook._types.UnraveledKey])
out_keys (List[tdhook._types.UnraveledKey])
extra_relative_path (str)
- Return type:
tensordict.nn.TensorDictModuleBase