Steering Vectors#
This notebook demonstrates how to use Steering Vectors to modify model behavior.
Setup#
[1]:
import importlib.util
DEV = True
if importlib.util.find_spec("google.colab") is not None:
MODE = "colab-dev" if DEV else "colab"
else:
MODE = "local"
[2]:
if MODE == "colab":
%pip install -q tdhook
elif MODE == "colab-dev":
!rm -r tdhook
!git clone https://github.com/Xmaster6y/tdhook -b main
%pip install -q ./tdhook
Usage#
[3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from tensordict import TensorDict
from tdhook.latent import ActivationAddition, SteeringVectors
/home/docs/checkouts/readthedocs.org/user_builds/tdhook/checkouts/v0.1.2/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Load model and tokenizer
[4]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
Prepare inputs
[5]:
positive_inputs = tokenizer.encode("I am rich.", return_tensors="pt")
negative_inputs = tokenizer.encode("I am poor.", return_tensors="pt")
base_inputs = tokenizer.encode("I work as a", return_tensors="pt")
Extract steering vector (rich - poor)
[6]:
with ActivationAddition(["transformer.h.7.mlp"]).prepare(model) as hooked_model:
td = TensorDict({("positive", "input"): positive_inputs, ("negative", "input"): negative_inputs}, batch_size=1)
td = hooked_model(td)
steering_vector = td.get(("steer", "transformer.h.7.mlp")).sum(dim=0)
Define steering function
[7]:
def steer_fn(module_key, output):
return output + 4 * steering_vector
Apply steering during inference
[8]:
with SteeringVectors(["transformer.h.7.mlp"], steer_fn=steer_fn).prepare(model) as hooked_model:
td = TensorDict({"input": base_inputs}, batch_size=1)
td = hooked_model(td)
Compare results
[9]:
steered_token = td.get(("output", "logits")).max(dim=-1).indices[0, -1]
original_token = model(base_inputs)["logits"].max(dim=-1).indices[0, -1]
print(f"Steered: {tokenizer.decode(steered_token)}") # Output: "pilot"
print(f"Original: {tokenizer.decode(original_token)}") # Output: "writer"
Steered: pilot
Original: writer