tdhook.attribution.gradient_helpers#

Gradient attribution helpers.

Submodules#

Classes#

GradientAttribution

Base class for gradient attribution.

GradientAttributionWithBaseline

Gradient attribution with baseline.

Package Contents#

class tdhook.attribution.gradient_helpers.GradientAttribution(use_inputs=True, use_outputs=True, input_modules=None, target_modules=None, init_attr_targets=None, init_attr_inputs=None, init_attr_cache_in=None, init_attr_grads=None, additional_init_keys=None, output_grad_callbacks=None, attribution_key='attr', clean_intermediate_keys=True, cache_callback=None)[source]#

Bases: tdhook.contexts.HookingContextFactory

Base class for gradient attribution.

Parameters:
  • use_inputs (bool)

  • use_outputs (bool)

  • input_modules (Optional[List[str]])

  • target_modules (Optional[List[str]])

  • init_attr_targets (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])

  • init_attr_inputs (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])

  • init_attr_cache_in (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])

  • init_attr_grads (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])

  • additional_init_keys (Optional[List[tdhook._types.UnraveledKey]])

  • output_grad_callbacks (Optional[Dict[str, Callable]])

  • attribution_key (tdhook._types.UnraveledKey)

  • clean_intermediate_keys (bool)

  • cache_callback (Optional[Callable])

_use_inputs = True#
_use_outputs = True#
_input_modules = []#
_target_modules = []#
_init_attr_targets = None#
_init_attr_inputs = None#
_init_attr_cache_in = None#
_init_attr_grads = None#
_output_grad_callbacks#
_cache_callback = None#
_additional_init_keys = []#
_attr_key = 'attr'#
_clean_intermediate_keys = True#
_prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

_hook_module(module)[source]#
Parameters:

module (tdhook.modules.HookedModule)

Return type:

tdhook.hooks.MultiHookHandle

_register_inputs_fn(td)[source]#
Parameters:

td (tensordict.TensorDict)

Return type:

tensordict.TensorDict

_attributor_fn(td, cache_ref)[source]#
Parameters:
Return type:

tensordict.TensorDict

abstract _grad_attr(grads, inputs)[source]#
Parameters:
  • grads (tensordict.TensorDict)

  • inputs (tensordict.TensorDict)

Return type:

tensordict.TensorDict

class tdhook.attribution.gradient_helpers.GradientAttributionWithBaseline(*args, compute_convergence_delta=False, baseline_key='baseline', multiply_by_inputs=False, **kwargs)[source]#

Bases: GradientAttribution

Gradient attribution with baseline.

Parameters:
_compute_convergence_delta = False#
_baseline_key = 'baseline'#
_multiply_by_inputs = False#
_prepare_module(module, in_keys, out_keys, extra_relative_path)[source]#
Parameters:
Return type:

tensordict.nn.TensorDictModuleBase

abstract _reduce_baselines_fn(td, in_keys)[source]#
Parameters:
Return type:

tensordict.TensorDict

_multiply_by_inputs_fn(inputs, baselines, attrs)[source]#
Parameters:
  • inputs (Tuple[torch.Tensor, Ellipsis])

  • baselines (Tuple[torch.Tensor, Ellipsis])

  • attrs (Tuple[torch.Tensor, Ellipsis])

Return type:

Tuple[torch.Tensor, Ellipsis]

_compute_convergence_delta_fn(td, in_keys, out_keys, module)[source]#
Parameters:
Return type:

tensordict.TensorDict