tdhook.weights.task_vectors#

Classes#

TaskVectorsContext

Base class for hooking contexts.

TaskVectorsModule

Wrapper to enhance a module with hooking capabilities.

TaskVectors

Task vectors [33].

Module Contents#

class tdhook.weights.task_vectors.TaskVectorsContext(*args, alphas, get_test_accuracy, get_control_adequacy, **kwargs)[source]#

Bases: tdhook.contexts.HookingContext

Base class for hooking contexts.

Parameters:
  • alphas (Iterable[float])

  • get_test_accuracy (Callable[[torch.nn.Module], float])

  • get_control_adequacy (Callable[[torch.nn.Module], bool])

alphas[source]#
get_test_accuracy[source]#
get_control_adequacy[source]#
compute_alpha(vector)[source]#

Compute alpha

Parameters:

vector (tensordict.TensorDict)

Return type:

float

class tdhook.weights.task_vectors.TaskVectorsModule(*args, **kwargs)[source]#

Bases: tdhook.modules.HookedModule

Wrapper to enhance a module with hooking capabilities.

_weights[source]#
get_task_vector(finetuned_module)[source]#

Compute task vector

Parameters:

finetuned_module (torch.nn.Module)

Return type:

tensordict.TensorDict

get_forget_vector(finetuned_module)[source]#

Compute forget vector

Parameters:

finetuned_module (torch.nn.Module)

Return type:

tensordict.TensorDict

get_weights(*vectors, alpha=None)[source]#

Get weights

Parameters:
  • vectors (tensordict.TensorDict)

  • alpha (Optional[float])

Return type:

tensordict.TensorDict

with_applied_vectors(*vectors, alpha=None)[source]#

Apply vectors to model

Parameters:
  • vectors (tensordict.TensorDict)

  • alpha (Optional[float])

Return type:

Generator[torch.nn.Module, None, None]

class tdhook.weights.task_vectors.TaskVectors(alphas, get_test_accuracy, get_control_adequacy)[source]#

Bases: tdhook.contexts.HookingContextFactory

Task vectors [33].

Parameters:
  • alphas (Iterable[float])

  • get_test_accuracy (Callable[[torch.nn.Module], float])

  • get_control_adequacy (Callable[[torch.nn.Module], bool])

_hooking_context_class[source]#
_hooked_module_class[source]#