TorchRL PPO Action Probing#
This notebook demonstrates how to set up a TorchRL PPO agent and use tdhook to probe action representations.
Setup#
[1]:
import importlib.util
DEV = True
if importlib.util.find_spec("google.colab") is not None:
MODE = "colab-dev" if DEV else "colab"
else:
MODE = "local"
[2]:
if MODE == "colab":
%pip install -q tdhook torchrl
elif MODE == "colab-dev":
!rm -rf tdhook
!git clone https://github.com/Xmaster6y/tdhook -b main
%pip install -q ./tdhook torchrl
Imports#
[3]:
import torch
from torch import nn
from torchrl.envs import TransformedEnv, Compose, DoubleToFloat, StepCounter
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import MLP, ProbabilisticActor, NormalParamExtractor, TanhNormal
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tensordict.nn import TensorDictModule
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from tdhook.latent.probing import Probing, ProbeManager
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Create Environment#
[4]:
env_name = "InvertedDoublePendulum-v4"
base_env = GymEnv(env_name)
env = TransformedEnv(
base_env,
Compose(
DoubleToFloat(),
StepCounter(),
),
)
print(f"Observation space: {env.observation_spec}")
print(f"Action space: {env.action_spec}")
Observation space: Composite(
observation: UnboundedContinuous(
shape=torch.Size([11]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous),
step_count: BoundedDiscrete(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
device=cpu,
dtype=torch.int64,
domain=discrete),
device=None,
shape=torch.Size([]),
data_cls=None)
Action space: BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
Create Actor and Critic Networks#
[5]:
hidden_size = 32
num_cells = 6
# Actor network
actor_module = TensorDictModule(
nn.Sequential(
MLP(
in_features=env.observation_spec["observation"].shape[-1],
out_features=2 * env.action_spec.shape[-1],
num_cells=[hidden_size] * num_cells,
),
NormalParamExtractor(),
),
in_keys=["observation"],
out_keys=["loc", "scale"],
)
actor = ProbabilisticActor(
module=actor_module,
spec=env.action_spec,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
},
return_log_prob=True,
)
# Critic network
critic = TensorDictModule(
MLP(
in_features=env.observation_spec["observation"].shape[-1],
out_features=1,
num_cells=[hidden_size] * num_cells,
),
in_keys=["observation"],
out_keys=["state_value"],
)
print("Actor network created")
print("Critic network created")
Actor network created
Critic network created
Create PPO Loss Module#
[6]:
advantage_module = GAE(
gamma=0.99,
lmbda=0.95,
value_network=critic,
average_gae=True,
device=device,
deactivate_vmap=True,
)
loss_module = ClipPPOLoss(
actor_network=actor,
critic_network=critic,
clip_epsilon=0.2,
entropy_bonus=True,
entropy_coeff=1e-4,
critic_coeff=1.0,
loss_critic_type="smooth_l1",
functional=False,
)
print("PPO loss module created")
PPO loss module created
Collect Sample Data#
[7]:
from torchrl.collectors import SyncDataCollector
collector = SyncDataCollector(
env,
actor,
frames_per_batch=100,
total_frames=100,
)
# Collect a batch of data
batch = next(iter(collector))
advantage_module(batch)
print(f"Batch shape: {batch.shape}")
print(f"Batch keys: {batch.keys(True, True)}")
Batch shape: torch.Size([100])
Batch keys: _TensorDictKeysView(['step_count', 'action', 'observation', 'done', 'terminated', 'truncated', ('next', 'observation'), ('next', 'step_count'), ('next', 'reward'), ('next', 'done'), ('next', 'terminated'), ('next', 'truncated'), ('next', 'state_value'), ('collector', 'traj_ids'), 'loc', 'scale', 'action_log_prob', 'state_value', 'advantage', 'value_target'],
include_nested=True,
leaves_only=True)
Set Up Action Probing#
[8]:
# Split data into train and test
indices = torch.randperm(batch.numel())
split_idx = int(0.8 * batch.numel())
train_indices, test_indices = indices[:split_idx], indices[split_idx:]
train_batch = batch[train_indices]
test_batch = batch[test_indices]
# Create probe manager
probe_manager = ProbeManager(
LinearRegression,
{},
lambda preds, labels: {"r2": r2_score(labels, preds)},
)
Run Probing on Actor and Critic Layers#
[9]:
# Hook into actor and critic layers to probe action representations
with Probing(
"td_module.(critic_network.module.\d+|actor_network.module.0.module.0.\d+)",
probe_manager.probe_factory,
additional_keys=["labels", "step_type"],
relative=False,
).prepare(loss_module) as hooked_module:
# Fit probes on training data
train_batch["labels"] = train_batch["action"]
train_batch["step_type"] = "fit"
hooked_module(train_batch)
# Evaluate probes on test data
test_batch["labels"] = test_batch["action"]
test_batch["step_type"] = "predict"
hooked_module(test_batch)
Display Results#
[10]:
print("Training R² scores:")
for key, value in probe_manager.fit_metrics.items():
print(f" {key}: {value['r2']:.3f}")
print("\nTest R² scores:")
for key, value in probe_manager.predict_metrics.items():
print(f" {key}: {value['r2']:.3f}")
Training R² scores:
td_module.actor_network.module.0.module.0.0_fwd: 0.057
td_module.actor_network.module.0.module.0.1_fwd: 0.457
td_module.actor_network.module.0.module.0.2_fwd: 0.457
td_module.actor_network.module.0.module.0.3_fwd: 0.424
td_module.actor_network.module.0.module.0.4_fwd: 0.424
td_module.actor_network.module.0.module.0.5_fwd: 0.392
td_module.actor_network.module.0.module.0.6_fwd: 0.391
td_module.actor_network.module.0.module.0.7_fwd: 0.475
td_module.actor_network.module.0.module.0.8_fwd: 0.429
td_module.actor_network.module.0.module.0.9_fwd: 0.399
td_module.actor_network.module.0.module.0.10_fwd: 0.376
td_module.actor_network.module.0.module.0.11_fwd: 0.377
td_module.actor_network.module.0.module.0.12_fwd: 0.003
td_module.critic_network.module.0_fwd: 0.057
td_module.critic_network.module.1_fwd: 0.380
td_module.critic_network.module.2_fwd: 0.380
td_module.critic_network.module.3_fwd: 0.315
td_module.critic_network.module.4_fwd: 0.314
td_module.critic_network.module.5_fwd: 0.318
td_module.critic_network.module.6_fwd: 0.316
td_module.critic_network.module.7_fwd: 0.333
td_module.critic_network.module.8_fwd: 0.307
td_module.critic_network.module.9_fwd: 0.307
td_module.critic_network.module.10_fwd: 0.288
td_module.critic_network.module.11_fwd: 0.337
td_module.critic_network.module.12_fwd: 0.013
Test R² scores:
td_module.actor_network.module.0.module.0.0_fwd: 0.160
td_module.actor_network.module.0.module.0.1_fwd: -0.743
td_module.actor_network.module.0.module.0.2_fwd: -0.743
td_module.actor_network.module.0.module.0.3_fwd: -2.509
td_module.actor_network.module.0.module.0.4_fwd: -2.506
td_module.actor_network.module.0.module.0.5_fwd: -2.913
td_module.actor_network.module.0.module.0.6_fwd: -2.848
td_module.actor_network.module.0.module.0.7_fwd: -2.626
td_module.actor_network.module.0.module.0.8_fwd: -3.136
td_module.actor_network.module.0.module.0.9_fwd: -2.513
td_module.actor_network.module.0.module.0.10_fwd: -2.363
td_module.actor_network.module.0.module.0.11_fwd: -1.602
td_module.actor_network.module.0.module.0.12_fwd: 0.038
td_module.critic_network.module.0_fwd: 0.160
td_module.critic_network.module.1_fwd: -2.264
td_module.critic_network.module.2_fwd: -2.264
td_module.critic_network.module.3_fwd: -1.375
td_module.critic_network.module.4_fwd: -1.235
td_module.critic_network.module.5_fwd: -1.553
td_module.critic_network.module.6_fwd: -1.480
td_module.critic_network.module.7_fwd: -0.289
td_module.critic_network.module.8_fwd: -0.219
td_module.critic_network.module.9_fwd: -0.883
td_module.critic_network.module.10_fwd: -0.278
td_module.critic_network.module.11_fwd: -1.791
td_module.critic_network.module.12_fwd: 0.120
Note: The R² scores shown above are expected to be poor (often negative) because the model is not trained. The actor and critic networks are initialized with random weights, so their internal representations do not yet encode meaningful information about actions. After training the PPO agent, you would expect to see higher R² scores, indicating that the network layers learn to represent action-relevant information.