Chess Value Saliency#

This notebook demonstrates how to use tdhook’s Saliency to compute attribution maps for chess model predictions, showing which squares on the board are most important for the model’s decisions.

Setup#

[1]:
import importlib.util

DEV = True

if importlib.util.find_spec("google.colab") is not None:
    MODE = "colab-dev" if DEV else "colab"
else:
    MODE = "local"
[2]:
if MODE == "colab":
    %pip install -q tdhook lczerolens
elif MODE == "colab-dev":
    !rm -rf tdhook
    !git clone https://github.com/Xmaster6y/tdhook -b main
    %pip install -q ./tdhook lczerolens

Imports#

[3]:
from tensordict import TensorDict
from lczerolens import LczeroModel, LczeroBoard
from tdhook.attribution import Saliency
from IPython.display import HTML

Load Model and Set Up Board Position#

[4]:
# Load a chess model from HuggingFace
model = LczeroModel.from_hf("lczerolens/maia-1100")

# Set up a chess position
fen = "5k2/2R5/1PQ5/2Pp1n2/5P2/2b1r3/3K2P1/8 w - - 11 42"
board = LczeroBoard(fen)

moves = "d2c2 f5d4 c2b1"
for move in moves.split(" "):
    board.push_uci(move)

print(f"Board position after moves: {moves}")
print(f"FEN: {board.fen()}")

td = model(board)
Board position after moves: d2c2 f5d4 c2b1
FEN: 5k2/2R5/1PQ5/2Pp4/3n1P2/2b1r3/6P1/1K6 b - - 14 43

Compute Saliency for Best Move#

We’ll compute which squares are most important for the model’s best move prediction.

[5]:
# Define a function that extracts the best move logit as the target for attribution
def best_logit_init_targets(td: TensorDict, _):
    policy = td["policy"]
    best_logit = policy.max(dim=-1).values
    return TensorDict(out=best_logit, batch_size=td.batch_size)


# Compute saliency
saliency_context = Saliency(init_attr_targets=best_logit_init_targets)
with saliency_context.prepare(model) as hooked_model:
    output = hooked_model(td)

    # Get the best move
    move = board.decode_move(output["policy"][0].argmax())
    arrows = [(move.from_square, move.to_square)]
    print(f"Best move: {move}")

    # Get attribution for board squares
    batch_index = 0
    plane = 1

    svg_board, svg_colorbar = board.render_heatmap(
        output.get(("attr", "board"))[batch_index, plane].view(64).detach(), arrows=arrows, normalise="abs"
    )
    display(HTML(f"{svg_board}{svg_colorbar}"))
Best move: d4c6
. . . . . k . .
. . R . . . . .
. P Q . . . . .
. . P p . . . .
. . . n . P . .
. . b . r . . .
. . . . . . P .
. K . . . . . .
2026-02-16T14:10:47.246867 image/svg+xml Matplotlib v3.10.5, https://matplotlib.org/

Compute Saliency for Win/Draw/Lose Predictions#

Now we’ll compute saliency for the model’s win/draw/lose (WDL) predictions to see which squares influence the outcome evaluation.

[6]:
# Helper function to create init_targets for WDL predictions
def get_init_targets(idx: int):
    def init_targets(td, _):
        return TensorDict(out=td["wdl"][..., idx], batch_size=td.batch_size)

    return init_targets


batch_index = 0
plane = 1

# Compute saliency for win, draw, and lose predictions
for idx, name in enumerate(["win", "draw", "lose"]):
    print(f"\nComputing {name} saliency...")
    saliency_context = Saliency(init_attr_targets=get_init_targets(idx))
    with saliency_context.prepare(model) as hooked_model:
        output = hooked_model(td)
        wdl_value = output["wdl"][0, idx].item()
        print(f"{name.capitalize()} probability: {wdl_value:.2f}")

        # Get attribution for board squares
        svg_board, svg_colorbar = board.render_heatmap(
            output.get(("attr", "board"))[batch_index, plane].view(64).detach(), normalise="abs"
        )
        display(HTML(f"{svg_board}{svg_colorbar}"))

Computing win saliency...
Win probability: 0.54
. . . . . k . .
. . R . . . . .
. P Q . . . . .
. . P p . . . .
. . . n . P . .
. . b . r . . .
. . . . . . P .
. K . . . . . .
2026-02-16T14:10:47.313058 image/svg+xml Matplotlib v3.10.5, https://matplotlib.org/

Computing draw saliency...
Draw probability: 0.12
. . . . . k . .
. . R . . . . .
. P Q . . . . .
. . P p . . . .
. . . n . P . .
. . b . r . . .
. . . . . . P .
. K . . . . . .
2026-02-16T14:10:47.350035 image/svg+xml Matplotlib v3.10.5, https://matplotlib.org/

Computing lose saliency...
Lose probability: 0.34
. . . . . k . .
. . R . . . . .
. P Q . . . . .
. . P p . . . .
. . . n . P . .
. . b . r . . .
. . . . . . P .
. K . . . . . .
2026-02-16T14:10:47.385207 image/svg+xml Matplotlib v3.10.5, https://matplotlib.org/

Understanding the Results#

The saliency maps show which squares on the chess board are most important for the model’s predictions:

  • Best move saliency: Highlights squares that influence which move the model considers best

  • Win/Draw/Lose saliency: Shows which squares affect the model’s evaluation of the game outcome

The attribution values indicate how much each square contributes to the prediction, with higher absolute values indicating greater importance. Positive values suggest the square increases the target value (e.g., makes a win more likely), while negative values suggest it decreases it.