♻️ crazy refactor

This commit is contained in:
2026-03-11 22:52:01 +01:00
parent 35223b3560
commit 4115447022
34 changed files with 4255 additions and 102 deletions

135
viz.py
View File

@@ -1,9 +1,13 @@
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
Usage:
Usage (simulation):
mjpython viz.py env=rotary_cartpole
mjpython viz.py env=cartpole +com=true
Usage (real hardware — digital twin):
mjpython viz.py env=rotary_cartpole runner=serial
mjpython viz.py env=rotary_cartpole runner=serial runner.port=/dev/ttyUSB0
Controls:
Left/Right arrows — apply torque to first actuator
R — reset environment
@@ -15,6 +19,7 @@ import time
import hydra
import mujoco
import mujoco.viewer
import numpy as np
import structlog
import torch
from hydra.core.hydra_config import HydraConfig
@@ -45,10 +50,64 @@ def _key_callback(keycode: int) -> None:
_reset_flag[0] = True
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
"""Draw an arrow on the motor joint showing applied torque direction."""
if abs(action_val) < 0.01 or model.nu == 0:
return
# Get the body that the first actuator's joint belongs to
jnt_id = model.actuator_trnid[0, 0]
body_id = model.jnt_bodyid[jnt_id]
# Arrow origin: body position
pos = data.xpos[body_id].copy()
pos[2] += 0.02 # lift slightly above the body
# Arrow direction: along joint axis in world frame, scaled by action
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
arrow_len = 0.08 * action_val
direction = axis * np.sign(arrow_len)
# Build rotation matrix: arrow rendered along local z-axis
z = direction / (np.linalg.norm(direction) + 1e-8)
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
x = np.cross(up, z)
x /= np.linalg.norm(x) + 1e-8
y = np.cross(z, x)
mat = np.column_stack([x, y, z]).flatten()
# Color: green = positive, red = negative
rgba = np.array(
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
dtype=np.float32,
)
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
mujoco.mjv_initGeom(
geom,
type=mujoco.mjtGeom.mjGEOM_ARROW,
size=np.array([0.008, 0.008, abs(arrow_len)]),
pos=pos,
mat=mat,
rgba=rgba,
)
viewer.user_scn.ngeom += 1
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "cartpole")
runner_name = choices.get("runner", "mujoco")
if runner_name == "serial":
_main_serial(cfg, env_name)
else:
_main_sim(cfg, env_name)
def _main_sim(cfg: DictConfig, env_name: str) -> None:
"""Simulation visualization — step MuJoCo physics with keyboard control."""
# Build env + runner (single env for viz)
env = build_env(env_name, cfg)
@@ -94,8 +153,10 @@ def main(cfg: DictConfig) -> None:
action = torch.tensor([[action_val]])
obs, reward, terminated, truncated, info = runner.step(action)
# Sync viewer
# Sync viewer with action arrow overlay
mujoco.mj_forward(model, data)
viewer.user_scn.ngeom = 0 # clear previous frame's overlays
_add_action_arrow(viewer, model, data, action_val)
viewer.sync()
# Print state
@@ -112,5 +173,75 @@ def main(cfg: DictConfig) -> None:
runner.close()
def _main_serial(cfg: DictConfig, env_name: str) -> None:
"""Digital-twin visualization — mirror real hardware in MuJoCo viewer.
The MuJoCo model is loaded for rendering only. Joint positions are
read from the ESP32 over serial and applied to the model each frame.
Keyboard arrows send motor commands to the real robot.
"""
from src.runners.serial import SerialRunner, SerialRunnerConfig
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
serial_runner = SerialRunner(
env=env, config=SerialRunnerConfig(**runner_dict)
)
# Load MuJoCo model for visualisation (same URDF the sim uses).
serial_runner._ensure_viz_model()
model = serial_runner._viz_model
data = serial_runner._viz_data
with mujoco.viewer.launch_passive(
model, data, key_callback=_key_callback
) as viewer:
# Show CoM / inertia if requested.
show_com = cfg.get("com", False)
if show_com:
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
logger.info(
"viewer_started",
env=env_name,
mode="serial (digital twin)",
port=serial_runner.config.port,
controls="Left/Right arrows = motor command, R = reset",
)
while viewer.is_running():
# Read action from keyboard callback.
if time.time() - _action_time[0] < _ACTION_HOLD_S:
action_val = _action_val[0]
else:
action_val = 0.0
# Reset on R press.
if _reset_flag[0]:
_reset_flag[0] = False
serial_runner._send("M0")
serial_runner._drive_to_center()
serial_runner._wait_for_pendulum_still()
logger.info("reset (drive-to-center + settle)")
# Send motor command to real hardware.
motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
serial_runner._send(f"M{motor_speed}")
# Sync MuJoCo model with real sensor data.
serial_runner._sync_viz()
# Render overlays and sync viewer.
viewer.user_scn.ngeom = 0
_add_action_arrow(viewer, model, data, action_val)
viewer.sync()
# Real-time pacing (~50 Hz, matches serial dt).
time.sleep(serial_runner.config.dt)
serial_runner.close()
if __name__ == "__main__":
main()