380 lines
13 KiB
Python
380 lines
13 KiB
Python
"""Evaluate a trained policy on real hardware (or in simulation).
|
|
|
|
Loads a checkpoint and runs the policy in a closed loop. For real
|
|
hardware the serial runner talks to the ESP32; for sim it uses the
|
|
MuJoCo runner. A digital-twin MuJoCo viewer mirrors the robot state
|
|
in both modes.
|
|
|
|
Usage (real hardware):
|
|
mjpython scripts/eval.py env=rotary_cartpole runner=serial \
|
|
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
|
|
|
|
Usage (simulation):
|
|
mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \
|
|
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
|
|
|
|
Controls:
|
|
Space — pause / resume policy (motor stops while paused)
|
|
R — reset environment
|
|
Esc — quit
|
|
"""
|
|
|
|
import math
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Ensure project root is on sys.path
|
|
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
|
if _PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, _PROJECT_ROOT)
|
|
|
|
import hydra
|
|
import mujoco
|
|
import mujoco.viewer
|
|
import numpy as np
|
|
import structlog
|
|
import torch
|
|
from gymnasium import spaces
|
|
from hydra.core.hydra_config import HydraConfig
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from skrl.resources.preprocessors.torch import RunningStandardScaler
|
|
|
|
from src.core.registry import build_env
|
|
from src.models.mlp import SharedMLP
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
# ── keyboard state ───────────────────────────────────────────────────
|
|
_reset_flag = [False]
|
|
_paused = [False]
|
|
_quit_flag = [False]
|
|
|
|
|
|
def _key_callback(keycode: int) -> None:
|
|
"""Called by MuJoCo viewer on key press."""
|
|
if keycode == 32: # GLFW_KEY_SPACE
|
|
_paused[0] = not _paused[0]
|
|
elif keycode == 82: # GLFW_KEY_R
|
|
_reset_flag[0] = True
|
|
elif keycode == 256: # GLFW_KEY_ESCAPE
|
|
_quit_flag[0] = True
|
|
|
|
|
|
# ── checkpoint loading ───────────────────────────────────────────────
|
|
|
|
def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
|
|
"""Infer hidden layer sizes from a SharedMLP state dict."""
|
|
sizes = []
|
|
i = 0
|
|
while f"net.{i}.weight" in state_dict:
|
|
sizes.append(state_dict[f"net.{i}.weight"].shape[0])
|
|
i += 2 # skip activation layers (ELU)
|
|
return tuple(sizes)
|
|
|
|
|
|
def load_policy(
|
|
checkpoint_path: str,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
device: torch.device = torch.device("cpu"),
|
|
) -> tuple[SharedMLP, RunningStandardScaler]:
|
|
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
|
|
|
|
Returns:
|
|
(model, state_preprocessor) ready for inference.
|
|
"""
|
|
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
|
|
# Infer architecture from saved weights.
|
|
hidden_sizes = _infer_hidden_sizes(ckpt["policy"])
|
|
|
|
# Reconstruct model.
|
|
model = SharedMLP(
|
|
observation_space=observation_space,
|
|
action_space=action_space,
|
|
device=device,
|
|
hidden_sizes=hidden_sizes,
|
|
)
|
|
model.load_state_dict(ckpt["policy"])
|
|
model.eval()
|
|
|
|
# Reconstruct observation normalizer.
|
|
state_preprocessor = RunningStandardScaler(size=observation_space, device=device)
|
|
state_preprocessor.running_mean = ckpt["state_preprocessor"]["running_mean"].to(device)
|
|
state_preprocessor.running_variance = ckpt["state_preprocessor"]["running_variance"].to(device)
|
|
state_preprocessor.current_count = ckpt["state_preprocessor"]["current_count"]
|
|
# Freeze the normalizer — don't update stats during eval.
|
|
state_preprocessor.training = False
|
|
|
|
logger.info(
|
|
"checkpoint_loaded",
|
|
path=checkpoint_path,
|
|
hidden_sizes=hidden_sizes,
|
|
obs_mean=[round(x, 3) for x in state_preprocessor.running_mean.tolist()],
|
|
obs_std=[round(x, 3) for x in state_preprocessor.running_variance.sqrt().tolist()],
|
|
)
|
|
return model, state_preprocessor
|
|
|
|
|
|
# ── action arrow overlay ─────────────────────────────────────────────
|
|
|
|
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
|
"""Draw an arrow showing applied torque direction."""
|
|
if abs(action_val) < 0.01 or model.nu == 0:
|
|
return
|
|
|
|
jnt_id = model.actuator_trnid[0, 0]
|
|
body_id = model.jnt_bodyid[jnt_id]
|
|
pos = data.xpos[body_id].copy()
|
|
pos[2] += 0.02
|
|
|
|
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
|
|
arrow_len = 0.08 * action_val
|
|
direction = axis * np.sign(arrow_len)
|
|
|
|
z = direction / (np.linalg.norm(direction) + 1e-8)
|
|
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
|
|
x = np.cross(up, z)
|
|
x /= np.linalg.norm(x) + 1e-8
|
|
y = np.cross(z, x)
|
|
mat = np.column_stack([x, y, z]).flatten()
|
|
|
|
rgba = np.array(
|
|
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
|
|
dtype=np.float32,
|
|
)
|
|
|
|
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
|
|
mujoco.mjv_initGeom(
|
|
geom,
|
|
type=mujoco.mjtGeom.mjGEOM_ARROW,
|
|
size=np.array([0.008, 0.008, abs(arrow_len)]),
|
|
pos=pos,
|
|
mat=mat,
|
|
rgba=rgba,
|
|
)
|
|
viewer.user_scn.ngeom += 1
|
|
|
|
|
|
# ── main loops ───────────────────────────────────────────────────────
|
|
|
|
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
|
def main(cfg: DictConfig) -> None:
|
|
choices = HydraConfig.get().runtime.choices
|
|
env_name = choices.get("env", "cartpole")
|
|
runner_name = choices.get("runner", "mujoco_single")
|
|
|
|
checkpoint_path = cfg.get("checkpoint", None)
|
|
if checkpoint_path is None:
|
|
logger.error("No checkpoint specified. Use: +checkpoint=path/to/agent.pt")
|
|
sys.exit(1)
|
|
|
|
# Resolve relative paths against original working directory.
|
|
checkpoint_path = str(Path(hydra.utils.get_original_cwd()) / checkpoint_path)
|
|
if not Path(checkpoint_path).exists():
|
|
logger.error("checkpoint_not_found", path=checkpoint_path)
|
|
sys.exit(1)
|
|
|
|
if runner_name == "serial":
|
|
_eval_serial(cfg, env_name, checkpoint_path)
|
|
else:
|
|
_eval_sim(cfg, env_name, checkpoint_path)
|
|
|
|
|
|
def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
|
"""Evaluate policy in MuJoCo simulation with viewer."""
|
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
|
|
|
env = build_env(env_name, cfg)
|
|
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
|
runner_dict["num_envs"] = 1
|
|
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
|
|
|
device = runner.device
|
|
model, preprocessor = load_policy(
|
|
checkpoint_path, runner.observation_space, runner.action_space, device
|
|
)
|
|
|
|
mj_model = runner._model
|
|
mj_data = runner._data[0]
|
|
dt_ctrl = runner.config.dt * runner.config.substeps
|
|
|
|
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
|
|
obs, _ = runner.reset()
|
|
step = 0
|
|
episode = 0
|
|
episode_reward = 0.0
|
|
|
|
logger.info(
|
|
"eval_started",
|
|
env=env_name,
|
|
mode="simulation",
|
|
checkpoint=Path(checkpoint_path).name,
|
|
controls="Space=pause, R=reset, Esc=quit",
|
|
)
|
|
|
|
while viewer.is_running() and not _quit_flag[0]:
|
|
if _reset_flag[0]:
|
|
_reset_flag[0] = False
|
|
obs, _ = runner.reset()
|
|
step = 0
|
|
episode += 1
|
|
episode_reward = 0.0
|
|
logger.info("reset", episode=episode)
|
|
|
|
if _paused[0]:
|
|
viewer.sync()
|
|
time.sleep(0.05)
|
|
continue
|
|
|
|
# Policy inference
|
|
with torch.no_grad():
|
|
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
|
|
action = model.act({"states": normalized_obs}, role="policy")[0]
|
|
action = action.clamp(-1.0, 1.0)
|
|
|
|
obs, reward, terminated, truncated, info = runner.step(action)
|
|
episode_reward += reward.item()
|
|
step += 1
|
|
|
|
# Sync viewer
|
|
mujoco.mj_forward(mj_model, mj_data)
|
|
viewer.user_scn.ngeom = 0
|
|
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
|
|
viewer.sync()
|
|
|
|
if step % 50 == 0:
|
|
joints = {mj_model.jnt(i).name: round(math.degrees(mj_data.qpos[i]), 1)
|
|
for i in range(mj_model.njnt)}
|
|
logger.debug(
|
|
"step", n=step, reward=round(reward.item(), 3),
|
|
action=round(action[0, 0].item(), 2),
|
|
ep_reward=round(episode_reward, 1), **joints,
|
|
)
|
|
|
|
if terminated.any() or truncated.any():
|
|
logger.info(
|
|
"episode_done", episode=episode, steps=step,
|
|
total_reward=round(episode_reward, 2),
|
|
reason="terminated" if terminated.any() else "truncated",
|
|
)
|
|
obs, _ = runner.reset()
|
|
step = 0
|
|
episode += 1
|
|
episode_reward = 0.0
|
|
|
|
time.sleep(dt_ctrl)
|
|
|
|
runner.close()
|
|
|
|
|
|
def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
|
"""Evaluate policy on real hardware via serial, with digital-twin viewer."""
|
|
from src.runners.serial import SerialRunner, SerialRunnerConfig
|
|
|
|
env = build_env(env_name, cfg)
|
|
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
|
serial_runner = SerialRunner(env=env, config=SerialRunnerConfig(**runner_dict))
|
|
|
|
device = serial_runner.device
|
|
model, preprocessor = load_policy(
|
|
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device
|
|
)
|
|
|
|
# Set up digital-twin MuJoCo model for visualization.
|
|
serial_runner._ensure_viz_model()
|
|
mj_model = serial_runner._viz_model
|
|
mj_data = serial_runner._viz_data
|
|
|
|
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
|
|
obs, _ = serial_runner.reset()
|
|
step = 0
|
|
episode = 0
|
|
episode_reward = 0.0
|
|
|
|
logger.info(
|
|
"eval_started",
|
|
env=env_name,
|
|
mode="real hardware (serial)",
|
|
port=serial_runner.config.port,
|
|
checkpoint=Path(checkpoint_path).name,
|
|
controls="Space=pause, R=reset, Esc=quit",
|
|
)
|
|
|
|
while viewer.is_running() and not _quit_flag[0]:
|
|
if _reset_flag[0]:
|
|
_reset_flag[0] = False
|
|
serial_runner._send("M0")
|
|
serial_runner._drive_to_center()
|
|
serial_runner._wait_for_pendulum_still()
|
|
obs, _ = serial_runner.reset()
|
|
step = 0
|
|
episode += 1
|
|
episode_reward = 0.0
|
|
logger.info("reset", episode=episode)
|
|
|
|
if _paused[0]:
|
|
serial_runner._send("M0") # safety: stop motor while paused
|
|
serial_runner._sync_viz()
|
|
viewer.sync()
|
|
time.sleep(0.05)
|
|
continue
|
|
|
|
# Policy inference
|
|
with torch.no_grad():
|
|
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
|
|
action = model.act({"states": normalized_obs}, role="policy")[0]
|
|
action = action.clamp(-1.0, 1.0)
|
|
|
|
obs, reward, terminated, truncated, info = serial_runner.step(action)
|
|
episode_reward += reward.item()
|
|
step += 1
|
|
|
|
# Sync digital twin with real sensor data.
|
|
serial_runner._sync_viz()
|
|
viewer.user_scn.ngeom = 0
|
|
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
|
|
viewer.sync()
|
|
|
|
if step % 25 == 0:
|
|
state = serial_runner._read_state()
|
|
logger.debug(
|
|
"step", n=step, reward=round(reward.item(), 3),
|
|
action=round(action[0, 0].item(), 2),
|
|
ep_reward=round(episode_reward, 1),
|
|
motor_enc=state["encoder_count"],
|
|
pend_deg=round(state["pendulum_angle"], 1),
|
|
)
|
|
|
|
# Check for safety / disconnection.
|
|
if info.get("reboot_detected") or info.get("motor_limit_exceeded"):
|
|
logger.error(
|
|
"safety_stop",
|
|
reboot=info.get("reboot_detected", False),
|
|
motor_limit=info.get("motor_limit_exceeded", False),
|
|
)
|
|
serial_runner._send("M0")
|
|
break
|
|
|
|
if terminated.any() or truncated.any():
|
|
logger.info(
|
|
"episode_done", episode=episode, steps=step,
|
|
total_reward=round(episode_reward, 2),
|
|
reason="terminated" if terminated.any() else "truncated",
|
|
)
|
|
# Auto-reset for next episode.
|
|
obs, _ = serial_runner.reset()
|
|
step = 0
|
|
episode += 1
|
|
episode_reward = 0.0
|
|
|
|
# Real-time pacing is handled by serial_runner.step() (dt sleep).
|
|
|
|
serial_runner.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|