Chess Value Saliency#
This notebook demonstrates how to use tdhook’s Saliency to compute attribution maps for chess model predictions, showing which squares on the board are most important for the model’s decisions.
Setup#
[1]:
import importlib.util
DEV = True
if importlib.util.find_spec("google.colab") is not None:
MODE = "colab-dev" if DEV else "colab"
else:
MODE = "local"
[2]:
if MODE == "colab":
%pip install -q tdhook lczerolens
elif MODE == "colab-dev":
!rm -rf tdhook
!git clone https://github.com/Xmaster6y/tdhook -b main
%pip install -q ./tdhook lczerolens
Imports#
[3]:
from tensordict import TensorDict
from lczerolens import LczeroModel, LczeroBoard
from tdhook.attribution import Saliency
from IPython.display import HTML
Load Model and Set Up Board Position#
[4]:
# Load a chess model from HuggingFace
model = LczeroModel.from_hf("lczerolens/maia-1100")
# Set up a chess position
fen = "5k2/2R5/1PQ5/2Pp1n2/5P2/2b1r3/3K2P1/8 w - - 11 42"
board = LczeroBoard(fen)
moves = "d2c2 f5d4 c2b1"
for move in moves.split(" "):
board.push_uci(move)
print(f"Board position after moves: {moves}")
print(f"FEN: {board.fen()}")
td = model(board)
Board position after moves: d2c2 f5d4 c2b1
FEN: 5k2/2R5/1PQ5/2Pp4/3n1P2/2b1r3/6P1/1K6 b - - 14 43
Compute Saliency for Best Move#
We’ll compute which squares are most important for the model’s best move prediction.
[5]:
# Define a function that extracts the best move logit as the target for attribution
def best_logit_init_targets(td: TensorDict, _):
policy = td["policy"]
best_logit = policy.max(dim=-1).values
return TensorDict(out=best_logit, batch_size=td.batch_size)
# Compute saliency
saliency_context = Saliency(init_attr_targets=best_logit_init_targets)
with saliency_context.prepare(model) as hooked_model:
output = hooked_model(td)
# Get the best move
move = board.decode_move(output["policy"][0].argmax())
arrows = [(move.from_square, move.to_square)]
print(f"Best move: {move}")
# Get attribution for board squares
batch_index = 0
plane = 1
svg_board, svg_colorbar = board.render_heatmap(
output.get(("attr", "board"))[batch_index, plane].view(64).detach(), arrows=arrows, normalise="abs"
)
display(HTML(f"{svg_board}{svg_colorbar}"))
Best move: d4c6
Compute Saliency for Win/Draw/Lose Predictions#
Now we’ll compute saliency for the model’s win/draw/lose (WDL) predictions to see which squares influence the outcome evaluation.
[6]:
# Helper function to create init_targets for WDL predictions
def get_init_targets(idx: int):
def init_targets(td, _):
return TensorDict(out=td["wdl"][..., idx], batch_size=td.batch_size)
return init_targets
batch_index = 0
plane = 1
# Compute saliency for win, draw, and lose predictions
for idx, name in enumerate(["win", "draw", "lose"]):
print(f"\nComputing {name} saliency...")
saliency_context = Saliency(init_attr_targets=get_init_targets(idx))
with saliency_context.prepare(model) as hooked_model:
output = hooked_model(td)
wdl_value = output["wdl"][0, idx].item()
print(f"{name.capitalize()} probability: {wdl_value:.2f}")
# Get attribution for board squares
svg_board, svg_colorbar = board.render_heatmap(
output.get(("attr", "board"))[batch_index, plane].view(64).detach(), normalise="abs"
)
display(HTML(f"{svg_board}{svg_colorbar}"))
Computing win saliency...
Win probability: 0.54
Computing draw saliency...
Draw probability: 0.12
Computing lose saliency...
Lose probability: 0.34
Understanding the Results#
The saliency maps show which squares on the chess board are most important for the model’s predictions:
Best move saliency: Highlights squares that influence which move the model considers best
Win/Draw/Lose saliency: Shows which squares affect the model’s evaluation of the game outcome
The attribution values indicate how much each square contributes to the prediction, with higher absolute values indicating greater importance. Positive values suggest the square increases the target value (e.g., makes a win more likely), while negative values suggest it decreases it.