255 lines
8.5 KiB
Python
255 lines
8.5 KiB
Python
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
|
|
|
|
Usage (simulation):
|
|
mjpython scripts/viz.py env=rotary_cartpole
|
|
mjpython scripts/viz.py env=cartpole +com=true
|
|
|
|
Usage (real hardware — digital twin):
|
|
mjpython scripts/viz.py env=rotary_cartpole runner=serial
|
|
mjpython scripts/viz.py env=rotary_cartpole runner=serial runner.port=/dev/ttyUSB0
|
|
|
|
Controls:
|
|
Left/Right arrows — apply torque to first actuator
|
|
R — reset environment
|
|
Esc / close window — quit
|
|
"""
|
|
import math
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Ensure project root is on sys.path
|
|
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
|
if _PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, _PROJECT_ROOT)
|
|
|
|
import hydra
|
|
import mujoco
|
|
import mujoco.viewer
|
|
import numpy as np
|
|
import structlog
|
|
import torch
|
|
from hydra.core.hydra_config import HydraConfig
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
|
from src.core.registry import build_env
|
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
# ── keyboard state ───────────────────────────────────────────────────
|
|
_action_val = [0.0] # mutable container shared with callback
|
|
_action_time = [0.0] # timestamp of last key press
|
|
_reset_flag = [False]
|
|
_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
|
|
|
|
|
|
def _key_callback(keycode: int) -> None:
|
|
"""Called by MuJoCo on key press & repeat (not release)."""
|
|
if keycode == 263: # GLFW_KEY_LEFT
|
|
_action_val[0] = -1.0
|
|
_action_time[0] = time.time()
|
|
elif keycode == 262: # GLFW_KEY_RIGHT
|
|
_action_val[0] = 1.0
|
|
_action_time[0] = time.time()
|
|
elif keycode == 82: # GLFW_KEY_R
|
|
_reset_flag[0] = True
|
|
|
|
|
|
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
|
"""Draw an arrow on the motor joint showing applied torque direction."""
|
|
if abs(action_val) < 0.01 or model.nu == 0:
|
|
return
|
|
|
|
# Get the body that the first actuator's joint belongs to
|
|
jnt_id = model.actuator_trnid[0, 0]
|
|
body_id = model.jnt_bodyid[jnt_id]
|
|
|
|
# Arrow origin: body position
|
|
pos = data.xpos[body_id].copy()
|
|
pos[2] += 0.02 # lift slightly above the body
|
|
|
|
# Arrow direction: along joint axis in world frame, scaled by action
|
|
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
|
|
arrow_len = 0.08 * action_val
|
|
direction = axis * np.sign(arrow_len)
|
|
|
|
# Build rotation matrix: arrow rendered along local z-axis
|
|
z = direction / (np.linalg.norm(direction) + 1e-8)
|
|
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
|
|
x = np.cross(up, z)
|
|
x /= np.linalg.norm(x) + 1e-8
|
|
y = np.cross(z, x)
|
|
mat = np.column_stack([x, y, z]).flatten()
|
|
|
|
# Color: green = positive, red = negative
|
|
rgba = np.array(
|
|
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
|
|
dtype=np.float32,
|
|
)
|
|
|
|
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
|
|
mujoco.mjv_initGeom(
|
|
geom,
|
|
type=mujoco.mjtGeom.mjGEOM_ARROW,
|
|
size=np.array([0.008, 0.008, abs(arrow_len)]),
|
|
pos=pos,
|
|
mat=mat,
|
|
rgba=rgba,
|
|
)
|
|
viewer.user_scn.ngeom += 1
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
|
def main(cfg: DictConfig) -> None:
|
|
choices = HydraConfig.get().runtime.choices
|
|
env_name = choices.get("env", "cartpole")
|
|
runner_name = choices.get("runner", "mujoco")
|
|
|
|
if runner_name == "serial":
|
|
_main_serial(cfg, env_name)
|
|
else:
|
|
_main_sim(cfg, env_name)
|
|
|
|
|
|
def _main_sim(cfg: DictConfig, env_name: str) -> None:
|
|
"""Simulation visualization — step MuJoCo physics with keyboard control."""
|
|
|
|
# Build env + runner (single env for viz)
|
|
env = build_env(env_name, cfg)
|
|
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
|
runner_dict["num_envs"] = 1
|
|
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
|
|
|
model = runner._model
|
|
data = runner._data[0]
|
|
|
|
# Control period
|
|
dt_ctrl = runner.config.dt * runner.config.substeps
|
|
|
|
# Launch viewer
|
|
with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
|
|
# Show CoM / inertia if requested via Hydra override: viz.py +com=true
|
|
show_com = cfg.get("com", False)
|
|
if show_com:
|
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
|
|
|
obs, _ = runner.reset()
|
|
step = 0
|
|
|
|
logger.info("viewer_started", env=env_name,
|
|
controls="Left/Right arrows = torque, R = reset")
|
|
|
|
while viewer.is_running():
|
|
# Read action from callback (expires after _ACTION_HOLD_S)
|
|
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
|
action_val = _action_val[0]
|
|
else:
|
|
action_val = 0.0
|
|
|
|
# Reset on R press
|
|
if _reset_flag[0]:
|
|
_reset_flag[0] = False
|
|
obs, _ = runner.reset()
|
|
step = 0
|
|
logger.info("reset")
|
|
|
|
# Step through runner
|
|
action = torch.tensor([[action_val]])
|
|
obs, reward, terminated, truncated, info = runner.step(action)
|
|
|
|
# Sync viewer with action arrow overlay
|
|
mujoco.mj_forward(model, data)
|
|
viewer.user_scn.ngeom = 0 # clear previous frame's overlays
|
|
_add_action_arrow(viewer, model, data, action_val)
|
|
viewer.sync()
|
|
|
|
# Print state
|
|
if step % 25 == 0:
|
|
joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
|
|
for i in range(model.njnt)}
|
|
logger.debug("step", n=step, reward=round(reward.item(), 3),
|
|
action=round(action_val, 1), **joints)
|
|
|
|
# Real-time pacing
|
|
time.sleep(dt_ctrl)
|
|
step += 1
|
|
|
|
runner.close()
|
|
|
|
|
|
def _main_serial(cfg: DictConfig, env_name: str) -> None:
|
|
"""Digital-twin visualization — mirror real hardware in MuJoCo viewer.
|
|
|
|
The MuJoCo model is loaded for rendering only. Joint positions are
|
|
read from the ESP32 over serial and applied to the model each frame.
|
|
Keyboard arrows send motor commands to the real robot.
|
|
"""
|
|
from src.runners.serial import SerialRunner, SerialRunnerConfig
|
|
|
|
env = build_env(env_name, cfg)
|
|
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
|
serial_runner = SerialRunner(
|
|
env=env, config=SerialRunnerConfig(**runner_dict)
|
|
)
|
|
|
|
# Load MuJoCo model for visualisation (same URDF the sim uses).
|
|
serial_runner._ensure_viz_model()
|
|
model = serial_runner._viz_model
|
|
data = serial_runner._viz_data
|
|
|
|
with mujoco.viewer.launch_passive(
|
|
model, data, key_callback=_key_callback
|
|
) as viewer:
|
|
# Show CoM / inertia if requested.
|
|
show_com = cfg.get("com", False)
|
|
if show_com:
|
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
|
|
|
logger.info(
|
|
"viewer_started",
|
|
env=env_name,
|
|
mode="serial (digital twin)",
|
|
port=serial_runner.config.port,
|
|
controls="Left/Right arrows = motor command, R = reset",
|
|
)
|
|
|
|
while viewer.is_running():
|
|
# Read action from keyboard callback.
|
|
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
|
action_val = _action_val[0]
|
|
else:
|
|
action_val = 0.0
|
|
|
|
# Reset on R press.
|
|
if _reset_flag[0]:
|
|
_reset_flag[0] = False
|
|
serial_runner._send("M0")
|
|
serial_runner._drive_to_center()
|
|
serial_runner._wait_for_pendulum_still()
|
|
logger.info("reset (drive-to-center + settle)")
|
|
|
|
# Send motor command to real hardware.
|
|
motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
|
|
serial_runner._send(f"M{motor_speed}")
|
|
|
|
# Sync MuJoCo model with real sensor data.
|
|
serial_runner._sync_viz()
|
|
|
|
# Render overlays and sync viewer.
|
|
viewer.user_scn.ngeom = 0
|
|
_add_action_arrow(viewer, model, data, action_val)
|
|
viewer.sync()
|
|
|
|
# Real-time pacing (~50 Hz, matches serial dt).
|
|
time.sleep(serial_runner.config.dt)
|
|
|
|
serial_runner.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|