✨ update urdf and dependencies
This commit is contained in:
@@ -36,9 +36,9 @@
|
||||
<link name="arm">
|
||||
<inertial>
|
||||
<origin xyz="0.00005 0.0065 0.00563" rpy="0 0 0"/>
|
||||
<mass value="0.150"/>
|
||||
<inertia ixx="4.05e-05" iyy="1.17e-05" izz="3.66e-05"
|
||||
ixy="0.0" iyz="1.08e-07" ixz="0.0"/>
|
||||
<mass value="0.010"/>
|
||||
<inertia ixx="2.70e-06" iyy="7.80e-07" izz="2.44e-06"
|
||||
ixy="0.0" iyz="7.20e-08" ixz="0.0"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0"/>
|
||||
@@ -73,7 +73,7 @@
|
||||
Tip at (0.07, -0.07, 0) → 45° diagonal in +X/-Y.
|
||||
CoM = (5×0.035+10×0.07)/15 = 0.0583 along both +X and -Y.
|
||||
Inertia tensor rotated 45° to match diagonal rod axis. -->
|
||||
<origin xyz="0.0583 -0.0583 0.0" rpy="0 0 0"/>
|
||||
<origin xyz="0.1583 -0.0983 -0.0" rpy="0 0 0"/>
|
||||
<mass value="0.015"/>
|
||||
<inertia ixx="6.16e-06" iyy="6.16e-06" izz="1.23e-05"
|
||||
ixy="6.10e-06" iyz="0.0" ixz="0.0"/>
|
||||
@@ -93,13 +93,14 @@
|
||||
</link>
|
||||
|
||||
<!-- Pendulum joint: arm → pendulum, bearing axis along Y.
|
||||
Joint origin corrected from mesh analysis (Fusion2URDF was 180mm off). -->
|
||||
Joint origin corrected from mesh analysis (Fusion2URDF was 180mm off).
|
||||
rpy pitch +90° so qpos=0 = pendulum hanging down (gravity-stable). -->
|
||||
<joint name="pendulum_joint" type="continuous">
|
||||
<origin xyz="0.000052 0.019274 0.014993" rpy="0 0 0"/>
|
||||
<origin xyz="0.000052 0.019274 0.014993" rpy="0 1.5708 0"/>
|
||||
<parent link="arm"/>
|
||||
<child link="pendulum"/>
|
||||
<axis xyz="0 -1 0"/>
|
||||
<dynamics damping="0.0005"/>
|
||||
<dynamics damping="0.0001"/>
|
||||
</joint>
|
||||
|
||||
</robot>
|
||||
|
||||
1
configs/env/cartpole.yaml
vendored
1
configs/env/cartpole.yaml
vendored
@@ -9,3 +9,4 @@ actuators:
|
||||
- joint: cart_joint
|
||||
gear: 10.0
|
||||
ctrl_range: [-1.0, 1.0]
|
||||
damping: 0.05
|
||||
|
||||
3
configs/env/rotary_cartpole.yaml
vendored
3
configs/env/rotary_cartpole.yaml
vendored
@@ -3,5 +3,6 @@ model_path: assets/rotary_cartpole/rotary_cartpole.urdf
|
||||
reward_upright_scale: 1.0
|
||||
actuators:
|
||||
- joint: motor_joint
|
||||
gear: 15.0
|
||||
gear: 0.5
|
||||
ctrl_range: [-1.0, 1.0]
|
||||
damping: 0.1
|
||||
@@ -9,7 +9,8 @@ learning_rate: 0.0003
|
||||
clip_ratio: 0.2
|
||||
value_loss_scale: 0.5
|
||||
entropy_loss_scale: 0.05
|
||||
log_interval: 10
|
||||
log_interval: 1000
|
||||
checkpoint_interval: 50000
|
||||
|
||||
# ClearML remote execution (GPU worker)
|
||||
remote: false
|
||||
|
||||
@@ -7,4 +7,5 @@ skrl[torch]
|
||||
clearml
|
||||
imageio
|
||||
imageio-ffmpeg
|
||||
structlog
|
||||
pytest
|
||||
@@ -17,6 +17,7 @@ class ActuatorConfig:
|
||||
joint: str = ""
|
||||
gear: float = 1.0
|
||||
ctrl_range: tuple[float, float] = (-1.0, 1.0)
|
||||
damping: float = 0.05 # joint damping — limits max speed: vel_max ≈ torque / damping
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
@@ -66,21 +66,14 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||
], dim=-1)
|
||||
|
||||
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
||||
# height: sin(θ) → -1 (down) to +1 (up)
|
||||
height = torch.sin(state.pendulum_angle)
|
||||
# Upright reward: -cos(θ) ∈ [-1, +1]
|
||||
upright = -torch.cos(state.pendulum_angle)
|
||||
|
||||
# Upright reward: strongly rewards being near vertical.
|
||||
# Uses cos(θ - π/2) = sin(θ), squared and scaled so:
|
||||
# down (h=-1): 0.0
|
||||
# horiz (h= 0): 0.0
|
||||
# up (h=+1): 1.0
|
||||
# Only kicks in above horizontal, so swing-up isn't penalised.
|
||||
upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
|
||||
# Velocity penalties — make spinning expensive but allow swing-up
|
||||
pend_vel_penalty = 0.01 * state.pendulum_vel ** 2
|
||||
motor_vel_penalty = 0.01 * state.motor_vel ** 2
|
||||
|
||||
# Motor effort penalty: small cost to avoid bang-bang control.
|
||||
effort_penalty = 0.001 * actions.squeeze(-1) ** 2
|
||||
|
||||
return 5.0 * upright_reward - effort_penalty
|
||||
return upright - pend_vel_penalty - motor_vel_penalty
|
||||
|
||||
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||
# No early termination — episode runs for max_steps (truncation only).
|
||||
@@ -88,7 +81,5 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
|
||||
|
||||
def get_default_qpos(self, nq: int) -> list[float] | None:
|
||||
# The STL mesh is horizontal at qpos=0.
|
||||
# Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
|
||||
import math
|
||||
return [0.0, -math.pi / 2]
|
||||
# qpos=0 = pendulum hanging down (joint frame rotated in URDF).
|
||||
return None
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
|
||||
import mujoco
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.core.env import BaseEnv, ActuatorConfig
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
import torch
|
||||
import numpy as np
|
||||
import mujoco
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||
@@ -39,9 +41,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
This keeps the URDF clean and standard — actuator config lives in
|
||||
the env config (Isaac Lab pattern), not in the robot file.
|
||||
"""
|
||||
abs_path = os.path.abspath(model_path)
|
||||
model_dir = os.path.dirname(abs_path)
|
||||
is_urdf = abs_path.lower().endswith(".urdf")
|
||||
abs_path = Path(model_path).resolve()
|
||||
model_dir = abs_path.parent
|
||||
is_urdf = abs_path.suffix.lower() == ".urdf"
|
||||
|
||||
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
|
||||
# so we inject a <mujoco><compiler meshdir="..."/> block into a
|
||||
@@ -53,9 +55,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
meshdir = None
|
||||
for mesh_el in root.iter("mesh"):
|
||||
fn = mesh_el.get("filename", "")
|
||||
dirname = os.path.dirname(fn)
|
||||
if dirname:
|
||||
meshdir = dirname
|
||||
parent = str(Path(fn).parent)
|
||||
if parent and parent != ".":
|
||||
meshdir = parent
|
||||
break
|
||||
if meshdir:
|
||||
mj_ext = ET.SubElement(root, "mujoco")
|
||||
@@ -63,25 +65,24 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
"meshdir": meshdir,
|
||||
"balanceinertia": "true",
|
||||
})
|
||||
tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf")
|
||||
tree.write(tmp_urdf, xml_declaration=True, encoding="unicode")
|
||||
tmp_urdf = model_dir / "_tmp_mujoco_load.urdf"
|
||||
tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode")
|
||||
try:
|
||||
model_raw = mujoco.MjModel.from_xml_path(tmp_urdf)
|
||||
model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf))
|
||||
finally:
|
||||
os.unlink(tmp_urdf)
|
||||
tmp_urdf.unlink()
|
||||
else:
|
||||
model_raw = mujoco.MjModel.from_xml_path(abs_path)
|
||||
model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
|
||||
|
||||
if not actuators:
|
||||
return model_raw
|
||||
|
||||
# Step 2: Export internal MJCF representation (save next to original
|
||||
# model so relative mesh/asset paths resolve correctly on reload)
|
||||
tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml")
|
||||
tmp_mjcf = model_dir / "_tmp_actuator_inject.xml"
|
||||
try:
|
||||
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
|
||||
with open(tmp_mjcf) as f:
|
||||
mjcf_str = f.read()
|
||||
mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
|
||||
mjcf_str = tmp_mjcf.read_text()
|
||||
|
||||
# Step 3: Inject actuators into the MJCF XML
|
||||
# Use torque actuator. Speed is limited by joint damping:
|
||||
@@ -98,12 +99,13 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
|
||||
# Add damping to actuated joints to limit max speed and
|
||||
# mimic real motor friction / back-EMF.
|
||||
# vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s)
|
||||
actuated_joints = {a.joint for a in actuators}
|
||||
# vel_max ≈ max_torque / damping
|
||||
joint_damping = {a.joint: a.damping for a in actuators}
|
||||
for body in root.iter("body"):
|
||||
for jnt in body.findall("joint"):
|
||||
if jnt.get("name") in actuated_joints:
|
||||
jnt.set("damping", "0.05")
|
||||
name = jnt.get("name")
|
||||
if name in joint_damping:
|
||||
jnt.set("damping", str(joint_damping[name]))
|
||||
|
||||
# Disable self-collision on all geoms.
|
||||
# URDF mesh convex hulls often overlap at joints (especially
|
||||
@@ -116,12 +118,10 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
# Step 4: Write modified MJCF and reload from file path
|
||||
# (from_xml_path resolves mesh paths relative to the file location)
|
||||
modified_xml = ET.tostring(root, encoding="unicode")
|
||||
with open(tmp_mjcf, "w") as f:
|
||||
f.write(modified_xml)
|
||||
return mujoco.MjModel.from_xml_path(tmp_mjcf)
|
||||
tmp_mjcf.write_text(modified_xml)
|
||||
return mujoco.MjModel.from_xml_path(str(tmp_mjcf))
|
||||
finally:
|
||||
if os.path.exists(tmp_mjcf):
|
||||
os.unlink(tmp_mjcf)
|
||||
tmp_mjcf.unlink(missing_ok=True)
|
||||
|
||||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||||
model_path = self.env.config.model_path
|
||||
|
||||
@@ -35,6 +35,7 @@ class TrainerConfig:
|
||||
# Training
|
||||
total_timesteps: int = 1_000_000
|
||||
log_interval: int = 10
|
||||
checkpoint_interval: int = 50_000
|
||||
|
||||
# Video recording (uploaded to ClearML)
|
||||
record_video_every: int = 10_000 # 0 = disabled
|
||||
@@ -196,7 +197,7 @@ class Trainer:
|
||||
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
|
||||
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
|
||||
agent_cfg["experiment"]["checkpoint_interval"] = max(
|
||||
self.config.total_timesteps // 10, self.config.rollout_steps
|
||||
self.config.checkpoint_interval, self.config.rollout_steps
|
||||
)
|
||||
|
||||
self.agent = PPO(
|
||||
|
||||
16
train.py
16
train.py
@@ -1,4 +1,8 @@
|
||||
import pathlib
|
||||
|
||||
import hydra
|
||||
import hydra.utils as hydra_utils
|
||||
import structlog
|
||||
from clearml import Task
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
@@ -9,6 +13,8 @@ from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# ── env registry ──────────────────────────────────────────────────────
|
||||
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||
@@ -52,9 +58,15 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||
tags = [env_name, runner_name, training_name]
|
||||
|
||||
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||
task.set_base_docker("registry.kube.optimize/worker-image:latest")
|
||||
|
||||
if remote:
|
||||
task.execute_remotely(queue_name="default")
|
||||
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
||||
task.set_packages(str(req_file))
|
||||
|
||||
# Execute remotely if requested and running locally
|
||||
if remote and task.running_locally():
|
||||
logger.info("executing_task_remotely", queue="gpu-queue")
|
||||
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
137
viz.py
Normal file
137
viz.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
|
||||
|
||||
Usage:
|
||||
mjpython viz.py env=rotary_cartpole
|
||||
mjpython viz.py env=cartpole +com=true
|
||||
|
||||
Controls:
|
||||
Left/Right arrows — apply torque to first actuator
|
||||
R — reset environment
|
||||
Esc / close window — quit
|
||||
"""
|
||||
import math
|
||||
import time
|
||||
|
||||
import hydra
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
import structlog
|
||||
import torch
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import ActuatorConfig, BaseEnv, BaseEnvConfig
|
||||
from src.envs.cartpole import CartPoleConfig, CartPoleEnv
|
||||
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# ── registry (same as train.py) ──────────────────────────────────────
|
||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||
"cartpole": (CartPoleEnv, CartPoleConfig),
|
||||
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
||||
}
|
||||
|
||||
|
||||
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
||||
if env_name not in ENV_REGISTRY:
|
||||
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
||||
env_cls, config_cls = ENV_REGISTRY[env_name]
|
||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||
if "actuators" in env_dict:
|
||||
for a in env_dict["actuators"]:
|
||||
if "ctrl_range" in a:
|
||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
||||
return env_cls(config_cls(**env_dict))
|
||||
|
||||
|
||||
# ── keyboard state ───────────────────────────────────────────────────
|
||||
_action_val = [0.0] # mutable container shared with callback
|
||||
_action_time = [0.0] # timestamp of last key press
|
||||
_reset_flag = [False]
|
||||
_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
|
||||
|
||||
|
||||
def _key_callback(keycode: int) -> None:
|
||||
"""Called by MuJoCo on key press & repeat (not release)."""
|
||||
if keycode == 263: # GLFW_KEY_LEFT
|
||||
_action_val[0] = -1.0
|
||||
_action_time[0] = time.time()
|
||||
elif keycode == 262: # GLFW_KEY_RIGHT
|
||||
_action_val[0] = 1.0
|
||||
_action_time[0] = time.time()
|
||||
elif keycode == 82: # GLFW_KEY_R
|
||||
_reset_flag[0] = True
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "cartpole")
|
||||
|
||||
# Build env + runner (single env for viz)
|
||||
env = _build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
runner_dict["num_envs"] = 1
|
||||
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
||||
|
||||
model = runner._model
|
||||
data = runner._data[0]
|
||||
|
||||
# Control period
|
||||
dt_ctrl = runner.config.dt * runner.config.substeps
|
||||
|
||||
# Launch viewer
|
||||
with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
|
||||
# Show CoM / inertia if requested via Hydra override: viz.py +com=true
|
||||
show_com = cfg.get("com", False)
|
||||
if show_com:
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
|
||||
logger.info("viewer_started", env=env_name,
|
||||
controls="Left/Right arrows = torque, R = reset")
|
||||
|
||||
while viewer.is_running():
|
||||
# Read action from callback (expires after _ACTION_HOLD_S)
|
||||
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||
action_val = _action_val[0]
|
||||
else:
|
||||
action_val = 0.0
|
||||
|
||||
# Reset on R press
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
logger.info("reset")
|
||||
|
||||
# Step through runner
|
||||
action = torch.tensor([[action_val]])
|
||||
obs, reward, terminated, truncated, info = runner.step(action)
|
||||
|
||||
# Sync viewer
|
||||
mujoco.mj_forward(model, data)
|
||||
viewer.sync()
|
||||
|
||||
# Print state
|
||||
if step % 25 == 0:
|
||||
joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
|
||||
for i in range(model.njnt)}
|
||||
logger.debug("step", n=step, reward=round(reward.item(), 3),
|
||||
action=round(action_val, 1), **joints)
|
||||
|
||||
# Real-time pacing
|
||||
time.sleep(dt_ctrl)
|
||||
step += 1
|
||||
|
||||
runner.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user