better robot joint loading

This commit is contained in:
2026-03-09 22:17:28 +01:00
parent 9be07d9186
commit 70cd2cdd7d
13 changed files with 215 additions and 128 deletions

View File

@@ -0,0 +1,10 @@
# Classic cartpole — robot hardware config.
urdf: cartpole.urdf
actuators:
- joint: cart_joint
type: motor
gear: 10.0
ctrl_range: [-1.0, 1.0]
damping: 0.05

View File

@@ -0,0 +1,15 @@
# Rotary cartpole (Furuta pendulum) — robot hardware config.
# Lives next to the URDF so all robot-specific settings are in one place.
urdf: rotary_cartpole.urdf
actuators:
- joint: motor_joint
type: motor # direct torque control
gear: 0.5 # torque multiplier
ctrl_range: [-1.0, 1.0]
damping: 0.1 # motor friction / back-EMF
joints:
pendulum_joint:
damping: 0.0001 # bearing friction

View File

@@ -1,12 +1,7 @@
max_steps: 500
robot_path: assets/cartpole
angle_threshold: 0.418
cart_limit: 2.4
reward_alive: 1.0
reward_pole_upright_scale: 1.0
reward_action_penalty_scale: 0.01
model_path: assets/cartpole/cartpole.urdf
actuators:
- joint: cart_joint
gear: 10.0
ctrl_range: [-1.0, 1.0]
damping: 0.05

View File

@@ -1,8 +1,3 @@
max_steps: 1000
model_path: assets/rotary_cartpole/rotary_cartpole.urdf
reward_upright_scale: 1.0
actuators:
- joint: motor_joint
gear: 0.5
ctrl_range: [-1.0, 1.0]
damping: 0.1
robot_path: assets/rotary_cartpole
reward_upright_scale: 1.0

View File

@@ -10,4 +10,5 @@ clearml
imageio
imageio-ffmpeg
structlog
pyyaml
pytest

View File

@@ -5,30 +5,20 @@ from gymnasium import spaces
import torch
import pathlib
from src.core.robot import RobotConfig, load_robot_config
T = TypeVar("T")
@dataclasses.dataclass
class ActuatorConfig:
"""Actuator definition — maps a joint to a motor with gear ratio and control limits.
Kept in the env config (not runner config) because actuators define what the robot
can do, which determines action space — a task-level concept.
This mirrors Isaac Lab's pattern of separating actuator config from the robot file."""
joint: str = ""
gear: float = 1.0
ctrl_range: tuple[float, float] = (-1.0, 1.0)
damping: float = 0.05 # joint damping — limits max speed: vel_max ≈ torque / damping
@dataclasses.dataclass
class BaseEnvConfig:
max_steps: int = 1000
model_path: pathlib.Path | None = None
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
robot_path: str = "" # directory containing robot.yaml + URDF
class BaseEnv(abc.ABC, Generic[T]):
def __init__(self, config: BaseEnvConfig):
self.config = config
self.robot: RobotConfig = load_robot_config(config.robot_path)
@property
@abc.abstractmethod

23
src/core/registry.py Normal file
View File

@@ -0,0 +1,23 @@
"""Shared env registry and builder used by train.py and viz.py."""
from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv, BaseEnvConfig
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
# Maps Hydra config-group name → (EnvClass, ConfigClass)
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
"cartpole": (CartPoleEnv, CartPoleConfig),
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
}
def build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
"""Instantiate the right env + config from the Hydra config-group name."""
if env_name not in ENV_REGISTRY:
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
env_cls, config_cls = ENV_REGISTRY[env_name]
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
return env_cls(config_cls(**env_dict))

101
src/core/robot.py Normal file
View File

@@ -0,0 +1,101 @@
"""Robot hardware configuration — loaded from robot.yaml next to the URDF.
Separates robot hardware (actuators, joint tuning) from task config
(rewards, episode length) and from the URDF (clean CAD export).
Usage:
robot = load_robot_config(Path("assets/rotary_cartpole"))
# robot.urdf_path → resolved absolute path to the URDF
# robot.actuators → list of ActuatorConfig
# robot.joints → dict of per-joint overrides
"""
import dataclasses
from pathlib import Path
import structlog
import yaml
log = structlog.get_logger()
@dataclasses.dataclass
class ActuatorConfig:
"""Motor/actuator attached to a joint.
type:
motor — direct torque control (ctrl = normalised torque)
position — PD position servo (ctrl = target angle, needs kp)
velocity — P velocity servo (ctrl = target velocity, needs kp)
"""
joint: str = ""
type: str = "motor"
gear: float = 1.0
ctrl_range: tuple[float, float] = (-1.0, 1.0)
damping: float = 0.05
kp: float = 0.0 # proportional gain (position / velocity actuators)
kv: float = 0.0 # derivative gain (position actuators)
@dataclasses.dataclass
class JointConfig:
"""Per-joint overrides applied on top of the URDF values."""
damping: float | None = None
@dataclasses.dataclass
class RobotConfig:
"""Complete robot hardware description."""
urdf_path: Path = dataclasses.field(default_factory=lambda: Path())
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
joints: dict[str, JointConfig] = dataclasses.field(default_factory=dict)
def load_robot_config(robot_dir: str | Path) -> RobotConfig:
"""Load robot.yaml from a directory and resolve the URDF path.
Expected layout:
robot_dir/
robot.yaml ← hardware config
some_robot.urdf ← CAD export
meshes/ ← optional mesh files
"""
robot_dir = Path(robot_dir).resolve()
yaml_path = robot_dir / "robot.yaml"
if not yaml_path.exists():
raise FileNotFoundError(f"Robot config not found: {yaml_path}")
raw = yaml.safe_load(yaml_path.read_text())
# Resolve URDF path relative to robot.yaml directory
urdf_filename = raw.get("urdf", "")
if not urdf_filename:
raise ValueError(f"robot.yaml must specify 'urdf' filename: {yaml_path}")
urdf_path = robot_dir / urdf_filename
if not urdf_path.exists():
raise FileNotFoundError(f"URDF not found: {urdf_path}")
# Parse actuators
actuators = []
for a in raw.get("actuators", []):
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
actuators.append(ActuatorConfig(**a))
# Parse joint overrides
joints = {}
for name, jcfg in raw.get("joints", {}).items():
joints[name] = JointConfig(**jcfg)
config = RobotConfig(
urdf_path=urdf_path,
actuators=actuators,
joints=joints,
)
log.debug("robot_config_loaded", robot_dir=str(robot_dir),
urdf=urdf_filename, num_actuators=len(actuators),
joint_overrides=list(joints.keys()))
return config

View File

@@ -53,9 +53,10 @@ class BaseRunner(abc.ABC, Generic[T]):
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
...
@abc.abstractmethod
def _sim_close(self) -> None:
...
"""Release simulator resources. Override for extra cleanup."""
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
all_ids = torch.arange(self.num_envs, device=self.device)

View File

@@ -29,7 +29,7 @@ import numpy as np
from src.core.env import BaseEnv
from src.core.runner import BaseRunner, BaseRunnerConfig
from src.runners.mujoco import MuJoCoRunner # reuse _load_model_with_actuators
from src.runners.mujoco import MuJoCoRunner # reuse _load_model
log = structlog.get_logger()
@@ -64,14 +64,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
# ── Initialization ───────────────────────────────────────────────
def _sim_initialize(self, config: MJXRunnerConfig) -> None:
model_path = self.env.config.model_path
if model_path is None:
raise ValueError("model_path must be specified")
# Step 1: Load CPU model (reuses URDF → MJCF → actuator injection)
self._mj_model = MuJoCoRunner._load_model_with_actuators(
str(model_path), self.env.config.actuators,
)
self._mj_model = MuJoCoRunner._load_model(self.env.robot)
self._mj_model.opt.timestep = config.dt
self._nq = self._mj_model.nq
self._nv = self._mj_model.nv
@@ -207,10 +201,6 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
rv = self._batch_data.qvel[ids_np].astype(jnp.float32)
return torch.from_dlpack(rq), torch.from_dlpack(rv)
def _sim_close(self) -> None:
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
# ── Rendering ────────────────────────────────────────────────────
def render(self, env_idx: int = 0) -> np.ndarray:

View File

@@ -6,7 +6,8 @@ import mujoco
import numpy as np
import torch
from src.core.env import BaseEnv, ActuatorConfig
from src.core.env import BaseEnv
from src.core.robot import RobotConfig
from src.core.runner import BaseRunner, BaseRunnerConfig
@dataclasses.dataclass
@@ -30,18 +31,18 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
return torch.device(self.config.device)
@staticmethod
def _load_model_with_actuators(model_path: str, actuators: list[ActuatorConfig]) -> mujoco.MjModel:
"""Load a URDF (or MJCF) file and programmatically inject actuators.
def _load_model(robot: RobotConfig) -> mujoco.MjModel:
"""Load a URDF (or MJCF) and apply robot.yaml settings.
Two-step approach required because MuJoCo's URDF parser ignores
<actuator> in the <mujoco> extension block:
1. Load the URDF → MuJoCo converts it to internal MJCF
2. Export the MJCF XML, add <actuator> elements, reload
2. Export the MJCF XML, inject actuators + joint overrides, reload
This keeps the URDF clean and standard — actuator config lives in
the env config (Isaac Lab pattern), not in the robot file.
This keeps the URDF clean (re-exportable from CAD) — all hardware
tuning lives in robot.yaml.
"""
abs_path = Path(model_path).resolve()
abs_path = robot.urdf_path.resolve()
model_dir = abs_path.parent
is_urdf = abs_path.suffix.lower() == ".urdf"
@@ -74,33 +75,45 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
else:
model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
if not actuators:
if not robot.actuators and not robot.joints:
return model_raw
# Step 2: Export internal MJCF representation (save next to original
# model so relative mesh/asset paths resolve correctly on reload)
# Step 2: Export internal MJCF, inject actuators + joint overrides, reload
tmp_mjcf = model_dir / "_tmp_actuator_inject.xml"
try:
mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
mjcf_str = tmp_mjcf.read_text()
# Step 3: Inject actuators into the MJCF XML
# Use torque actuator. Speed is limited by joint damping:
# at steady state, vel_max = torque / damping.
root = ET.fromstring(mjcf_str)
act_elem = ET.SubElement(root, "actuator")
for act in actuators:
ET.SubElement(act_elem, "motor", attrib={
"name": f"{act.joint}_motor",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
})
# Add damping to actuated joints to limit max speed and
# mimic real motor friction / back-EMF.
# vel_max ≈ max_torque / damping
joint_damping = {a.joint: a.damping for a in actuators}
# ── Inject actuators ────────────────────────────────────
if robot.actuators:
act_elem = ET.SubElement(root, "actuator")
for act in robot.actuators:
attribs = {
"name": f"{act.joint}_{act.type}",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
}
if act.type == "position":
attribs["kp"] = str(act.kp)
if act.kv > 0:
attribs["kv"] = str(act.kv)
ET.SubElement(act_elem, "position", attrib=attribs)
elif act.type == "velocity":
attribs["kp"] = str(act.kp)
ET.SubElement(act_elem, "velocity", attrib=attribs)
else: # motor (default)
ET.SubElement(act_elem, "motor", attrib=attribs)
# ── Apply joint overrides from robot.yaml ───────────────
# Merge actuator damping + explicit joint overrides
joint_damping = {a.joint: a.damping for a in robot.actuators}
for name, jcfg in robot.joints.items():
if jcfg.damping is not None:
joint_damping[name] = jcfg.damping
for body in root.iter("body"):
for jnt in body.findall("joint"):
name = jnt.get("name")
@@ -115,6 +128,15 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
geom.set("contype", "0")
geom.set("conaffinity", "0")
# Harden joint limits: MuJoCo's default soft limits are too
# weak and allow overshoot. Negative solref = hard constraint
# (direct stiffness/damping instead of impedance match).
for body in root.iter("body"):
for jnt in body.findall("joint"):
if jnt.get("limited") == "true" or jnt.get("range"):
jnt.set("solreflimit", "-1000 -100")
jnt.set("solimplimit", "0.95 0.99 0.001")
# Step 4: Write modified MJCF and reload from file path
# (from_xml_path resolves mesh paths relative to the file location)
modified_xml = ET.tostring(root, encoding="unicode")
@@ -124,12 +146,7 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
tmp_mjcf.unlink(missing_ok=True)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
model_path = self.env.config.model_path
if model_path is None:
raise ValueError("model_path must be specified in the environment config")
actuators = self.env.config.actuators
self._model = self._load_model_with_actuators(str(model_path), actuators)
self._model = self._load_model(self.env.robot)
self._model.opt.timestep = config.dt
self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)]
@@ -195,10 +212,6 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
torch.from_numpy(qvel_batch).to(self.device),
)
def _sim_close(self) -> None:
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
def render(self, env_idx: int = 0) -> np.ndarray:
"""Offscreen render of a single environment."""
if not hasattr(self, "_offscreen_renderer"):

View File

@@ -13,39 +13,13 @@ from clearml import Task
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
from src.core.env import BaseEnv
from src.core.registry import build_env
from src.core.runner import BaseRunner
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
from src.training.trainer import Trainer, TrainerConfig
logger = structlog.get_logger()
# ── env registry ──────────────────────────────────────────────────────
# Maps Hydra config-group name → (EnvClass, ConfigClass)
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
"cartpole": (CartPoleEnv, CartPoleConfig),
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
}
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
"""Instantiate the right env + config from the Hydra config-group name."""
if env_name not in ENV_REGISTRY:
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
env_cls, config_cls = ENV_REGISTRY[env_name]
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
# Convert actuator dicts → ActuatorConfig objects
if "actuators" in env_dict:
for a in env_dict["actuators"]:
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
return env_cls(config_cls(**env_dict))
# ── runner registry ───────────────────────────────────────────────────
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
@@ -123,7 +97,7 @@ def main(cfg: DictConfig) -> None:
task = _init_clearml(choices, remote=remote)
env_name = choices.get("env", "cartpole")
env = _build_env(env_name, cfg)
env = build_env(env_name, cfg)
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
trainer_config = TrainerConfig(**training_dict)
trainer = Trainer(runner=runner, config=trainer_config)

25
viz.py
View File

@@ -20,32 +20,11 @@ import torch
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.core.env import ActuatorConfig, BaseEnv, BaseEnvConfig
from src.envs.cartpole import CartPoleConfig, CartPoleEnv
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
from src.core.registry import build_env
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
logger = structlog.get_logger()
# ── registry (same as train.py) ──────────────────────────────────────
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
"cartpole": (CartPoleEnv, CartPoleConfig),
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
}
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
if env_name not in ENV_REGISTRY:
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
env_cls, config_cls = ENV_REGISTRY[env_name]
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
if "actuators" in env_dict:
for a in env_dict["actuators"]:
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
return env_cls(config_cls(**env_dict))
# ── keyboard state ───────────────────────────────────────────────────
_action_val = [0.0] # mutable container shared with callback
@@ -72,7 +51,7 @@ def main(cfg: DictConfig) -> None:
env_name = choices.get("env", "cartpole")
# Build env + runner (single env for viz)
env = _build_env(env_name, cfg)
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
runner_dict["num_envs"] = 1
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))