tdhook.attribution.integrated_gradients#
Classes#
Module Contents#
- class tdhook.attribution.integrated_gradients.IntegratedGradients(use_inputs=True, use_outputs=True, input_modules=None, target_modules=None, init_attr_targets=None, init_attr_inputs=None, init_attr_cache_in=None, init_attr_grads=None, additional_init_keys=None, output_grad_callbacks=None, attribution_key='attr', clean_intermediate_keys=True, cache_callback=None, compute_convergence_delta=False, baseline_key='baseline', multiply_by_inputs=False, method='gausslegendre', n_steps=50)[source]#
Bases:
tdhook.attribution.gradient_helpers.GradientAttributionWithBaselineIntegrated gradients [12] and its conditional variant [13].
- Parameters:
use_inputs (bool)
use_outputs (bool)
input_modules (Optional[List[str]])
target_modules (Optional[List[str]])
init_attr_targets (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])
init_attr_inputs (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])
init_attr_cache_in (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])
init_attr_grads (Optional[Callable[[tensordict.TensorDict, tensordict.TensorDict], tensordict.TensorDict]])
additional_init_keys (Optional[List[tdhook._types.UnraveledKey]])
output_grad_callbacks (Optional[Dict[str, Callable]])
attribution_key (tdhook._types.UnraveledKey)
clean_intermediate_keys (bool)
cache_callback (Optional[Callable])
compute_convergence_delta (bool)
baseline_key (tdhook._types.UnraveledKey)
multiply_by_inputs (bool)
method (str)
n_steps (int)
- _reduce_baselines_fn(td, in_keys)[source]#
- Parameters:
td (tensordict.TensorDict)
in_keys (List[tdhook._types.UnraveledKey])
- Return type:
tensordict.TensorDict
- _grad_attr(grads, inputs)[source]#
- Parameters:
grads (tensordict.TensorDict)
inputs (tensordict.TensorDict)
- Return type:
tensordict.TensorDict
- static init_attr_targets_with_labels(outputs, additional_init_tensors, selected_out_keys, label_key='label')[source]#
- Parameters:
outputs (tensordict.TensorDict)
additional_init_tensors (tensordict.TensorDict)
selected_out_keys (List[tdhook._types.UnraveledKey])
label_key (tdhook._types.UnraveledKey)
- Return type:
tensordict.TensorDict