diff --git a/assets/cartpole/robot.yaml b/assets/cartpole/robot.yaml new file mode 100644 index 0000000..99a2069 --- /dev/null +++ b/assets/cartpole/robot.yaml @@ -0,0 +1,10 @@ +# Classic cartpole — robot hardware config. + +urdf: cartpole.urdf + +actuators: + - joint: cart_joint + type: motor + gear: 10.0 + ctrl_range: [-1.0, 1.0] + damping: 0.05 diff --git a/assets/rotary_cartpole/robot.yaml b/assets/rotary_cartpole/robot.yaml new file mode 100644 index 0000000..8228c19 --- /dev/null +++ b/assets/rotary_cartpole/robot.yaml @@ -0,0 +1,15 @@ +# Rotary cartpole (Furuta pendulum) — robot hardware config. +# Lives next to the URDF so all robot-specific settings are in one place. + +urdf: rotary_cartpole.urdf + +actuators: + - joint: motor_joint + type: motor # direct torque control + gear: 0.5 # torque multiplier + ctrl_range: [-1.0, 1.0] + damping: 0.1 # motor friction / back-EMF + +joints: + pendulum_joint: + damping: 0.0001 # bearing friction diff --git a/configs/env/cartpole.yaml b/configs/env/cartpole.yaml index 4fb9b67..ffd53dd 100644 --- a/configs/env/cartpole.yaml +++ b/configs/env/cartpole.yaml @@ -1,12 +1,7 @@ max_steps: 500 +robot_path: assets/cartpole angle_threshold: 0.418 cart_limit: 2.4 reward_alive: 1.0 reward_pole_upright_scale: 1.0 reward_action_penalty_scale: 0.01 -model_path: assets/cartpole/cartpole.urdf -actuators: - - joint: cart_joint - gear: 10.0 - ctrl_range: [-1.0, 1.0] - damping: 0.05 diff --git a/configs/env/rotary_cartpole.yaml b/configs/env/rotary_cartpole.yaml index 4d1a8bd..19a666b 100644 --- a/configs/env/rotary_cartpole.yaml +++ b/configs/env/rotary_cartpole.yaml @@ -1,8 +1,3 @@ max_steps: 1000 -model_path: assets/rotary_cartpole/rotary_cartpole.urdf -reward_upright_scale: 1.0 -actuators: - - joint: motor_joint - gear: 0.5 - ctrl_range: [-1.0, 1.0] - damping: 0.1 \ No newline at end of file +robot_path: assets/rotary_cartpole +reward_upright_scale: 1.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8466a62..693a91e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ clearml imageio imageio-ffmpeg structlog +pyyaml pytest \ No newline at end of file diff --git a/src/core/env.py b/src/core/env.py index d737b07..5ec6e48 100644 --- a/src/core/env.py +++ b/src/core/env.py @@ -5,30 +5,20 @@ from gymnasium import spaces import torch import pathlib +from src.core.robot import RobotConfig, load_robot_config + T = TypeVar("T") -@dataclasses.dataclass -class ActuatorConfig: - """Actuator definition — maps a joint to a motor with gear ratio and control limits. - Kept in the env config (not runner config) because actuators define what the robot - can do, which determines action space — a task-level concept. - This mirrors Isaac Lab's pattern of separating actuator config from the robot file.""" - joint: str = "" - gear: float = 1.0 - ctrl_range: tuple[float, float] = (-1.0, 1.0) - damping: float = 0.05 # joint damping — limits max speed: vel_max ≈ torque / damping - - @dataclasses.dataclass class BaseEnvConfig: max_steps: int = 1000 - model_path: pathlib.Path | None = None - actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list) + robot_path: str = "" # directory containing robot.yaml + URDF class BaseEnv(abc.ABC, Generic[T]): def __init__(self, config: BaseEnvConfig): self.config = config + self.robot: RobotConfig = load_robot_config(config.robot_path) @property @abc.abstractmethod diff --git a/src/core/registry.py b/src/core/registry.py new file mode 100644 index 0000000..99ade3c --- /dev/null +++ b/src/core/registry.py @@ -0,0 +1,23 @@ +"""Shared env registry and builder used by train.py and viz.py.""" + +from omegaconf import DictConfig, OmegaConf + +from src.core.env import BaseEnv, BaseEnvConfig +from src.envs.cartpole import CartPoleEnv, CartPoleConfig +from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig + +# Maps Hydra config-group name → (EnvClass, ConfigClass) +ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = { + "cartpole": (CartPoleEnv, CartPoleConfig), + "rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig), +} + + +def build_env(env_name: str, cfg: DictConfig) -> BaseEnv: + """Instantiate the right env + config from the Hydra config-group name.""" + if env_name not in ENV_REGISTRY: + raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}") + + env_cls, config_cls = ENV_REGISTRY[env_name] + env_dict = OmegaConf.to_container(cfg.env, resolve=True) + return env_cls(config_cls(**env_dict)) diff --git a/src/core/robot.py b/src/core/robot.py new file mode 100644 index 0000000..73f31e3 --- /dev/null +++ b/src/core/robot.py @@ -0,0 +1,101 @@ +"""Robot hardware configuration — loaded from robot.yaml next to the URDF. + +Separates robot hardware (actuators, joint tuning) from task config +(rewards, episode length) and from the URDF (clean CAD export). + +Usage: + robot = load_robot_config(Path("assets/rotary_cartpole")) + # robot.urdf_path → resolved absolute path to the URDF + # robot.actuators → list of ActuatorConfig + # robot.joints → dict of per-joint overrides +""" + +import dataclasses +from pathlib import Path + +import structlog +import yaml + +log = structlog.get_logger() + + +@dataclasses.dataclass +class ActuatorConfig: + """Motor/actuator attached to a joint. + + type: + motor — direct torque control (ctrl = normalised torque) + position — PD position servo (ctrl = target angle, needs kp) + velocity — P velocity servo (ctrl = target velocity, needs kp) + """ + joint: str = "" + type: str = "motor" + gear: float = 1.0 + ctrl_range: tuple[float, float] = (-1.0, 1.0) + damping: float = 0.05 + kp: float = 0.0 # proportional gain (position / velocity actuators) + kv: float = 0.0 # derivative gain (position actuators) + + +@dataclasses.dataclass +class JointConfig: + """Per-joint overrides applied on top of the URDF values.""" + damping: float | None = None + + +@dataclasses.dataclass +class RobotConfig: + """Complete robot hardware description.""" + urdf_path: Path = dataclasses.field(default_factory=lambda: Path()) + actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list) + joints: dict[str, JointConfig] = dataclasses.field(default_factory=dict) + + +def load_robot_config(robot_dir: str | Path) -> RobotConfig: + """Load robot.yaml from a directory and resolve the URDF path. + + Expected layout: + robot_dir/ + robot.yaml ← hardware config + some_robot.urdf ← CAD export + meshes/ ← optional mesh files + """ + robot_dir = Path(robot_dir).resolve() + yaml_path = robot_dir / "robot.yaml" + + if not yaml_path.exists(): + raise FileNotFoundError(f"Robot config not found: {yaml_path}") + + raw = yaml.safe_load(yaml_path.read_text()) + + # Resolve URDF path relative to robot.yaml directory + urdf_filename = raw.get("urdf", "") + if not urdf_filename: + raise ValueError(f"robot.yaml must specify 'urdf' filename: {yaml_path}") + urdf_path = robot_dir / urdf_filename + if not urdf_path.exists(): + raise FileNotFoundError(f"URDF not found: {urdf_path}") + + # Parse actuators + actuators = [] + for a in raw.get("actuators", []): + if "ctrl_range" in a: + a["ctrl_range"] = tuple(a["ctrl_range"]) + actuators.append(ActuatorConfig(**a)) + + # Parse joint overrides + joints = {} + for name, jcfg in raw.get("joints", {}).items(): + joints[name] = JointConfig(**jcfg) + + config = RobotConfig( + urdf_path=urdf_path, + actuators=actuators, + joints=joints, + ) + + log.debug("robot_config_loaded", robot_dir=str(robot_dir), + urdf=urdf_filename, num_actuators=len(actuators), + joint_overrides=list(joints.keys())) + + return config diff --git a/src/core/runner.py b/src/core/runner.py index b28fe28..f543ff1 100644 --- a/src/core/runner.py +++ b/src/core/runner.py @@ -53,9 +53,10 @@ class BaseRunner(abc.ABC, Generic[T]): def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ... - @abc.abstractmethod def _sim_close(self) -> None: - ... + """Release simulator resources. Override for extra cleanup.""" + if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None: + self._offscreen_renderer.close() def reset(self) -> tuple[torch.Tensor, dict[str, Any]]: all_ids = torch.arange(self.num_envs, device=self.device) diff --git a/src/runners/mjx.py b/src/runners/mjx.py index 63b5f73..3511fd2 100644 --- a/src/runners/mjx.py +++ b/src/runners/mjx.py @@ -29,7 +29,7 @@ import numpy as np from src.core.env import BaseEnv from src.core.runner import BaseRunner, BaseRunnerConfig -from src.runners.mujoco import MuJoCoRunner # reuse _load_model_with_actuators +from src.runners.mujoco import MuJoCoRunner # reuse _load_model log = structlog.get_logger() @@ -64,14 +64,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): # ── Initialization ─────────────────────────────────────────────── def _sim_initialize(self, config: MJXRunnerConfig) -> None: - model_path = self.env.config.model_path - if model_path is None: - raise ValueError("model_path must be specified") - # Step 1: Load CPU model (reuses URDF → MJCF → actuator injection) - self._mj_model = MuJoCoRunner._load_model_with_actuators( - str(model_path), self.env.config.actuators, - ) + self._mj_model = MuJoCoRunner._load_model(self.env.robot) self._mj_model.opt.timestep = config.dt self._nq = self._mj_model.nq self._nv = self._mj_model.nv @@ -207,10 +201,6 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): rv = self._batch_data.qvel[ids_np].astype(jnp.float32) return torch.from_dlpack(rq), torch.from_dlpack(rv) - def _sim_close(self) -> None: - if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None: - self._offscreen_renderer.close() - # ── Rendering ──────────────────────────────────────────────────── def render(self, env_idx: int = 0) -> np.ndarray: diff --git a/src/runners/mujoco.py b/src/runners/mujoco.py index 5e140c3..c86b391 100644 --- a/src/runners/mujoco.py +++ b/src/runners/mujoco.py @@ -6,7 +6,8 @@ import mujoco import numpy as np import torch -from src.core.env import BaseEnv, ActuatorConfig +from src.core.env import BaseEnv +from src.core.robot import RobotConfig from src.core.runner import BaseRunner, BaseRunnerConfig @dataclasses.dataclass @@ -30,18 +31,18 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): return torch.device(self.config.device) @staticmethod - def _load_model_with_actuators(model_path: str, actuators: list[ActuatorConfig]) -> mujoco.MjModel: - """Load a URDF (or MJCF) file and programmatically inject actuators. + def _load_model(robot: RobotConfig) -> mujoco.MjModel: + """Load a URDF (or MJCF) and apply robot.yaml settings. Two-step approach required because MuJoCo's URDF parser ignores in the extension block: 1. Load the URDF → MuJoCo converts it to internal MJCF - 2. Export the MJCF XML, add elements, reload + 2. Export the MJCF XML, inject actuators + joint overrides, reload - This keeps the URDF clean and standard — actuator config lives in - the env config (Isaac Lab pattern), not in the robot file. + This keeps the URDF clean (re-exportable from CAD) — all hardware + tuning lives in robot.yaml. """ - abs_path = Path(model_path).resolve() + abs_path = robot.urdf_path.resolve() model_dir = abs_path.parent is_urdf = abs_path.suffix.lower() == ".urdf" @@ -74,33 +75,45 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): else: model_raw = mujoco.MjModel.from_xml_path(str(abs_path)) - if not actuators: + if not robot.actuators and not robot.joints: return model_raw - # Step 2: Export internal MJCF representation (save next to original - # model so relative mesh/asset paths resolve correctly on reload) + # Step 2: Export internal MJCF, inject actuators + joint overrides, reload tmp_mjcf = model_dir / "_tmp_actuator_inject.xml" try: mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw) mjcf_str = tmp_mjcf.read_text() - # Step 3: Inject actuators into the MJCF XML - # Use torque actuator. Speed is limited by joint damping: - # at steady state, vel_max = torque / damping. root = ET.fromstring(mjcf_str) - act_elem = ET.SubElement(root, "actuator") - for act in actuators: - ET.SubElement(act_elem, "motor", attrib={ - "name": f"{act.joint}_motor", - "joint": act.joint, - "gear": str(act.gear), - "ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}", - }) - # Add damping to actuated joints to limit max speed and - # mimic real motor friction / back-EMF. - # vel_max ≈ max_torque / damping - joint_damping = {a.joint: a.damping for a in actuators} + # ── Inject actuators ──────────────────────────────────── + if robot.actuators: + act_elem = ET.SubElement(root, "actuator") + for act in robot.actuators: + attribs = { + "name": f"{act.joint}_{act.type}", + "joint": act.joint, + "gear": str(act.gear), + "ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}", + } + if act.type == "position": + attribs["kp"] = str(act.kp) + if act.kv > 0: + attribs["kv"] = str(act.kv) + ET.SubElement(act_elem, "position", attrib=attribs) + elif act.type == "velocity": + attribs["kp"] = str(act.kp) + ET.SubElement(act_elem, "velocity", attrib=attribs) + else: # motor (default) + ET.SubElement(act_elem, "motor", attrib=attribs) + + # ── Apply joint overrides from robot.yaml ─────────────── + # Merge actuator damping + explicit joint overrides + joint_damping = {a.joint: a.damping for a in robot.actuators} + for name, jcfg in robot.joints.items(): + if jcfg.damping is not None: + joint_damping[name] = jcfg.damping + for body in root.iter("body"): for jnt in body.findall("joint"): name = jnt.get("name") @@ -115,6 +128,15 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): geom.set("contype", "0") geom.set("conaffinity", "0") + # Harden joint limits: MuJoCo's default soft limits are too + # weak and allow overshoot. Negative solref = hard constraint + # (direct stiffness/damping instead of impedance match). + for body in root.iter("body"): + for jnt in body.findall("joint"): + if jnt.get("limited") == "true" or jnt.get("range"): + jnt.set("solreflimit", "-1000 -100") + jnt.set("solimplimit", "0.95 0.99 0.001") + # Step 4: Write modified MJCF and reload from file path # (from_xml_path resolves mesh paths relative to the file location) modified_xml = ET.tostring(root, encoding="unicode") @@ -124,12 +146,7 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): tmp_mjcf.unlink(missing_ok=True) def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None: - model_path = self.env.config.model_path - if model_path is None: - raise ValueError("model_path must be specified in the environment config") - - actuators = self.env.config.actuators - self._model = self._load_model_with_actuators(str(model_path), actuators) + self._model = self._load_model(self.env.robot) self._model.opt.timestep = config.dt self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)] @@ -195,10 +212,6 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): torch.from_numpy(qvel_batch).to(self.device), ) - def _sim_close(self) -> None: - if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None: - self._offscreen_renderer.close() - def render(self, env_idx: int = 0) -> np.ndarray: """Offscreen render of a single environment.""" if not hasattr(self, "_offscreen_renderer"): diff --git a/train.py b/train.py index ddcce35..124c398 100644 --- a/train.py +++ b/train.py @@ -13,39 +13,13 @@ from clearml import Task from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf -from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig +from src.core.env import BaseEnv +from src.core.registry import build_env from src.core.runner import BaseRunner -from src.envs.cartpole import CartPoleEnv, CartPoleConfig -from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig from src.training.trainer import Trainer, TrainerConfig logger = structlog.get_logger() -# ── env registry ────────────────────────────────────────────────────── -# Maps Hydra config-group name → (EnvClass, ConfigClass) -ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = { - "cartpole": (CartPoleEnv, CartPoleConfig), - "rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig), -} - - -def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv: - """Instantiate the right env + config from the Hydra config-group name.""" - if env_name not in ENV_REGISTRY: - raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}") - - env_cls, config_cls = ENV_REGISTRY[env_name] - env_dict = OmegaConf.to_container(cfg.env, resolve=True) - - # Convert actuator dicts → ActuatorConfig objects - if "actuators" in env_dict: - for a in env_dict["actuators"]: - if "ctrl_range" in a: - a["ctrl_range"] = tuple(a["ctrl_range"]) - env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]] - - return env_cls(config_cls(**env_dict)) - # ── runner registry ─────────────────────────────────────────────────── # Maps Hydra config-group name → (RunnerClass, ConfigClass) @@ -123,7 +97,7 @@ def main(cfg: DictConfig) -> None: task = _init_clearml(choices, remote=remote) env_name = choices.get("env", "cartpole") - env = _build_env(env_name, cfg) + env = build_env(env_name, cfg) runner = _build_runner(choices.get("runner", "mujoco"), env, cfg) trainer_config = TrainerConfig(**training_dict) trainer = Trainer(runner=runner, config=trainer_config) diff --git a/viz.py b/viz.py index d4187d4..59d937e 100644 --- a/viz.py +++ b/viz.py @@ -20,32 +20,11 @@ import torch from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf -from src.core.env import ActuatorConfig, BaseEnv, BaseEnvConfig -from src.envs.cartpole import CartPoleConfig, CartPoleEnv -from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv +from src.core.registry import build_env from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig logger = structlog.get_logger() -# ── registry (same as train.py) ────────────────────────────────────── -ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = { - "cartpole": (CartPoleEnv, CartPoleConfig), - "rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig), -} - - -def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv: - if env_name not in ENV_REGISTRY: - raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}") - env_cls, config_cls = ENV_REGISTRY[env_name] - env_dict = OmegaConf.to_container(cfg.env, resolve=True) - if "actuators" in env_dict: - for a in env_dict["actuators"]: - if "ctrl_range" in a: - a["ctrl_range"] = tuple(a["ctrl_range"]) - env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]] - return env_cls(config_cls(**env_dict)) - # ── keyboard state ─────────────────────────────────────────────────── _action_val = [0.0] # mutable container shared with callback @@ -72,7 +51,7 @@ def main(cfg: DictConfig) -> None: env_name = choices.get("env", "cartpole") # Build env + runner (single env for viz) - env = _build_env(env_name, cfg) + env = build_env(env_name, cfg) runner_dict = OmegaConf.to_container(cfg.runner, resolve=True) runner_dict["num_envs"] = 1 runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))