Source code for tdhook.attribution.saliency
from typing import List, Optional, Callable, Dict
from tensordict import TensorDict
import torch
from tdhook._types import UnraveledKey
from tdhook.attribution.gradient_helpers import GradientAttribution
[docs]
class Saliency(GradientAttribution):
"""
Gradient attribution :cite:`Simonyan2013DeepIC` and its gradient-times-input variation :cite:`Shrikumar2016NotJA`.
"""
def __init__(
self,
use_inputs: bool = True,
use_outputs: bool = True,
input_modules: Optional[List[str]] = None,
target_modules: Optional[List[str]] = None,
init_attr_targets: Optional[Callable[[TensorDict, TensorDict], TensorDict]] = None,
init_attr_inputs: Optional[Callable[[TensorDict, TensorDict], TensorDict]] = None,
init_attr_cache_in: Optional[Callable[[TensorDict, TensorDict], TensorDict]] = None,
init_attr_grads: Optional[Callable[[TensorDict, TensorDict], TensorDict]] = None,
additional_init_keys: Optional[List[UnraveledKey]] = None,
output_grad_callbacks: Optional[Dict[str, Callable]] = None,
attribution_key: UnraveledKey = "attr",
clean_intermediate_keys: bool = True,
cache_callback: Optional[Callable] = None,
absolute: bool = False,
multiply_by_inputs: bool = False,
):
super().__init__(
use_inputs=use_inputs,
use_outputs=use_outputs,
input_modules=input_modules,
target_modules=target_modules,
init_attr_targets=init_attr_targets,
init_attr_inputs=init_attr_inputs,
init_attr_cache_in=init_attr_cache_in,
init_attr_grads=init_attr_grads,
additional_init_keys=additional_init_keys,
output_grad_callbacks=output_grad_callbacks,
attribution_key=attribution_key,
clean_intermediate_keys=clean_intermediate_keys,
cache_callback=cache_callback,
)
[docs]
self._absolute = absolute
@torch.no_grad()
[docs]
def _grad_attr(
self,
grads: TensorDict,
inputs: TensorDict,
):
if self._absolute:
grads.abs_()
if self._multiply_by_inputs:
grads.mul_(inputs)
return grads