Files
RL-Sim-Framework/scripts/eval.py
2026-03-22 15:49:13 +01:00

380 lines
13 KiB
Python

"""Evaluate a trained policy on real hardware (or in simulation).
Loads a checkpoint and runs the policy in a closed loop. For real
hardware the serial runner talks to the ESP32; for sim it uses the
MuJoCo runner. A digital-twin MuJoCo viewer mirrors the robot state
in both modes.
Usage (real hardware):
mjpython scripts/eval.py env=rotary_cartpole runner=serial \
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
Usage (simulation):
mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
Controls:
Space — pause / resume policy (motor stops while paused)
R — reset environment
Esc — quit
"""
import math
import sys
import time
from pathlib import Path
# Ensure project root is on sys.path
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import hydra
import mujoco
import mujoco.viewer
import numpy as np
import structlog
import torch
from gymnasium import spaces
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from skrl.resources.preprocessors.torch import RunningStandardScaler
from src.core.registry import build_env
from src.models.mlp import SharedMLP
logger = structlog.get_logger()
# ── keyboard state ───────────────────────────────────────────────────
_reset_flag = [False]
_paused = [False]
_quit_flag = [False]
def _key_callback(keycode: int) -> None:
"""Called by MuJoCo viewer on key press."""
if keycode == 32: # GLFW_KEY_SPACE
_paused[0] = not _paused[0]
elif keycode == 82: # GLFW_KEY_R
_reset_flag[0] = True
elif keycode == 256: # GLFW_KEY_ESCAPE
_quit_flag[0] = True
# ── checkpoint loading ───────────────────────────────────────────────
def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
"""Infer hidden layer sizes from a SharedMLP state dict."""
sizes = []
i = 0
while f"net.{i}.weight" in state_dict:
sizes.append(state_dict[f"net.{i}.weight"].shape[0])
i += 2 # skip activation layers (ELU)
return tuple(sizes)
def load_policy(
checkpoint_path: str,
observation_space: spaces.Space,
action_space: spaces.Space,
device: torch.device = torch.device("cpu"),
) -> tuple[SharedMLP, RunningStandardScaler]:
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
Returns:
(model, state_preprocessor) ready for inference.
"""
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Infer architecture from saved weights.
hidden_sizes = _infer_hidden_sizes(ckpt["policy"])
# Reconstruct model.
model = SharedMLP(
observation_space=observation_space,
action_space=action_space,
device=device,
hidden_sizes=hidden_sizes,
)
model.load_state_dict(ckpt["policy"])
model.eval()
# Reconstruct observation normalizer.
state_preprocessor = RunningStandardScaler(size=observation_space, device=device)
state_preprocessor.running_mean = ckpt["state_preprocessor"]["running_mean"].to(device)
state_preprocessor.running_variance = ckpt["state_preprocessor"]["running_variance"].to(device)
state_preprocessor.current_count = ckpt["state_preprocessor"]["current_count"]
# Freeze the normalizer — don't update stats during eval.
state_preprocessor.training = False
logger.info(
"checkpoint_loaded",
path=checkpoint_path,
hidden_sizes=hidden_sizes,
obs_mean=[round(x, 3) for x in state_preprocessor.running_mean.tolist()],
obs_std=[round(x, 3) for x in state_preprocessor.running_variance.sqrt().tolist()],
)
return model, state_preprocessor
# ── action arrow overlay ─────────────────────────────────────────────
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
"""Draw an arrow showing applied torque direction."""
if abs(action_val) < 0.01 or model.nu == 0:
return
jnt_id = model.actuator_trnid[0, 0]
body_id = model.jnt_bodyid[jnt_id]
pos = data.xpos[body_id].copy()
pos[2] += 0.02
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
arrow_len = 0.08 * action_val
direction = axis * np.sign(arrow_len)
z = direction / (np.linalg.norm(direction) + 1e-8)
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
x = np.cross(up, z)
x /= np.linalg.norm(x) + 1e-8
y = np.cross(z, x)
mat = np.column_stack([x, y, z]).flatten()
rgba = np.array(
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
dtype=np.float32,
)
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
mujoco.mjv_initGeom(
geom,
type=mujoco.mjtGeom.mjGEOM_ARROW,
size=np.array([0.008, 0.008, abs(arrow_len)]),
pos=pos,
mat=mat,
rgba=rgba,
)
viewer.user_scn.ngeom += 1
# ── main loops ───────────────────────────────────────────────────────
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "cartpole")
runner_name = choices.get("runner", "mujoco_single")
checkpoint_path = cfg.get("checkpoint", None)
if checkpoint_path is None:
logger.error("No checkpoint specified. Use: +checkpoint=path/to/agent.pt")
sys.exit(1)
# Resolve relative paths against original working directory.
checkpoint_path = str(Path(hydra.utils.get_original_cwd()) / checkpoint_path)
if not Path(checkpoint_path).exists():
logger.error("checkpoint_not_found", path=checkpoint_path)
sys.exit(1)
if runner_name == "serial":
_eval_serial(cfg, env_name, checkpoint_path)
else:
_eval_sim(cfg, env_name, checkpoint_path)
def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
"""Evaluate policy in MuJoCo simulation with viewer."""
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
runner_dict["num_envs"] = 1
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
device = runner.device
model, preprocessor = load_policy(
checkpoint_path, runner.observation_space, runner.action_space, device
)
mj_model = runner._model
mj_data = runner._data[0]
dt_ctrl = runner.config.dt * runner.config.substeps
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
obs, _ = runner.reset()
step = 0
episode = 0
episode_reward = 0.0
logger.info(
"eval_started",
env=env_name,
mode="simulation",
checkpoint=Path(checkpoint_path).name,
controls="Space=pause, R=reset, Esc=quit",
)
while viewer.is_running() and not _quit_flag[0]:
if _reset_flag[0]:
_reset_flag[0] = False
obs, _ = runner.reset()
step = 0
episode += 1
episode_reward = 0.0
logger.info("reset", episode=episode)
if _paused[0]:
viewer.sync()
time.sleep(0.05)
continue
# Policy inference
with torch.no_grad():
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
action = model.act({"states": normalized_obs}, role="policy")[0]
action = action.clamp(-1.0, 1.0)
obs, reward, terminated, truncated, info = runner.step(action)
episode_reward += reward.item()
step += 1
# Sync viewer
mujoco.mj_forward(mj_model, mj_data)
viewer.user_scn.ngeom = 0
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
viewer.sync()
if step % 50 == 0:
joints = {mj_model.jnt(i).name: round(math.degrees(mj_data.qpos[i]), 1)
for i in range(mj_model.njnt)}
logger.debug(
"step", n=step, reward=round(reward.item(), 3),
action=round(action[0, 0].item(), 2),
ep_reward=round(episode_reward, 1), **joints,
)
if terminated.any() or truncated.any():
logger.info(
"episode_done", episode=episode, steps=step,
total_reward=round(episode_reward, 2),
reason="terminated" if terminated.any() else "truncated",
)
obs, _ = runner.reset()
step = 0
episode += 1
episode_reward = 0.0
time.sleep(dt_ctrl)
runner.close()
def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
"""Evaluate policy on real hardware via serial, with digital-twin viewer."""
from src.runners.serial import SerialRunner, SerialRunnerConfig
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
serial_runner = SerialRunner(env=env, config=SerialRunnerConfig(**runner_dict))
device = serial_runner.device
model, preprocessor = load_policy(
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device
)
# Set up digital-twin MuJoCo model for visualization.
serial_runner._ensure_viz_model()
mj_model = serial_runner._viz_model
mj_data = serial_runner._viz_data
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
obs, _ = serial_runner.reset()
step = 0
episode = 0
episode_reward = 0.0
logger.info(
"eval_started",
env=env_name,
mode="real hardware (serial)",
port=serial_runner.config.port,
checkpoint=Path(checkpoint_path).name,
controls="Space=pause, R=reset, Esc=quit",
)
while viewer.is_running() and not _quit_flag[0]:
if _reset_flag[0]:
_reset_flag[0] = False
serial_runner._send("M0")
serial_runner._drive_to_center()
serial_runner._wait_for_pendulum_still()
obs, _ = serial_runner.reset()
step = 0
episode += 1
episode_reward = 0.0
logger.info("reset", episode=episode)
if _paused[0]:
serial_runner._send("M0") # safety: stop motor while paused
serial_runner._sync_viz()
viewer.sync()
time.sleep(0.05)
continue
# Policy inference
with torch.no_grad():
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
action = model.act({"states": normalized_obs}, role="policy")[0]
action = action.clamp(-1.0, 1.0)
obs, reward, terminated, truncated, info = serial_runner.step(action)
episode_reward += reward.item()
step += 1
# Sync digital twin with real sensor data.
serial_runner._sync_viz()
viewer.user_scn.ngeom = 0
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
viewer.sync()
if step % 25 == 0:
state = serial_runner._read_state()
logger.debug(
"step", n=step, reward=round(reward.item(), 3),
action=round(action[0, 0].item(), 2),
ep_reward=round(episode_reward, 1),
motor_enc=state["encoder_count"],
pend_deg=round(state["pendulum_angle"], 1),
)
# Check for safety / disconnection.
if info.get("reboot_detected") or info.get("motor_limit_exceeded"):
logger.error(
"safety_stop",
reboot=info.get("reboot_detected", False),
motor_limit=info.get("motor_limit_exceeded", False),
)
serial_runner._send("M0")
break
if terminated.any() or truncated.any():
logger.info(
"episode_done", episode=episode, steps=step,
total_reward=round(episode_reward, 2),
reason="terminated" if terminated.any() else "truncated",
)
# Auto-reset for next episode.
obs, _ = serial_runner.reset()
step = 0
episode += 1
episode_reward = 0.0
# Real-time pacing is handled by serial_runner.step() (dt sleep).
serial_runner.close()
if __name__ == "__main__":
main()