TorchRL PPO Action Probing#

This notebook demonstrates how to set up a TorchRL PPO agent and use tdhook to probe action representations.

Setup#

[1]:
import importlib.util

DEV = True

if importlib.util.find_spec("google.colab") is not None:
    MODE = "colab-dev" if DEV else "colab"
else:
    MODE = "local"
[2]:
if MODE == "colab":
    %pip install -q tdhook torchrl
elif MODE == "colab-dev":
    !rm -rf tdhook
    !git clone https://github.com/Xmaster6y/tdhook -b main
    %pip install -q ./tdhook torchrl

Imports#

[3]:
import torch
from torch import nn
from torchrl.envs import TransformedEnv, Compose, DoubleToFloat, StepCounter
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import MLP, ProbabilisticActor, NormalParamExtractor, TanhNormal
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tensordict.nn import TensorDictModule
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

from tdhook.latent.probing import Probing, ProbeManager

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Create Environment#

[4]:
env_name = "InvertedDoublePendulum-v4"
base_env = GymEnv(env_name)
env = TransformedEnv(
    base_env,
    Compose(
        DoubleToFloat(),
        StepCounter(),
    ),
)
print(f"Observation space: {env.observation_spec}")
print(f"Action space: {env.action_spec}")
Observation space: Composite(
    observation: UnboundedContinuous(
        shape=torch.Size([11]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    step_count: BoundedDiscrete(
        shape=torch.Size([1]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
        device=cpu,
        dtype=torch.int64,
        domain=discrete),
    device=None,
    shape=torch.Size([]),
    data_cls=None)
Action space: BoundedContinuous(
    shape=torch.Size([1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)

Create Actor and Critic Networks#

[5]:
hidden_size = 32
num_cells = 6

# Actor network
actor_module = TensorDictModule(
    nn.Sequential(
        MLP(
            in_features=env.observation_spec["observation"].shape[-1],
            out_features=2 * env.action_spec.shape[-1],
            num_cells=[hidden_size] * num_cells,
        ),
        NormalParamExtractor(),
    ),
    in_keys=["observation"],
    out_keys=["loc", "scale"],
)

actor = ProbabilisticActor(
    module=actor_module,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": env.action_spec.space.low,
        "high": env.action_spec.space.high,
    },
    return_log_prob=True,
)

# Critic network
critic = TensorDictModule(
    MLP(
        in_features=env.observation_spec["observation"].shape[-1],
        out_features=1,
        num_cells=[hidden_size] * num_cells,
    ),
    in_keys=["observation"],
    out_keys=["state_value"],
)

print("Actor network created")
print("Critic network created")
Actor network created
Critic network created

Create PPO Loss Module#

[6]:
advantage_module = GAE(
    gamma=0.99,
    lmbda=0.95,
    value_network=critic,
    average_gae=True,
    device=device,
    deactivate_vmap=True,
)

loss_module = ClipPPOLoss(
    actor_network=actor,
    critic_network=critic,
    clip_epsilon=0.2,
    entropy_bonus=True,
    entropy_coeff=1e-4,
    critic_coeff=1.0,
    loss_critic_type="smooth_l1",
    functional=False,
)

print("PPO loss module created")
PPO loss module created

Collect Sample Data#

[7]:
from torchrl.collectors import SyncDataCollector

collector = SyncDataCollector(
    env,
    actor,
    frames_per_batch=100,
    total_frames=100,
)

# Collect a batch of data
batch = next(iter(collector))
advantage_module(batch)

print(f"Batch shape: {batch.shape}")
print(f"Batch keys: {batch.keys(True, True)}")
Batch shape: torch.Size([100])
Batch keys: _TensorDictKeysView(['step_count', 'action', 'observation', 'done', 'terminated', 'truncated', ('next', 'observation'), ('next', 'step_count'), ('next', 'reward'), ('next', 'done'), ('next', 'terminated'), ('next', 'truncated'), ('next', 'state_value'), ('collector', 'traj_ids'), 'loc', 'scale', 'action_log_prob', 'state_value', 'advantage', 'value_target'],
    include_nested=True,
    leaves_only=True)

Set Up Action Probing#

[8]:
# Split data into train and test
indices = torch.randperm(batch.numel())
split_idx = int(0.8 * batch.numel())
train_indices, test_indices = indices[:split_idx], indices[split_idx:]
train_batch = batch[train_indices]
test_batch = batch[test_indices]

# Create probe manager
probe_manager = ProbeManager(
    LinearRegression,
    {},
    lambda preds, labels: {"r2": r2_score(labels, preds)},
)

Run Probing on Actor and Critic Layers#

[9]:
# Hook into actor and critic layers to probe action representations
with Probing(
    "td_module.(critic_network.module.\d+|actor_network.module.0.module.0.\d+)",
    probe_manager.probe_factory,
    additional_keys=["labels", "step_type"],
    relative=False,
).prepare(loss_module) as hooked_module:
    # Fit probes on training data
    train_batch["labels"] = train_batch["action"]
    train_batch["step_type"] = "fit"
    hooked_module(train_batch)

    # Evaluate probes on test data
    test_batch["labels"] = test_batch["action"]
    test_batch["step_type"] = "predict"
    hooked_module(test_batch)

Display Results#

[10]:
print("Training R² scores:")
for key, value in probe_manager.fit_metrics.items():
    print(f"  {key}: {value['r2']:.3f}")

print("\nTest R² scores:")
for key, value in probe_manager.predict_metrics.items():
    print(f"  {key}: {value['r2']:.3f}")
Training R² scores:
  td_module.actor_network.module.0.module.0.0_fwd: 0.057
  td_module.actor_network.module.0.module.0.1_fwd: 0.457
  td_module.actor_network.module.0.module.0.2_fwd: 0.457
  td_module.actor_network.module.0.module.0.3_fwd: 0.424
  td_module.actor_network.module.0.module.0.4_fwd: 0.424
  td_module.actor_network.module.0.module.0.5_fwd: 0.392
  td_module.actor_network.module.0.module.0.6_fwd: 0.391
  td_module.actor_network.module.0.module.0.7_fwd: 0.475
  td_module.actor_network.module.0.module.0.8_fwd: 0.429
  td_module.actor_network.module.0.module.0.9_fwd: 0.399
  td_module.actor_network.module.0.module.0.10_fwd: 0.376
  td_module.actor_network.module.0.module.0.11_fwd: 0.377
  td_module.actor_network.module.0.module.0.12_fwd: 0.003
  td_module.critic_network.module.0_fwd: 0.057
  td_module.critic_network.module.1_fwd: 0.380
  td_module.critic_network.module.2_fwd: 0.380
  td_module.critic_network.module.3_fwd: 0.315
  td_module.critic_network.module.4_fwd: 0.314
  td_module.critic_network.module.5_fwd: 0.318
  td_module.critic_network.module.6_fwd: 0.316
  td_module.critic_network.module.7_fwd: 0.333
  td_module.critic_network.module.8_fwd: 0.307
  td_module.critic_network.module.9_fwd: 0.307
  td_module.critic_network.module.10_fwd: 0.288
  td_module.critic_network.module.11_fwd: 0.337
  td_module.critic_network.module.12_fwd: 0.013

Test R² scores:
  td_module.actor_network.module.0.module.0.0_fwd: 0.160
  td_module.actor_network.module.0.module.0.1_fwd: -0.743
  td_module.actor_network.module.0.module.0.2_fwd: -0.743
  td_module.actor_network.module.0.module.0.3_fwd: -2.509
  td_module.actor_network.module.0.module.0.4_fwd: -2.506
  td_module.actor_network.module.0.module.0.5_fwd: -2.913
  td_module.actor_network.module.0.module.0.6_fwd: -2.848
  td_module.actor_network.module.0.module.0.7_fwd: -2.626
  td_module.actor_network.module.0.module.0.8_fwd: -3.136
  td_module.actor_network.module.0.module.0.9_fwd: -2.513
  td_module.actor_network.module.0.module.0.10_fwd: -2.363
  td_module.actor_network.module.0.module.0.11_fwd: -1.602
  td_module.actor_network.module.0.module.0.12_fwd: 0.038
  td_module.critic_network.module.0_fwd: 0.160
  td_module.critic_network.module.1_fwd: -2.264
  td_module.critic_network.module.2_fwd: -2.264
  td_module.critic_network.module.3_fwd: -1.375
  td_module.critic_network.module.4_fwd: -1.235
  td_module.critic_network.module.5_fwd: -1.553
  td_module.critic_network.module.6_fwd: -1.480
  td_module.critic_network.module.7_fwd: -0.289
  td_module.critic_network.module.8_fwd: -0.219
  td_module.critic_network.module.9_fwd: -0.883
  td_module.critic_network.module.10_fwd: -0.278
  td_module.critic_network.module.11_fwd: -1.791
  td_module.critic_network.module.12_fwd: 0.120

Note: The R² scores shown above are expected to be poor (often negative) because the model is not trained. The actor and critic networks are initialized with random weights, so their internal representations do not yet encode meaningful information about actions. After training the PPO agent, you would expect to see higher R² scores, indicating that the network layers learn to represent action-relevant information.