diff --git a/assets/rotary_cartpole/rotary_cartpole.urdf b/assets/rotary_cartpole/rotary_cartpole.urdf
index b68ed3b..327c023 100644
--- a/assets/rotary_cartpole/rotary_cartpole.urdf
+++ b/assets/rotary_cartpole/rotary_cartpole.urdf
@@ -36,9 +36,9 @@
-
-
+
+
@@ -73,7 +73,7 @@
Tip at (0.07, -0.07, 0) → 45° diagonal in +X/-Y.
CoM = (5×0.035+10×0.07)/15 = 0.0583 along both +X and -Y.
Inertia tensor rotated 45° to match diagonal rod axis. -->
-
+
@@ -93,13 +93,14 @@
+ Joint origin corrected from mesh analysis (Fusion2URDF was 180mm off).
+ rpy pitch +90° so qpos=0 = pendulum hanging down (gravity-stable). -->
-
+
-
+
diff --git a/configs/env/cartpole.yaml b/configs/env/cartpole.yaml
index f7d0722..4fb9b67 100644
--- a/configs/env/cartpole.yaml
+++ b/configs/env/cartpole.yaml
@@ -9,3 +9,4 @@ actuators:
- joint: cart_joint
gear: 10.0
ctrl_range: [-1.0, 1.0]
+ damping: 0.05
diff --git a/configs/env/rotary_cartpole.yaml b/configs/env/rotary_cartpole.yaml
index c4fb61c..4d1a8bd 100644
--- a/configs/env/rotary_cartpole.yaml
+++ b/configs/env/rotary_cartpole.yaml
@@ -3,5 +3,6 @@ model_path: assets/rotary_cartpole/rotary_cartpole.urdf
reward_upright_scale: 1.0
actuators:
- joint: motor_joint
- gear: 15.0
+ gear: 0.5
ctrl_range: [-1.0, 1.0]
+ damping: 0.1
\ No newline at end of file
diff --git a/configs/training/ppo.yaml b/configs/training/ppo.yaml
index c078de6..d3d1786 100644
--- a/configs/training/ppo.yaml
+++ b/configs/training/ppo.yaml
@@ -9,7 +9,8 @@ learning_rate: 0.0003
clip_ratio: 0.2
value_loss_scale: 0.5
entropy_loss_scale: 0.05
-log_interval: 10
+log_interval: 1000
+checkpoint_interval: 50000
# ClearML remote execution (GPU worker)
remote: false
diff --git a/requirements.txt b/requirements.txt
index f3ed7df..4ac07a4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,4 +7,5 @@ skrl[torch]
clearml
imageio
imageio-ffmpeg
+structlog
pytest
\ No newline at end of file
diff --git a/src/core/env.py b/src/core/env.py
index 80c5d96..d737b07 100644
--- a/src/core/env.py
+++ b/src/core/env.py
@@ -17,6 +17,7 @@ class ActuatorConfig:
joint: str = ""
gear: float = 1.0
ctrl_range: tuple[float, float] = (-1.0, 1.0)
+ damping: float = 0.05 # joint damping — limits max speed: vel_max ≈ torque / damping
@dataclasses.dataclass
diff --git a/src/envs/rotary_cartpole.py b/src/envs/rotary_cartpole.py
index d70dcf2..402e3ce 100644
--- a/src/envs/rotary_cartpole.py
+++ b/src/envs/rotary_cartpole.py
@@ -66,21 +66,14 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
], dim=-1)
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
- # height: sin(θ) → -1 (down) to +1 (up)
- height = torch.sin(state.pendulum_angle)
+ # Upright reward: -cos(θ) ∈ [-1, +1]
+ upright = -torch.cos(state.pendulum_angle)
- # Upright reward: strongly rewards being near vertical.
- # Uses cos(θ - π/2) = sin(θ), squared and scaled so:
- # down (h=-1): 0.0
- # horiz (h= 0): 0.0
- # up (h=+1): 1.0
- # Only kicks in above horizontal, so swing-up isn't penalised.
- upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
+ # Velocity penalties — make spinning expensive but allow swing-up
+ pend_vel_penalty = 0.01 * state.pendulum_vel ** 2
+ motor_vel_penalty = 0.01 * state.motor_vel ** 2
- # Motor effort penalty: small cost to avoid bang-bang control.
- effort_penalty = 0.001 * actions.squeeze(-1) ** 2
-
- return 5.0 * upright_reward - effort_penalty
+ return upright - pend_vel_penalty - motor_vel_penalty
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
# No early termination — episode runs for max_steps (truncation only).
@@ -88,7 +81,5 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
def get_default_qpos(self, nq: int) -> list[float] | None:
- # The STL mesh is horizontal at qpos=0.
- # Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
- import math
- return [0.0, -math.pi / 2]
+ # qpos=0 = pendulum hanging down (joint frame rotated in URDF).
+ return None
diff --git a/src/runners/mujoco.py b/src/runners/mujoco.py
index 7cf7755..f8c79b5 100644
--- a/src/runners/mujoco.py
+++ b/src/runners/mujoco.py
@@ -1,11 +1,13 @@
import dataclasses
-import os
import xml.etree.ElementTree as ET
+from pathlib import Path
+
+import mujoco
+import numpy as np
+import torch
+
from src.core.env import BaseEnv, ActuatorConfig
from src.core.runner import BaseRunner, BaseRunnerConfig
-import torch
-import numpy as np
-import mujoco
@dataclasses.dataclass
class MuJoCoRunnerConfig(BaseRunnerConfig):
@@ -39,9 +41,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
This keeps the URDF clean and standard — actuator config lives in
the env config (Isaac Lab pattern), not in the robot file.
"""
- abs_path = os.path.abspath(model_path)
- model_dir = os.path.dirname(abs_path)
- is_urdf = abs_path.lower().endswith(".urdf")
+ abs_path = Path(model_path).resolve()
+ model_dir = abs_path.parent
+ is_urdf = abs_path.suffix.lower() == ".urdf"
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
# so we inject a block into a
@@ -53,9 +55,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
meshdir = None
for mesh_el in root.iter("mesh"):
fn = mesh_el.get("filename", "")
- dirname = os.path.dirname(fn)
- if dirname:
- meshdir = dirname
+ parent = str(Path(fn).parent)
+ if parent and parent != ".":
+ meshdir = parent
break
if meshdir:
mj_ext = ET.SubElement(root, "mujoco")
@@ -63,25 +65,24 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
"meshdir": meshdir,
"balanceinertia": "true",
})
- tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf")
- tree.write(tmp_urdf, xml_declaration=True, encoding="unicode")
+ tmp_urdf = model_dir / "_tmp_mujoco_load.urdf"
+ tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode")
try:
- model_raw = mujoco.MjModel.from_xml_path(tmp_urdf)
+ model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf))
finally:
- os.unlink(tmp_urdf)
+ tmp_urdf.unlink()
else:
- model_raw = mujoco.MjModel.from_xml_path(abs_path)
+ model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
if not actuators:
return model_raw
# Step 2: Export internal MJCF representation (save next to original
# model so relative mesh/asset paths resolve correctly on reload)
- tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml")
+ tmp_mjcf = model_dir / "_tmp_actuator_inject.xml"
try:
- mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
- with open(tmp_mjcf) as f:
- mjcf_str = f.read()
+ mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
+ mjcf_str = tmp_mjcf.read_text()
# Step 3: Inject actuators into the MJCF XML
# Use torque actuator. Speed is limited by joint damping:
@@ -98,12 +99,13 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
# Add damping to actuated joints to limit max speed and
# mimic real motor friction / back-EMF.
- # vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s)
- actuated_joints = {a.joint for a in actuators}
+ # vel_max ≈ max_torque / damping
+ joint_damping = {a.joint: a.damping for a in actuators}
for body in root.iter("body"):
for jnt in body.findall("joint"):
- if jnt.get("name") in actuated_joints:
- jnt.set("damping", "0.05")
+ name = jnt.get("name")
+ if name in joint_damping:
+ jnt.set("damping", str(joint_damping[name]))
# Disable self-collision on all geoms.
# URDF mesh convex hulls often overlap at joints (especially
@@ -116,12 +118,10 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
# Step 4: Write modified MJCF and reload from file path
# (from_xml_path resolves mesh paths relative to the file location)
modified_xml = ET.tostring(root, encoding="unicode")
- with open(tmp_mjcf, "w") as f:
- f.write(modified_xml)
- return mujoco.MjModel.from_xml_path(tmp_mjcf)
+ tmp_mjcf.write_text(modified_xml)
+ return mujoco.MjModel.from_xml_path(str(tmp_mjcf))
finally:
- if os.path.exists(tmp_mjcf):
- os.unlink(tmp_mjcf)
+ tmp_mjcf.unlink(missing_ok=True)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
model_path = self.env.config.model_path
diff --git a/src/training/trainer.py b/src/training/trainer.py
index 3315156..5230049 100644
--- a/src/training/trainer.py
+++ b/src/training/trainer.py
@@ -35,6 +35,7 @@ class TrainerConfig:
# Training
total_timesteps: int = 1_000_000
log_interval: int = 10
+ checkpoint_interval: int = 50_000
# Video recording (uploaded to ClearML)
record_video_every: int = 10_000 # 0 = disabled
@@ -196,7 +197,7 @@ class Trainer:
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
agent_cfg["experiment"]["checkpoint_interval"] = max(
- self.config.total_timesteps // 10, self.config.rollout_steps
+ self.config.checkpoint_interval, self.config.rollout_steps
)
self.agent = PPO(
diff --git a/train.py b/train.py
index b133bbc..8d16e80 100644
--- a/train.py
+++ b/train.py
@@ -1,4 +1,8 @@
+import pathlib
+
import hydra
+import hydra.utils as hydra_utils
+import structlog
from clearml import Task
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
@@ -9,6 +13,8 @@ from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
from src.training.trainer import Trainer, TrainerConfig
+logger = structlog.get_logger()
+
# ── env registry ──────────────────────────────────────────────────────
# Maps Hydra config-group name → (EnvClass, ConfigClass)
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
@@ -52,9 +58,15 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
tags = [env_name, runner_name, training_name]
task = Task.init(project_name=project, task_name=task_name, tags=tags)
+ task.set_base_docker("registry.kube.optimize/worker-image:latest")
- if remote:
- task.execute_remotely(queue_name="default")
+ req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
+ task.set_packages(str(req_file))
+
+ # Execute remotely if requested and running locally
+ if remote and task.running_locally():
+ logger.info("executing_task_remotely", queue="gpu-queue")
+ task.execute_remotely(queue_name="gpu-queue", exit_process=True)
return task
diff --git a/viz.py b/viz.py
new file mode 100644
index 0000000..d4187d4
--- /dev/null
+++ b/viz.py
@@ -0,0 +1,137 @@
+"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
+
+Usage:
+ mjpython viz.py env=rotary_cartpole
+ mjpython viz.py env=cartpole +com=true
+
+Controls:
+ Left/Right arrows — apply torque to first actuator
+ R — reset environment
+ Esc / close window — quit
+"""
+import math
+import time
+
+import hydra
+import mujoco
+import mujoco.viewer
+import structlog
+import torch
+from hydra.core.hydra_config import HydraConfig
+from omegaconf import DictConfig, OmegaConf
+
+from src.core.env import ActuatorConfig, BaseEnv, BaseEnvConfig
+from src.envs.cartpole import CartPoleConfig, CartPoleEnv
+from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
+from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
+
+logger = structlog.get_logger()
+
+# ── registry (same as train.py) ──────────────────────────────────────
+ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
+ "cartpole": (CartPoleEnv, CartPoleConfig),
+ "rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
+}
+
+
+def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
+ if env_name not in ENV_REGISTRY:
+ raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
+ env_cls, config_cls = ENV_REGISTRY[env_name]
+ env_dict = OmegaConf.to_container(cfg.env, resolve=True)
+ if "actuators" in env_dict:
+ for a in env_dict["actuators"]:
+ if "ctrl_range" in a:
+ a["ctrl_range"] = tuple(a["ctrl_range"])
+ env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
+ return env_cls(config_cls(**env_dict))
+
+
+# ── keyboard state ───────────────────────────────────────────────────
+_action_val = [0.0] # mutable container shared with callback
+_action_time = [0.0] # timestamp of last key press
+_reset_flag = [False]
+_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
+
+
+def _key_callback(keycode: int) -> None:
+ """Called by MuJoCo on key press & repeat (not release)."""
+ if keycode == 263: # GLFW_KEY_LEFT
+ _action_val[0] = -1.0
+ _action_time[0] = time.time()
+ elif keycode == 262: # GLFW_KEY_RIGHT
+ _action_val[0] = 1.0
+ _action_time[0] = time.time()
+ elif keycode == 82: # GLFW_KEY_R
+ _reset_flag[0] = True
+
+
+@hydra.main(version_base=None, config_path="configs", config_name="config")
+def main(cfg: DictConfig) -> None:
+ choices = HydraConfig.get().runtime.choices
+ env_name = choices.get("env", "cartpole")
+
+ # Build env + runner (single env for viz)
+ env = _build_env(env_name, cfg)
+ runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
+ runner_dict["num_envs"] = 1
+ runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
+
+ model = runner._model
+ data = runner._data[0]
+
+ # Control period
+ dt_ctrl = runner.config.dt * runner.config.substeps
+
+ # Launch viewer
+ with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
+ # Show CoM / inertia if requested via Hydra override: viz.py +com=true
+ show_com = cfg.get("com", False)
+ if show_com:
+ viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
+ viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
+
+ obs, _ = runner.reset()
+ step = 0
+
+ logger.info("viewer_started", env=env_name,
+ controls="Left/Right arrows = torque, R = reset")
+
+ while viewer.is_running():
+ # Read action from callback (expires after _ACTION_HOLD_S)
+ if time.time() - _action_time[0] < _ACTION_HOLD_S:
+ action_val = _action_val[0]
+ else:
+ action_val = 0.0
+
+ # Reset on R press
+ if _reset_flag[0]:
+ _reset_flag[0] = False
+ obs, _ = runner.reset()
+ step = 0
+ logger.info("reset")
+
+ # Step through runner
+ action = torch.tensor([[action_val]])
+ obs, reward, terminated, truncated, info = runner.step(action)
+
+ # Sync viewer
+ mujoco.mj_forward(model, data)
+ viewer.sync()
+
+ # Print state
+ if step % 25 == 0:
+ joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
+ for i in range(model.njnt)}
+ logger.debug("step", n=step, reward=round(reward.item(), 3),
+ action=round(action_val, 1), **joints)
+
+ # Real-time pacing
+ time.sleep(dt_ctrl)
+ step += 1
+
+ runner.close()
+
+
+if __name__ == "__main__":
+ main()