223 lines
9.2 KiB
Python
223 lines
9.2 KiB
Python
import dataclasses
|
||
import xml.etree.ElementTree as ET
|
||
from pathlib import Path
|
||
|
||
import mujoco
|
||
import numpy as np
|
||
import torch
|
||
|
||
from src.core.env import BaseEnv
|
||
from src.core.robot import RobotConfig
|
||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||
|
||
@dataclasses.dataclass
|
||
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||
num_envs: int = 16
|
||
device: str = "cpu"
|
||
dt: float = 0.02
|
||
substeps: int = 2
|
||
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
|
||
|
||
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
|
||
super().__init__(env, config)
|
||
|
||
@property
|
||
def num_envs(self) -> int:
|
||
return self.config.num_envs
|
||
|
||
@property
|
||
def device(self) -> torch.device:
|
||
return torch.device(self.config.device)
|
||
|
||
@staticmethod
|
||
def _load_model(robot: RobotConfig) -> mujoco.MjModel:
|
||
"""Load a URDF (or MJCF) and apply robot.yaml settings.
|
||
|
||
Two-step approach required because MuJoCo's URDF parser ignores
|
||
<actuator> in the <mujoco> extension block:
|
||
1. Load the URDF → MuJoCo converts it to internal MJCF
|
||
2. Export the MJCF XML, inject actuators + joint overrides, reload
|
||
|
||
This keeps the URDF clean (re-exportable from CAD) — all hardware
|
||
tuning lives in robot.yaml.
|
||
"""
|
||
abs_path = robot.urdf_path.resolve()
|
||
model_dir = abs_path.parent
|
||
is_urdf = abs_path.suffix.lower() == ".urdf"
|
||
|
||
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
|
||
# so we inject a <mujoco><compiler meshdir="..."/> block into a
|
||
# temporary copy. The original URDF stays clean and simulator-agnostic.
|
||
if is_urdf:
|
||
tree = ET.parse(abs_path)
|
||
root = tree.getroot()
|
||
# Detect the mesh subdirectory from the first mesh filename
|
||
meshdir = None
|
||
for mesh_el in root.iter("mesh"):
|
||
fn = mesh_el.get("filename", "")
|
||
parent = str(Path(fn).parent)
|
||
if parent and parent != ".":
|
||
meshdir = parent
|
||
break
|
||
if meshdir:
|
||
mj_ext = ET.SubElement(root, "mujoco")
|
||
ET.SubElement(mj_ext, "compiler", attrib={
|
||
"meshdir": meshdir,
|
||
"balanceinertia": "true",
|
||
})
|
||
tmp_urdf = model_dir / "_tmp_mujoco_load.urdf"
|
||
tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode")
|
||
try:
|
||
model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf))
|
||
finally:
|
||
tmp_urdf.unlink()
|
||
else:
|
||
model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
|
||
|
||
if not robot.actuators and not robot.joints:
|
||
return model_raw
|
||
|
||
# Step 2: Export internal MJCF, inject actuators + joint overrides, reload
|
||
tmp_mjcf = model_dir / "_tmp_actuator_inject.xml"
|
||
try:
|
||
mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
|
||
mjcf_str = tmp_mjcf.read_text()
|
||
|
||
root = ET.fromstring(mjcf_str)
|
||
|
||
# ── Inject actuators ────────────────────────────────────
|
||
if robot.actuators:
|
||
act_elem = ET.SubElement(root, "actuator")
|
||
for act in robot.actuators:
|
||
attribs = {
|
||
"name": f"{act.joint}_{act.type}",
|
||
"joint": act.joint,
|
||
"gear": str(act.gear),
|
||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||
}
|
||
if act.type == "position":
|
||
attribs["kp"] = str(act.kp)
|
||
if act.kv > 0:
|
||
attribs["kv"] = str(act.kv)
|
||
ET.SubElement(act_elem, "position", attrib=attribs)
|
||
elif act.type == "velocity":
|
||
attribs["kp"] = str(act.kp)
|
||
ET.SubElement(act_elem, "velocity", attrib=attribs)
|
||
else: # motor (default)
|
||
ET.SubElement(act_elem, "motor", attrib=attribs)
|
||
|
||
# ── Apply joint overrides from robot.yaml ───────────────
|
||
# Merge actuator damping + explicit joint overrides
|
||
joint_damping = {a.joint: a.damping for a in robot.actuators}
|
||
for name, jcfg in robot.joints.items():
|
||
if jcfg.damping is not None:
|
||
joint_damping[name] = jcfg.damping
|
||
|
||
for body in root.iter("body"):
|
||
for jnt in body.findall("joint"):
|
||
name = jnt.get("name")
|
||
if name in joint_damping:
|
||
jnt.set("damping", str(joint_damping[name]))
|
||
|
||
# Disable self-collision on all geoms.
|
||
# URDF mesh convex hulls often overlap at joints (especially
|
||
# grandparent↔grandchild bodies that MuJoCo does NOT auto-exclude),
|
||
# causing phantom contact forces.
|
||
for geom in root.iter("geom"):
|
||
geom.set("contype", "0")
|
||
geom.set("conaffinity", "0")
|
||
|
||
# Harden joint limits: MuJoCo's default soft limits are too
|
||
# weak and allow overshoot. Negative solref = hard constraint
|
||
# (direct stiffness/damping instead of impedance match).
|
||
for body in root.iter("body"):
|
||
for jnt in body.findall("joint"):
|
||
if jnt.get("limited") == "true" or jnt.get("range"):
|
||
jnt.set("solreflimit", "-1000 -100")
|
||
jnt.set("solimplimit", "0.95 0.99 0.001")
|
||
|
||
# Step 4: Write modified MJCF and reload from file path
|
||
# (from_xml_path resolves mesh paths relative to the file location)
|
||
modified_xml = ET.tostring(root, encoding="unicode")
|
||
tmp_mjcf.write_text(modified_xml)
|
||
return mujoco.MjModel.from_xml_path(str(tmp_mjcf))
|
||
finally:
|
||
tmp_mjcf.unlink(missing_ok=True)
|
||
|
||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||
self._model = self._load_model(self.env.robot)
|
||
self._model.opt.timestep = config.dt
|
||
self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)]
|
||
|
||
self._nq = self._model.nq
|
||
self._nv = self._model.nv
|
||
|
||
# Per-env smoothed ctrl state for EMA action filtering.
|
||
# Models real motor inertia: ctrl can't reverse instantly.
|
||
nu = self._model.nu
|
||
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
|
||
|
||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||
actions_np: np.ndarray = actions.cpu().numpy()
|
||
alpha = self.config.action_ema_alpha
|
||
|
||
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
||
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
||
|
||
for i, data in enumerate(self._data):
|
||
# EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl
|
||
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
|
||
data.ctrl[:] = self._smooth_ctrl[i]
|
||
for _ in range(self.config.substeps):
|
||
mujoco.mj_step(self._model, data)
|
||
|
||
qpos_batch[i] = data.qpos
|
||
qvel_batch[i] = data.qvel
|
||
|
||
return (
|
||
torch.from_numpy(qpos_batch).to(self.device),
|
||
torch.from_numpy(qvel_batch).to(self.device),
|
||
)
|
||
|
||
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||
ids = env_ids.cpu().numpy()
|
||
n = len(ids)
|
||
|
||
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
|
||
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
|
||
|
||
default_qpos = self.env.get_default_qpos(self._nq)
|
||
|
||
for i, env_id in enumerate(ids):
|
||
data = self._data[env_id]
|
||
mujoco.mj_resetData(self._model, data)
|
||
|
||
# Set initial pose (env-specific, e.g. pendulum hanging down)
|
||
if default_qpos is not None:
|
||
data.qpos[:] = default_qpos
|
||
|
||
# Add small random perturbation for exploration
|
||
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
|
||
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
||
|
||
# Reset smoothed ctrl so motor starts from rest
|
||
self._smooth_ctrl[env_id][:] = 0.0
|
||
|
||
qpos_batch[i] = data.qpos
|
||
qvel_batch[i] = data.qvel
|
||
|
||
return (
|
||
torch.from_numpy(qpos_batch).to(self.device),
|
||
torch.from_numpy(qvel_batch).to(self.device),
|
||
)
|
||
|
||
def render(self, env_idx: int = 0) -> np.ndarray:
|
||
"""Offscreen render of a single environment."""
|
||
if not hasattr(self, "_offscreen_renderer"):
|
||
self._offscreen_renderer = mujoco.Renderer(
|
||
self._model, width=640, height=480,
|
||
)
|
||
mujoco.mj_forward(self._model, self._data[env_idx])
|
||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||
return self._offscreen_renderer.render() |