Files
RL-Sim-Framework/src/runners/mujoco.py
2026-03-09 22:17:28 +01:00

223 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import dataclasses
import xml.etree.ElementTree as ET
from pathlib import Path
import mujoco
import numpy as np
import torch
from src.core.env import BaseEnv
from src.core.robot import RobotConfig
from src.core.runner import BaseRunner, BaseRunnerConfig
@dataclasses.dataclass
class MuJoCoRunnerConfig(BaseRunnerConfig):
num_envs: int = 16
device: str = "cpu"
dt: float = 0.02
substeps: int = 2
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
super().__init__(env, config)
@property
def num_envs(self) -> int:
return self.config.num_envs
@property
def device(self) -> torch.device:
return torch.device(self.config.device)
@staticmethod
def _load_model(robot: RobotConfig) -> mujoco.MjModel:
"""Load a URDF (or MJCF) and apply robot.yaml settings.
Two-step approach required because MuJoCo's URDF parser ignores
<actuator> in the <mujoco> extension block:
1. Load the URDF → MuJoCo converts it to internal MJCF
2. Export the MJCF XML, inject actuators + joint overrides, reload
This keeps the URDF clean (re-exportable from CAD) — all hardware
tuning lives in robot.yaml.
"""
abs_path = robot.urdf_path.resolve()
model_dir = abs_path.parent
is_urdf = abs_path.suffix.lower() == ".urdf"
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
# so we inject a <mujoco><compiler meshdir="..."/> block into a
# temporary copy. The original URDF stays clean and simulator-agnostic.
if is_urdf:
tree = ET.parse(abs_path)
root = tree.getroot()
# Detect the mesh subdirectory from the first mesh filename
meshdir = None
for mesh_el in root.iter("mesh"):
fn = mesh_el.get("filename", "")
parent = str(Path(fn).parent)
if parent and parent != ".":
meshdir = parent
break
if meshdir:
mj_ext = ET.SubElement(root, "mujoco")
ET.SubElement(mj_ext, "compiler", attrib={
"meshdir": meshdir,
"balanceinertia": "true",
})
tmp_urdf = model_dir / "_tmp_mujoco_load.urdf"
tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode")
try:
model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf))
finally:
tmp_urdf.unlink()
else:
model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
if not robot.actuators and not robot.joints:
return model_raw
# Step 2: Export internal MJCF, inject actuators + joint overrides, reload
tmp_mjcf = model_dir / "_tmp_actuator_inject.xml"
try:
mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
mjcf_str = tmp_mjcf.read_text()
root = ET.fromstring(mjcf_str)
# ── Inject actuators ────────────────────────────────────
if robot.actuators:
act_elem = ET.SubElement(root, "actuator")
for act in robot.actuators:
attribs = {
"name": f"{act.joint}_{act.type}",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
}
if act.type == "position":
attribs["kp"] = str(act.kp)
if act.kv > 0:
attribs["kv"] = str(act.kv)
ET.SubElement(act_elem, "position", attrib=attribs)
elif act.type == "velocity":
attribs["kp"] = str(act.kp)
ET.SubElement(act_elem, "velocity", attrib=attribs)
else: # motor (default)
ET.SubElement(act_elem, "motor", attrib=attribs)
# ── Apply joint overrides from robot.yaml ───────────────
# Merge actuator damping + explicit joint overrides
joint_damping = {a.joint: a.damping for a in robot.actuators}
for name, jcfg in robot.joints.items():
if jcfg.damping is not None:
joint_damping[name] = jcfg.damping
for body in root.iter("body"):
for jnt in body.findall("joint"):
name = jnt.get("name")
if name in joint_damping:
jnt.set("damping", str(joint_damping[name]))
# Disable self-collision on all geoms.
# URDF mesh convex hulls often overlap at joints (especially
# grandparent↔grandchild bodies that MuJoCo does NOT auto-exclude),
# causing phantom contact forces.
for geom in root.iter("geom"):
geom.set("contype", "0")
geom.set("conaffinity", "0")
# Harden joint limits: MuJoCo's default soft limits are too
# weak and allow overshoot. Negative solref = hard constraint
# (direct stiffness/damping instead of impedance match).
for body in root.iter("body"):
for jnt in body.findall("joint"):
if jnt.get("limited") == "true" or jnt.get("range"):
jnt.set("solreflimit", "-1000 -100")
jnt.set("solimplimit", "0.95 0.99 0.001")
# Step 4: Write modified MJCF and reload from file path
# (from_xml_path resolves mesh paths relative to the file location)
modified_xml = ET.tostring(root, encoding="unicode")
tmp_mjcf.write_text(modified_xml)
return mujoco.MjModel.from_xml_path(str(tmp_mjcf))
finally:
tmp_mjcf.unlink(missing_ok=True)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
self._model = self._load_model(self.env.robot)
self._model.opt.timestep = config.dt
self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)]
self._nq = self._model.nq
self._nv = self._model.nv
# Per-env smoothed ctrl state for EMA action filtering.
# Models real motor inertia: ctrl can't reverse instantly.
nu = self._model.nu
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
actions_np: np.ndarray = actions.cpu().numpy()
alpha = self.config.action_ema_alpha
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
for i, data in enumerate(self._data):
# EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
data.ctrl[:] = self._smooth_ctrl[i]
for _ in range(self.config.substeps):
mujoco.mj_step(self._model, data)
qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel
return (
torch.from_numpy(qpos_batch).to(self.device),
torch.from_numpy(qvel_batch).to(self.device),
)
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
ids = env_ids.cpu().numpy()
n = len(ids)
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
default_qpos = self.env.get_default_qpos(self._nq)
for i, env_id in enumerate(ids):
data = self._data[env_id]
mujoco.mj_resetData(self._model, data)
# Set initial pose (env-specific, e.g. pendulum hanging down)
if default_qpos is not None:
data.qpos[:] = default_qpos
# Add small random perturbation for exploration
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
# Reset smoothed ctrl so motor starts from rest
self._smooth_ctrl[env_id][:] = 0.0
qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel
return (
torch.from_numpy(qpos_batch).to(self.device),
torch.from_numpy(qvel_batch).to(self.device),
)
def render(self, env_idx: int = 0) -> np.ndarray:
"""Offscreen render of a single environment."""
if not hasattr(self, "_offscreen_renderer"):
self._offscreen_renderer = mujoco.Renderer(
self._model, width=640, height=480,
)
mujoco.mj_forward(self._model, self._data[env_idx])
self._offscreen_renderer.update_scene(self._data[env_idx])
return self._offscreen_renderer.render()