Files
RL-Sim-Framework/scripts/viz.py
2026-03-11 22:52:01 +01:00

255 lines
8.5 KiB
Python

"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
Usage (simulation):
mjpython scripts/viz.py env=rotary_cartpole
mjpython scripts/viz.py env=cartpole +com=true
Usage (real hardware — digital twin):
mjpython scripts/viz.py env=rotary_cartpole runner=serial
mjpython scripts/viz.py env=rotary_cartpole runner=serial runner.port=/dev/ttyUSB0
Controls:
Left/Right arrows — apply torque to first actuator
R — reset environment
Esc / close window — quit
"""
import math
import sys
import time
from pathlib import Path
# Ensure project root is on sys.path
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import hydra
import mujoco
import mujoco.viewer
import numpy as np
import structlog
import torch
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.core.registry import build_env
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
logger = structlog.get_logger()
# ── keyboard state ───────────────────────────────────────────────────
_action_val = [0.0] # mutable container shared with callback
_action_time = [0.0] # timestamp of last key press
_reset_flag = [False]
_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
def _key_callback(keycode: int) -> None:
"""Called by MuJoCo on key press & repeat (not release)."""
if keycode == 263: # GLFW_KEY_LEFT
_action_val[0] = -1.0
_action_time[0] = time.time()
elif keycode == 262: # GLFW_KEY_RIGHT
_action_val[0] = 1.0
_action_time[0] = time.time()
elif keycode == 82: # GLFW_KEY_R
_reset_flag[0] = True
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
"""Draw an arrow on the motor joint showing applied torque direction."""
if abs(action_val) < 0.01 or model.nu == 0:
return
# Get the body that the first actuator's joint belongs to
jnt_id = model.actuator_trnid[0, 0]
body_id = model.jnt_bodyid[jnt_id]
# Arrow origin: body position
pos = data.xpos[body_id].copy()
pos[2] += 0.02 # lift slightly above the body
# Arrow direction: along joint axis in world frame, scaled by action
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
arrow_len = 0.08 * action_val
direction = axis * np.sign(arrow_len)
# Build rotation matrix: arrow rendered along local z-axis
z = direction / (np.linalg.norm(direction) + 1e-8)
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
x = np.cross(up, z)
x /= np.linalg.norm(x) + 1e-8
y = np.cross(z, x)
mat = np.column_stack([x, y, z]).flatten()
# Color: green = positive, red = negative
rgba = np.array(
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
dtype=np.float32,
)
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
mujoco.mjv_initGeom(
geom,
type=mujoco.mjtGeom.mjGEOM_ARROW,
size=np.array([0.008, 0.008, abs(arrow_len)]),
pos=pos,
mat=mat,
rgba=rgba,
)
viewer.user_scn.ngeom += 1
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "cartpole")
runner_name = choices.get("runner", "mujoco")
if runner_name == "serial":
_main_serial(cfg, env_name)
else:
_main_sim(cfg, env_name)
def _main_sim(cfg: DictConfig, env_name: str) -> None:
"""Simulation visualization — step MuJoCo physics with keyboard control."""
# Build env + runner (single env for viz)
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
runner_dict["num_envs"] = 1
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
model = runner._model
data = runner._data[0]
# Control period
dt_ctrl = runner.config.dt * runner.config.substeps
# Launch viewer
with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
# Show CoM / inertia if requested via Hydra override: viz.py +com=true
show_com = cfg.get("com", False)
if show_com:
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
obs, _ = runner.reset()
step = 0
logger.info("viewer_started", env=env_name,
controls="Left/Right arrows = torque, R = reset")
while viewer.is_running():
# Read action from callback (expires after _ACTION_HOLD_S)
if time.time() - _action_time[0] < _ACTION_HOLD_S:
action_val = _action_val[0]
else:
action_val = 0.0
# Reset on R press
if _reset_flag[0]:
_reset_flag[0] = False
obs, _ = runner.reset()
step = 0
logger.info("reset")
# Step through runner
action = torch.tensor([[action_val]])
obs, reward, terminated, truncated, info = runner.step(action)
# Sync viewer with action arrow overlay
mujoco.mj_forward(model, data)
viewer.user_scn.ngeom = 0 # clear previous frame's overlays
_add_action_arrow(viewer, model, data, action_val)
viewer.sync()
# Print state
if step % 25 == 0:
joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
for i in range(model.njnt)}
logger.debug("step", n=step, reward=round(reward.item(), 3),
action=round(action_val, 1), **joints)
# Real-time pacing
time.sleep(dt_ctrl)
step += 1
runner.close()
def _main_serial(cfg: DictConfig, env_name: str) -> None:
"""Digital-twin visualization — mirror real hardware in MuJoCo viewer.
The MuJoCo model is loaded for rendering only. Joint positions are
read from the ESP32 over serial and applied to the model each frame.
Keyboard arrows send motor commands to the real robot.
"""
from src.runners.serial import SerialRunner, SerialRunnerConfig
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
serial_runner = SerialRunner(
env=env, config=SerialRunnerConfig(**runner_dict)
)
# Load MuJoCo model for visualisation (same URDF the sim uses).
serial_runner._ensure_viz_model()
model = serial_runner._viz_model
data = serial_runner._viz_data
with mujoco.viewer.launch_passive(
model, data, key_callback=_key_callback
) as viewer:
# Show CoM / inertia if requested.
show_com = cfg.get("com", False)
if show_com:
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
logger.info(
"viewer_started",
env=env_name,
mode="serial (digital twin)",
port=serial_runner.config.port,
controls="Left/Right arrows = motor command, R = reset",
)
while viewer.is_running():
# Read action from keyboard callback.
if time.time() - _action_time[0] < _ACTION_HOLD_S:
action_val = _action_val[0]
else:
action_val = 0.0
# Reset on R press.
if _reset_flag[0]:
_reset_flag[0] = False
serial_runner._send("M0")
serial_runner._drive_to_center()
serial_runner._wait_for_pendulum_still()
logger.info("reset (drive-to-center + settle)")
# Send motor command to real hardware.
motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
serial_runner._send(f"M{motor_speed}")
# Sync MuJoCo model with real sensor data.
serial_runner._sync_viz()
# Render overlays and sync viewer.
viewer.user_scn.ngeom = 0
_add_action_arrow(viewer, model, data, action_val)
viewer.sync()
# Real-time pacing (~50 Hz, matches serial dt).
time.sleep(serial_runner.config.dt)
serial_runner.close()
if __name__ == "__main__":
main()