"""Evaluate a trained policy on real hardware (or in simulation). Loads a checkpoint and runs the policy in a closed loop. For real hardware the serial runner talks to the ESP32; for sim it uses the MuJoCo runner. A digital-twin MuJoCo viewer mirrors the robot state in both modes. Usage (real hardware): mjpython scripts/eval.py env=rotary_cartpole runner=serial \ checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt Usage (simulation): mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \ checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt Controls: Space — pause / resume policy (motor stops while paused) R — reset environment Esc — quit """ import math import sys import time from pathlib import Path # Ensure project root is on sys.path _PROJECT_ROOT = str(Path(__file__).resolve().parent.parent) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) import hydra import mujoco import mujoco.viewer import numpy as np import structlog import torch from gymnasium import spaces from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf from skrl.resources.preprocessors.torch import RunningStandardScaler from src.core.registry import build_env from src.models.mlp import SharedMLP logger = structlog.get_logger() # ── keyboard state ─────────────────────────────────────────────────── _reset_flag = [False] _paused = [False] _quit_flag = [False] def _key_callback(keycode: int) -> None: """Called by MuJoCo viewer on key press.""" if keycode == 32: # GLFW_KEY_SPACE _paused[0] = not _paused[0] elif keycode == 82: # GLFW_KEY_R _reset_flag[0] = True elif keycode == 256: # GLFW_KEY_ESCAPE _quit_flag[0] = True # ── checkpoint loading ─────────────────────────────────────────────── def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]: """Infer hidden layer sizes from a SharedMLP state dict.""" sizes = [] i = 0 while f"net.{i}.weight" in state_dict: sizes.append(state_dict[f"net.{i}.weight"].shape[0]) i += 2 # skip activation layers (ELU) return tuple(sizes) def load_policy( checkpoint_path: str, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device = torch.device("cpu"), ) -> tuple[SharedMLP, RunningStandardScaler]: """Load a trained SharedMLP + observation normalizer from a checkpoint. Returns: (model, state_preprocessor) ready for inference. """ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) # Infer architecture from saved weights. hidden_sizes = _infer_hidden_sizes(ckpt["policy"]) # Reconstruct model. model = SharedMLP( observation_space=observation_space, action_space=action_space, device=device, hidden_sizes=hidden_sizes, ) model.load_state_dict(ckpt["policy"]) model.eval() # Reconstruct observation normalizer. state_preprocessor = RunningStandardScaler(size=observation_space, device=device) state_preprocessor.running_mean = ckpt["state_preprocessor"]["running_mean"].to(device) state_preprocessor.running_variance = ckpt["state_preprocessor"]["running_variance"].to(device) state_preprocessor.current_count = ckpt["state_preprocessor"]["current_count"] # Freeze the normalizer — don't update stats during eval. state_preprocessor.training = False logger.info( "checkpoint_loaded", path=checkpoint_path, hidden_sizes=hidden_sizes, obs_mean=[round(x, 3) for x in state_preprocessor.running_mean.tolist()], obs_std=[round(x, 3) for x in state_preprocessor.running_variance.sqrt().tolist()], ) return model, state_preprocessor # ── action arrow overlay ───────────────────────────────────────────── def _add_action_arrow(viewer, model, data, action_val: float) -> None: """Draw an arrow showing applied torque direction.""" if abs(action_val) < 0.01 or model.nu == 0: return jnt_id = model.actuator_trnid[0, 0] body_id = model.jnt_bodyid[jnt_id] pos = data.xpos[body_id].copy() pos[2] += 0.02 axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id] arrow_len = 0.08 * action_val direction = axis * np.sign(arrow_len) z = direction / (np.linalg.norm(direction) + 1e-8) up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0]) x = np.cross(up, z) x /= np.linalg.norm(x) + 1e-8 y = np.cross(z, x) mat = np.column_stack([x, y, z]).flatten() rgba = np.array( [0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8], dtype=np.float32, ) geom = viewer.user_scn.geoms[viewer.user_scn.ngeom] mujoco.mjv_initGeom( geom, type=mujoco.mjtGeom.mjGEOM_ARROW, size=np.array([0.008, 0.008, abs(arrow_len)]), pos=pos, mat=mat, rgba=rgba, ) viewer.user_scn.ngeom += 1 # ── main loops ─────────────────────────────────────────────────────── @hydra.main(version_base=None, config_path="../configs", config_name="config") def main(cfg: DictConfig) -> None: choices = HydraConfig.get().runtime.choices env_name = choices.get("env", "cartpole") runner_name = choices.get("runner", "mujoco_single") checkpoint_path = cfg.get("checkpoint", None) if checkpoint_path is None: logger.error("No checkpoint specified. Use: +checkpoint=path/to/agent.pt") sys.exit(1) # Resolve relative paths against original working directory. checkpoint_path = str(Path(hydra.utils.get_original_cwd()) / checkpoint_path) if not Path(checkpoint_path).exists(): logger.error("checkpoint_not_found", path=checkpoint_path) sys.exit(1) if runner_name == "serial": _eval_serial(cfg, env_name, checkpoint_path) else: _eval_sim(cfg, env_name, checkpoint_path) def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None: """Evaluate policy in MuJoCo simulation with viewer.""" from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig env = build_env(env_name, cfg) runner_dict = OmegaConf.to_container(cfg.runner, resolve=True) runner_dict["num_envs"] = 1 runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict)) device = runner.device model, preprocessor = load_policy( checkpoint_path, runner.observation_space, runner.action_space, device ) mj_model = runner._model mj_data = runner._data[0] dt_ctrl = runner.config.dt * runner.config.substeps with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer: obs, _ = runner.reset() step = 0 episode = 0 episode_reward = 0.0 logger.info( "eval_started", env=env_name, mode="simulation", checkpoint=Path(checkpoint_path).name, controls="Space=pause, R=reset, Esc=quit", ) while viewer.is_running() and not _quit_flag[0]: if _reset_flag[0]: _reset_flag[0] = False obs, _ = runner.reset() step = 0 episode += 1 episode_reward = 0.0 logger.info("reset", episode=episode) if _paused[0]: viewer.sync() time.sleep(0.05) continue # Policy inference with torch.no_grad(): normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs) action = model.act({"states": normalized_obs}, role="policy")[0] action = action.clamp(-1.0, 1.0) obs, reward, terminated, truncated, info = runner.step(action) episode_reward += reward.item() step += 1 # Sync viewer mujoco.mj_forward(mj_model, mj_data) viewer.user_scn.ngeom = 0 _add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item()) viewer.sync() if step % 50 == 0: joints = {mj_model.jnt(i).name: round(math.degrees(mj_data.qpos[i]), 1) for i in range(mj_model.njnt)} logger.debug( "step", n=step, reward=round(reward.item(), 3), action=round(action[0, 0].item(), 2), ep_reward=round(episode_reward, 1), **joints, ) if terminated.any() or truncated.any(): logger.info( "episode_done", episode=episode, steps=step, total_reward=round(episode_reward, 2), reason="terminated" if terminated.any() else "truncated", ) obs, _ = runner.reset() step = 0 episode += 1 episode_reward = 0.0 time.sleep(dt_ctrl) runner.close() def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None: """Evaluate policy on real hardware via serial, with digital-twin viewer.""" from src.runners.serial import SerialRunner, SerialRunnerConfig env = build_env(env_name, cfg) runner_dict = OmegaConf.to_container(cfg.runner, resolve=True) serial_runner = SerialRunner(env=env, config=SerialRunnerConfig(**runner_dict)) device = serial_runner.device model, preprocessor = load_policy( checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device ) # Set up digital-twin MuJoCo model for visualization. serial_runner._ensure_viz_model() mj_model = serial_runner._viz_model mj_data = serial_runner._viz_data with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer: obs, _ = serial_runner.reset() step = 0 episode = 0 episode_reward = 0.0 logger.info( "eval_started", env=env_name, mode="real hardware (serial)", port=serial_runner.config.port, checkpoint=Path(checkpoint_path).name, controls="Space=pause, R=reset, Esc=quit", ) while viewer.is_running() and not _quit_flag[0]: if _reset_flag[0]: _reset_flag[0] = False serial_runner._send("M0") serial_runner._drive_to_center() serial_runner._wait_for_pendulum_still() obs, _ = serial_runner.reset() step = 0 episode += 1 episode_reward = 0.0 logger.info("reset", episode=episode) if _paused[0]: serial_runner._send("M0") # safety: stop motor while paused serial_runner._sync_viz() viewer.sync() time.sleep(0.05) continue # Policy inference with torch.no_grad(): normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs) action = model.act({"states": normalized_obs}, role="policy")[0] action = action.clamp(-1.0, 1.0) obs, reward, terminated, truncated, info = serial_runner.step(action) episode_reward += reward.item() step += 1 # Sync digital twin with real sensor data. serial_runner._sync_viz() viewer.user_scn.ngeom = 0 _add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item()) viewer.sync() if step % 25 == 0: state = serial_runner._read_state() logger.debug( "step", n=step, reward=round(reward.item(), 3), action=round(action[0, 0].item(), 2), ep_reward=round(episode_reward, 1), motor_enc=state["encoder_count"], pend_deg=round(state["pendulum_angle"], 1), ) # Check for safety / disconnection. if info.get("reboot_detected") or info.get("motor_limit_exceeded"): logger.error( "safety_stop", reboot=info.get("reboot_detected", False), motor_limit=info.get("motor_limit_exceeded", False), ) serial_runner._send("M0") break if terminated.any() or truncated.any(): logger.info( "episode_done", episode=episode, steps=step, total_reward=round(episode_reward, 2), reason="terminated" if terminated.any() else "truncated", ) # Auto-reset for next episode. obs, _ = serial_runner.reset() step = 0 episode += 1 episode_reward = 0.0 # Real-time pacing is handled by serial_runner.step() (dt sleep). serial_runner.close() if __name__ == "__main__": main()