Steering Vectors#

This notebook demonstrates how to use Steering Vectors to modify model behavior.

Setup#

[1]:
import importlib.util

DEV = True

if importlib.util.find_spec("google.colab") is not None:
    MODE = "colab-dev" if DEV else "colab"
else:
    MODE = "local"
[2]:
if MODE == "colab":
    %pip install -q tdhook
elif MODE == "colab-dev":
    !rm -r tdhook
    !git clone https://github.com/Xmaster6y/tdhook -b main
    %pip install -q ./tdhook

Usage#

[3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from tensordict import TensorDict
from tdhook.latent import ActivationAddition, SteeringVectors
/home/docs/checkouts/readthedocs.org/user_builds/tdhook/checkouts/v0.1.2/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Load model and tokenizer

[4]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

Prepare inputs

[5]:
positive_inputs = tokenizer.encode("I am rich.", return_tensors="pt")
negative_inputs = tokenizer.encode("I am poor.", return_tensors="pt")
base_inputs = tokenizer.encode("I work as a", return_tensors="pt")

Extract steering vector (rich - poor)

[6]:
with ActivationAddition(["transformer.h.7.mlp"]).prepare(model) as hooked_model:
    td = TensorDict({("positive", "input"): positive_inputs, ("negative", "input"): negative_inputs}, batch_size=1)
    td = hooked_model(td)

steering_vector = td.get(("steer", "transformer.h.7.mlp")).sum(dim=0)

Define steering function

[7]:
def steer_fn(module_key, output):
    return output + 4 * steering_vector

Apply steering during inference

[8]:
with SteeringVectors(["transformer.h.7.mlp"], steer_fn=steer_fn).prepare(model) as hooked_model:
    td = TensorDict({"input": base_inputs}, batch_size=1)
    td = hooked_model(td)

Compare results

[9]:
steered_token = td.get(("output", "logits")).max(dim=-1).indices[0, -1]
original_token = model(base_inputs)["logits"].max(dim=-1).indices[0, -1]

print(f"Steered: {tokenizer.decode(steered_token)}")  # Output: "pilot"
print(f"Original: {tokenizer.decode(original_token)}")  # Output: "writer"
Steered:  pilot
Original:  writer