tdhook.weights.task_vectors#
Classes#
Base class for hooking contexts. |
|
Wrapper to enhance a module with hooking capabilities. |
|
Task vectors [33]. |
Module Contents#
- class tdhook.weights.task_vectors.TaskVectorsContext(*args, alphas, get_test_accuracy, get_control_adequacy, **kwargs)[source]#
Bases:
tdhook.contexts.HookingContextBase class for hooking contexts.
- Parameters:
alphas (Iterable[float])
get_test_accuracy (Callable[[torch.nn.Module], float])
get_control_adequacy (Callable[[torch.nn.Module], bool])
- class tdhook.weights.task_vectors.TaskVectorsModule(*args, **kwargs)[source]#
Bases:
tdhook.modules.HookedModuleWrapper to enhance a module with hooking capabilities.
- get_task_vector(finetuned_module)[source]#
Compute task vector
- Parameters:
finetuned_module (torch.nn.Module)
- Return type:
tensordict.TensorDict
- get_forget_vector(finetuned_module)[source]#
Compute forget vector
- Parameters:
finetuned_module (torch.nn.Module)
- Return type:
tensordict.TensorDict
- class tdhook.weights.task_vectors.TaskVectors(alphas, get_test_accuracy, get_control_adequacy)[source]#
Bases:
tdhook.contexts.HookingContextFactoryTask vectors [33].
- Parameters:
alphas (Iterable[float])
get_test_accuracy (Callable[[torch.nn.Module], float])
get_control_adequacy (Callable[[torch.nn.Module], bool])