diff --git a/assets/rotary_cartpole/rotary_cartpole.urdf b/assets/rotary_cartpole/rotary_cartpole.urdf index b68ed3b..327c023 100644 --- a/assets/rotary_cartpole/rotary_cartpole.urdf +++ b/assets/rotary_cartpole/rotary_cartpole.urdf @@ -36,9 +36,9 @@ - - + + @@ -73,7 +73,7 @@ Tip at (0.07, -0.07, 0) → 45° diagonal in +X/-Y. CoM = (5×0.035+10×0.07)/15 = 0.0583 along both +X and -Y. Inertia tensor rotated 45° to match diagonal rod axis. --> - + @@ -93,13 +93,14 @@ + Joint origin corrected from mesh analysis (Fusion2URDF was 180mm off). + rpy pitch +90° so qpos=0 = pendulum hanging down (gravity-stable). --> - + - + diff --git a/configs/env/cartpole.yaml b/configs/env/cartpole.yaml index f7d0722..4fb9b67 100644 --- a/configs/env/cartpole.yaml +++ b/configs/env/cartpole.yaml @@ -9,3 +9,4 @@ actuators: - joint: cart_joint gear: 10.0 ctrl_range: [-1.0, 1.0] + damping: 0.05 diff --git a/configs/env/rotary_cartpole.yaml b/configs/env/rotary_cartpole.yaml index c4fb61c..4d1a8bd 100644 --- a/configs/env/rotary_cartpole.yaml +++ b/configs/env/rotary_cartpole.yaml @@ -3,5 +3,6 @@ model_path: assets/rotary_cartpole/rotary_cartpole.urdf reward_upright_scale: 1.0 actuators: - joint: motor_joint - gear: 15.0 + gear: 0.5 ctrl_range: [-1.0, 1.0] + damping: 0.1 \ No newline at end of file diff --git a/configs/training/ppo.yaml b/configs/training/ppo.yaml index c078de6..d3d1786 100644 --- a/configs/training/ppo.yaml +++ b/configs/training/ppo.yaml @@ -9,7 +9,8 @@ learning_rate: 0.0003 clip_ratio: 0.2 value_loss_scale: 0.5 entropy_loss_scale: 0.05 -log_interval: 10 +log_interval: 1000 +checkpoint_interval: 50000 # ClearML remote execution (GPU worker) remote: false diff --git a/requirements.txt b/requirements.txt index f3ed7df..4ac07a4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ skrl[torch] clearml imageio imageio-ffmpeg +structlog pytest \ No newline at end of file diff --git a/src/core/env.py b/src/core/env.py index 80c5d96..d737b07 100644 --- a/src/core/env.py +++ b/src/core/env.py @@ -17,6 +17,7 @@ class ActuatorConfig: joint: str = "" gear: float = 1.0 ctrl_range: tuple[float, float] = (-1.0, 1.0) + damping: float = 0.05 # joint damping — limits max speed: vel_max ≈ torque / damping @dataclasses.dataclass diff --git a/src/envs/rotary_cartpole.py b/src/envs/rotary_cartpole.py index d70dcf2..402e3ce 100644 --- a/src/envs/rotary_cartpole.py +++ b/src/envs/rotary_cartpole.py @@ -66,21 +66,14 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]): ], dim=-1) def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor: - # height: sin(θ) → -1 (down) to +1 (up) - height = torch.sin(state.pendulum_angle) + # Upright reward: -cos(θ) ∈ [-1, +1] + upright = -torch.cos(state.pendulum_angle) - # Upright reward: strongly rewards being near vertical. - # Uses cos(θ - π/2) = sin(θ), squared and scaled so: - # down (h=-1): 0.0 - # horiz (h= 0): 0.0 - # up (h=+1): 1.0 - # Only kicks in above horizontal, so swing-up isn't penalised. - upright_reward = torch.clamp(height, 0.0, 1.0) ** 2 + # Velocity penalties — make spinning expensive but allow swing-up + pend_vel_penalty = 0.01 * state.pendulum_vel ** 2 + motor_vel_penalty = 0.01 * state.motor_vel ** 2 - # Motor effort penalty: small cost to avoid bang-bang control. - effort_penalty = 0.001 * actions.squeeze(-1) ** 2 - - return 5.0 * upright_reward - effort_penalty + return upright - pend_vel_penalty - motor_vel_penalty def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor: # No early termination — episode runs for max_steps (truncation only). @@ -88,7 +81,5 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]): return torch.zeros_like(state.motor_angle, dtype=torch.bool) def get_default_qpos(self, nq: int) -> list[float] | None: - # The STL mesh is horizontal at qpos=0. - # Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1). - import math - return [0.0, -math.pi / 2] + # qpos=0 = pendulum hanging down (joint frame rotated in URDF). + return None diff --git a/src/runners/mujoco.py b/src/runners/mujoco.py index 7cf7755..f8c79b5 100644 --- a/src/runners/mujoco.py +++ b/src/runners/mujoco.py @@ -1,11 +1,13 @@ import dataclasses -import os import xml.etree.ElementTree as ET +from pathlib import Path + +import mujoco +import numpy as np +import torch + from src.core.env import BaseEnv, ActuatorConfig from src.core.runner import BaseRunner, BaseRunnerConfig -import torch -import numpy as np -import mujoco @dataclasses.dataclass class MuJoCoRunnerConfig(BaseRunnerConfig): @@ -39,9 +41,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): This keeps the URDF clean and standard — actuator config lives in the env config (Isaac Lab pattern), not in the robot file. """ - abs_path = os.path.abspath(model_path) - model_dir = os.path.dirname(abs_path) - is_urdf = abs_path.lower().endswith(".urdf") + abs_path = Path(model_path).resolve() + model_dir = abs_path.parent + is_urdf = abs_path.suffix.lower() == ".urdf" # MuJoCo's URDF parser strips directory prefixes from mesh filenames, # so we inject a block into a @@ -53,9 +55,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): meshdir = None for mesh_el in root.iter("mesh"): fn = mesh_el.get("filename", "") - dirname = os.path.dirname(fn) - if dirname: - meshdir = dirname + parent = str(Path(fn).parent) + if parent and parent != ".": + meshdir = parent break if meshdir: mj_ext = ET.SubElement(root, "mujoco") @@ -63,25 +65,24 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): "meshdir": meshdir, "balanceinertia": "true", }) - tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf") - tree.write(tmp_urdf, xml_declaration=True, encoding="unicode") + tmp_urdf = model_dir / "_tmp_mujoco_load.urdf" + tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode") try: - model_raw = mujoco.MjModel.from_xml_path(tmp_urdf) + model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf)) finally: - os.unlink(tmp_urdf) + tmp_urdf.unlink() else: - model_raw = mujoco.MjModel.from_xml_path(abs_path) + model_raw = mujoco.MjModel.from_xml_path(str(abs_path)) if not actuators: return model_raw # Step 2: Export internal MJCF representation (save next to original # model so relative mesh/asset paths resolve correctly on reload) - tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml") + tmp_mjcf = model_dir / "_tmp_actuator_inject.xml" try: - mujoco.mj_saveLastXML(tmp_mjcf, model_raw) - with open(tmp_mjcf) as f: - mjcf_str = f.read() + mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw) + mjcf_str = tmp_mjcf.read_text() # Step 3: Inject actuators into the MJCF XML # Use torque actuator. Speed is limited by joint damping: @@ -98,12 +99,13 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): # Add damping to actuated joints to limit max speed and # mimic real motor friction / back-EMF. - # vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s) - actuated_joints = {a.joint for a in actuators} + # vel_max ≈ max_torque / damping + joint_damping = {a.joint: a.damping for a in actuators} for body in root.iter("body"): for jnt in body.findall("joint"): - if jnt.get("name") in actuated_joints: - jnt.set("damping", "0.05") + name = jnt.get("name") + if name in joint_damping: + jnt.set("damping", str(joint_damping[name])) # Disable self-collision on all geoms. # URDF mesh convex hulls often overlap at joints (especially @@ -116,12 +118,10 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): # Step 4: Write modified MJCF and reload from file path # (from_xml_path resolves mesh paths relative to the file location) modified_xml = ET.tostring(root, encoding="unicode") - with open(tmp_mjcf, "w") as f: - f.write(modified_xml) - return mujoco.MjModel.from_xml_path(tmp_mjcf) + tmp_mjcf.write_text(modified_xml) + return mujoco.MjModel.from_xml_path(str(tmp_mjcf)) finally: - if os.path.exists(tmp_mjcf): - os.unlink(tmp_mjcf) + tmp_mjcf.unlink(missing_ok=True) def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None: model_path = self.env.config.model_path diff --git a/src/training/trainer.py b/src/training/trainer.py index 3315156..5230049 100644 --- a/src/training/trainer.py +++ b/src/training/trainer.py @@ -35,6 +35,7 @@ class TrainerConfig: # Training total_timesteps: int = 1_000_000 log_interval: int = 10 + checkpoint_interval: int = 50_000 # Video recording (uploaded to ClearML) record_video_every: int = 10_000 # 0 = disabled @@ -196,7 +197,7 @@ class Trainer: # log_interval=1 → log every PPO update (= every rollout_steps timesteps). agent_cfg["experiment"]["write_interval"] = self.config.log_interval agent_cfg["experiment"]["checkpoint_interval"] = max( - self.config.total_timesteps // 10, self.config.rollout_steps + self.config.checkpoint_interval, self.config.rollout_steps ) self.agent = PPO( diff --git a/train.py b/train.py index b133bbc..8d16e80 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,8 @@ +import pathlib + import hydra +import hydra.utils as hydra_utils +import structlog from clearml import Task from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf @@ -9,6 +13,8 @@ from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig from src.training.trainer import Trainer, TrainerConfig +logger = structlog.get_logger() + # ── env registry ────────────────────────────────────────────────────── # Maps Hydra config-group name → (EnvClass, ConfigClass) ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = { @@ -52,9 +58,15 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task: tags = [env_name, runner_name, training_name] task = Task.init(project_name=project, task_name=task_name, tags=tags) + task.set_base_docker("registry.kube.optimize/worker-image:latest") - if remote: - task.execute_remotely(queue_name="default") + req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt" + task.set_packages(str(req_file)) + + # Execute remotely if requested and running locally + if remote and task.running_locally(): + logger.info("executing_task_remotely", queue="gpu-queue") + task.execute_remotely(queue_name="gpu-queue", exit_process=True) return task diff --git a/viz.py b/viz.py new file mode 100644 index 0000000..d4187d4 --- /dev/null +++ b/viz.py @@ -0,0 +1,137 @@ +"""Interactive visualization — control any env with keyboard in MuJoCo viewer. + +Usage: + mjpython viz.py env=rotary_cartpole + mjpython viz.py env=cartpole +com=true + +Controls: + Left/Right arrows — apply torque to first actuator + R — reset environment + Esc / close window — quit +""" +import math +import time + +import hydra +import mujoco +import mujoco.viewer +import structlog +import torch +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf + +from src.core.env import ActuatorConfig, BaseEnv, BaseEnvConfig +from src.envs.cartpole import CartPoleConfig, CartPoleEnv +from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv +from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig + +logger = structlog.get_logger() + +# ── registry (same as train.py) ────────────────────────────────────── +ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = { + "cartpole": (CartPoleEnv, CartPoleConfig), + "rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig), +} + + +def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv: + if env_name not in ENV_REGISTRY: + raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}") + env_cls, config_cls = ENV_REGISTRY[env_name] + env_dict = OmegaConf.to_container(cfg.env, resolve=True) + if "actuators" in env_dict: + for a in env_dict["actuators"]: + if "ctrl_range" in a: + a["ctrl_range"] = tuple(a["ctrl_range"]) + env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]] + return env_cls(config_cls(**env_dict)) + + +# ── keyboard state ─────────────────────────────────────────────────── +_action_val = [0.0] # mutable container shared with callback +_action_time = [0.0] # timestamp of last key press +_reset_flag = [False] +_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event + + +def _key_callback(keycode: int) -> None: + """Called by MuJoCo on key press & repeat (not release).""" + if keycode == 263: # GLFW_KEY_LEFT + _action_val[0] = -1.0 + _action_time[0] = time.time() + elif keycode == 262: # GLFW_KEY_RIGHT + _action_val[0] = 1.0 + _action_time[0] = time.time() + elif keycode == 82: # GLFW_KEY_R + _reset_flag[0] = True + + +@hydra.main(version_base=None, config_path="configs", config_name="config") +def main(cfg: DictConfig) -> None: + choices = HydraConfig.get().runtime.choices + env_name = choices.get("env", "cartpole") + + # Build env + runner (single env for viz) + env = _build_env(env_name, cfg) + runner_dict = OmegaConf.to_container(cfg.runner, resolve=True) + runner_dict["num_envs"] = 1 + runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict)) + + model = runner._model + data = runner._data[0] + + # Control period + dt_ctrl = runner.config.dt * runner.config.substeps + + # Launch viewer + with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer: + # Show CoM / inertia if requested via Hydra override: viz.py +com=true + show_com = cfg.get("com", False) + if show_com: + viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True + viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True + + obs, _ = runner.reset() + step = 0 + + logger.info("viewer_started", env=env_name, + controls="Left/Right arrows = torque, R = reset") + + while viewer.is_running(): + # Read action from callback (expires after _ACTION_HOLD_S) + if time.time() - _action_time[0] < _ACTION_HOLD_S: + action_val = _action_val[0] + else: + action_val = 0.0 + + # Reset on R press + if _reset_flag[0]: + _reset_flag[0] = False + obs, _ = runner.reset() + step = 0 + logger.info("reset") + + # Step through runner + action = torch.tensor([[action_val]]) + obs, reward, terminated, truncated, info = runner.step(action) + + # Sync viewer + mujoco.mj_forward(model, data) + viewer.sync() + + # Print state + if step % 25 == 0: + joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1) + for i in range(model.njnt)} + logger.debug("step", n=step, reward=round(reward.item(), 3), + action=round(action_val, 1), **joints) + + # Real-time pacing + time.sleep(dt_ctrl) + step += 1 + + runner.close() + + +if __name__ == "__main__": + main()