♻️ crazy refactor
This commit is contained in:
23
assets/rotary_cartpole/hardware.yaml
Normal file
23
assets/rotary_cartpole/hardware.yaml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Rotary cartpole (Furuta pendulum) — real hardware config.
|
||||||
|
# Describes the physical device for the SerialRunner.
|
||||||
|
# Robot-specific constants that don't belong in the runner config
|
||||||
|
# (which is machine-specific: port, baud) or the env config
|
||||||
|
# (which is task-specific: rewards, max_steps).
|
||||||
|
|
||||||
|
encoder:
|
||||||
|
ppr: 11 # pulses per revolution (before quadrature)
|
||||||
|
gear_ratio: 30.0 # gearbox ratio
|
||||||
|
# counts_per_rev = ppr × gear_ratio × 4 (quadrature) = 1320
|
||||||
|
|
||||||
|
safety:
|
||||||
|
max_motor_angle_deg: 90.0 # hard termination limit (0 = disabled)
|
||||||
|
soft_limit_deg: 40.0 # progressive penalty ramp starts here
|
||||||
|
|
||||||
|
reset:
|
||||||
|
drive_speed: 80 # PWM magnitude for bang-bang drive-to-center
|
||||||
|
deadband: 15 # encoder count threshold to consider "centered"
|
||||||
|
drive_timeout: 3.0 # seconds before giving up on drive-to-center
|
||||||
|
settle_angle_deg: 2.0 # pendulum angle threshold for "still" (degrees)
|
||||||
|
settle_vel_dps: 5.0 # pendulum velocity threshold (deg/s)
|
||||||
|
settle_duration: 0.5 # how long pendulum must stay still (seconds)
|
||||||
|
settle_timeout: 30.0 # give up waiting after this (seconds)
|
||||||
BIN
assets/rotary_cartpole/recordings/capture_20260311_215608.npz
Normal file
BIN
assets/rotary_cartpole/recordings/capture_20260311_215608.npz
Normal file
Binary file not shown.
@@ -1,19 +1,20 @@
|
|||||||
# Rotary cartpole (Furuta pendulum) — robot hardware config.
|
# Tuned robot config — generated by src.sysid.optimize
|
||||||
# Lives next to the URDF so all robot-specific settings are in one place.
|
# Original: robot.yaml
|
||||||
|
# Run `python -m src.sysid.visualize` to compare real vs sim.
|
||||||
|
|
||||||
urdf: rotary_cartpole.urdf
|
urdf: rotary_cartpole.urdf
|
||||||
|
|
||||||
actuators:
|
actuators:
|
||||||
- joint: motor_joint
|
- joint: motor_joint
|
||||||
type: motor # direct torque control
|
type: motor
|
||||||
gear: 0.064 # stall torque @ 58.8% PWM: 0.108 × 150/255 = 0.064 N·m
|
gear: 0.176692
|
||||||
ctrl_range: [-1.0, 1.0]
|
ctrl_range:
|
||||||
damping: 0.003 # viscous back-EMF only (small)
|
- -1.0
|
||||||
filter_tau: 0.03 # mechanical time constant ~30ms (37mm gearmotor)
|
- 1.0
|
||||||
|
damping: 0.009505
|
||||||
|
filter_tau: 0.040906
|
||||||
joints:
|
joints:
|
||||||
motor_joint:
|
motor_joint:
|
||||||
armature: 0.0001 # reflected rotor inertia: ~1e-7 × 30² = 9e-5 kg·m²
|
armature: 0.001389
|
||||||
frictionloss: 0.03 # disabled — may slow MJX (constraint-based solver)
|
frictionloss: 0.002179
|
||||||
pendulum_joint:
|
pendulum_joint:
|
||||||
damping: 0.0001 # bearing friction
|
damping: 6.1e-05
|
||||||
|
|||||||
@@ -1,16 +1,11 @@
|
|||||||
<?xml version="1.0" encoding="utf-8"?>
|
<?xml version='1.0' encoding='utf-8'?>
|
||||||
<robot name="rotary_cartpole">
|
<robot name="rotary_cartpole">
|
||||||
|
|
||||||
<!-- Fixed world frame -->
|
|
||||||
<link name="world" />
|
<link name="world" />
|
||||||
|
|
||||||
<!-- Base: motor housing, fixed to world -->
|
|
||||||
<link name="base_link">
|
<link name="base_link">
|
||||||
<inertial>
|
<inertial>
|
||||||
<origin xyz="-0.00011 0.00117 0.06055" rpy="0 0 0" />
|
<origin xyz="-0.00011 0.00117 0.06055" rpy="0 0 0" />
|
||||||
<mass value="0.921" />
|
<mass value="0.921" />
|
||||||
<inertia ixx="0.002385" iyy="0.002484" izz="0.000559"
|
<inertia ixx="0.002385" iyy="0.002484" izz="0.000559" ixy="0.0" iyz="-0.000149" ixz="6e-06" />
|
||||||
ixy="0.0" iyz="-0.000149" ixz="6e-06"/>
|
|
||||||
</inertial>
|
</inertial>
|
||||||
<visual>
|
<visual>
|
||||||
<origin xyz="0 0 0" rpy="0 0 0" />
|
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||||
@@ -25,20 +20,15 @@
|
|||||||
</geometry>
|
</geometry>
|
||||||
</collision>
|
</collision>
|
||||||
</link>
|
</link>
|
||||||
|
|
||||||
<joint name="base_joint" type="fixed">
|
<joint name="base_joint" type="fixed">
|
||||||
<parent link="world" />
|
<parent link="world" />
|
||||||
<child link="base_link" />
|
<child link="base_link" />
|
||||||
</joint>
|
</joint>
|
||||||
|
|
||||||
<!-- Arm: horizontal rotating arm driven by motor.
|
|
||||||
Real mass ~10g (Fusion assumed dense material, exported 279g). -->
|
|
||||||
<link name="arm">
|
<link name="arm">
|
||||||
<inertial>
|
<inertial>
|
||||||
<origin xyz="0.00005 0.0065 0.00563" rpy="0 0 0"/>
|
<origin xyz="0.014950488360794875 0.006089886527968399 0.004470745447817278" rpy="0 0 0" />
|
||||||
<mass value="0.010"/>
|
<mass value="0.012391951282440451" />
|
||||||
<inertia ixx="2.70e-06" iyy="7.80e-07" izz="2.44e-06"
|
<inertia ixx="2.70e-06" iyy="7.80e-07" izz="2.44e-06" ixy="0.0" iyz="7.20e-08" ixz="0.0" />
|
||||||
ixy="0.0" iyz="7.20e-08" ixz="0.0"/>
|
|
||||||
</inertial>
|
</inertial>
|
||||||
<visual>
|
<visual>
|
||||||
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
|
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
|
||||||
@@ -53,8 +43,6 @@
|
|||||||
</geometry>
|
</geometry>
|
||||||
</collision>
|
</collision>
|
||||||
</link>
|
</link>
|
||||||
|
|
||||||
<!-- Motor joint: base → arm, rotates around vertical z-axis -->
|
|
||||||
<joint name="motor_joint" type="revolute">
|
<joint name="motor_joint" type="revolute">
|
||||||
<origin xyz="-0.006947 0.00395 0.14796" rpy="0 0 0" />
|
<origin xyz="-0.006947 0.00395 0.14796" rpy="0 0 0" />
|
||||||
<parent link="base_link" />
|
<parent link="base_link" />
|
||||||
@@ -63,20 +51,11 @@
|
|||||||
<limit lower="-1.5708" upper="1.5708" effort="10.0" velocity="200.0" />
|
<limit lower="-1.5708" upper="1.5708" effort="10.0" velocity="200.0" />
|
||||||
<dynamics damping="0.001" />
|
<dynamics damping="0.001" />
|
||||||
</joint>
|
</joint>
|
||||||
|
|
||||||
<!-- Pendulum: swings freely at the end of the arm.
|
|
||||||
Real mass: 5g pendulum + 10g weight at the tip (70mm from bearing) = 15g total.
|
|
||||||
(Fusion assumed dense material, exported 57g for the pendulum alone.) -->
|
|
||||||
<link name="pendulum">
|
<link name="pendulum">
|
||||||
<inertial>
|
<inertial>
|
||||||
<!-- Combined CoM: 5g rod (CoM ~35mm) + 10g tip weight at 70mm from pivot.
|
<origin xyz="0.06432778588634695 -0.05999895841669392 0.0008769789937631209" rpy="0 0 0" />
|
||||||
Tip at (0.07, -0.07, 0) → 45° diagonal in +X/-Y.
|
<mass value="0.035508993892747365" />
|
||||||
CoM = (5×0.035+10×0.07)/15 = 0.0583 along both +X and -Y.
|
<inertia ixx="3.139576982078822e-05" iyy="9.431951659638859e-06" izz="4.07315891863556e-05" ixy="-1.8892943833253423e-06" iyz="0.0" ixz="0.0" />
|
||||||
Inertia tensor rotated 45° to match diagonal rod axis. -->
|
|
||||||
<origin xyz="0.1583 -0.0983 -0.0" rpy="0 0 0"/>
|
|
||||||
<mass value="0.015"/>
|
|
||||||
<inertia ixx="6.16e-06" iyy="6.16e-06" izz="1.23e-05"
|
|
||||||
ixy="6.10e-06" iyz="0.0" ixz="0.0"/>
|
|
||||||
</inertial>
|
</inertial>
|
||||||
<visual>
|
<visual>
|
||||||
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
|
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
|
||||||
@@ -91,10 +70,6 @@
|
|||||||
</geometry>
|
</geometry>
|
||||||
</collision>
|
</collision>
|
||||||
</link>
|
</link>
|
||||||
|
|
||||||
<!-- Pendulum joint: arm → pendulum, bearing axis along Y.
|
|
||||||
Joint origin corrected from mesh analysis (Fusion2URDF was 180mm off).
|
|
||||||
rpy pitch +90° so qpos=0 = pendulum hanging down (gravity-stable). -->
|
|
||||||
<joint name="pendulum_joint" type="continuous">
|
<joint name="pendulum_joint" type="continuous">
|
||||||
<origin xyz="0.000052 0.019274 0.014993" rpy="0 1.5708 0" />
|
<origin xyz="0.000052 0.019274 0.014993" rpy="0 1.5708 0" />
|
||||||
<parent link="arm" />
|
<parent link="arm" />
|
||||||
@@ -102,5 +77,4 @@
|
|||||||
<axis xyz="0 -1 0" />
|
<axis xyz="0 -1 0" />
|
||||||
<dynamics damping="0.0001" />
|
<dynamics damping="0.0001" />
|
||||||
</joint>
|
</joint>
|
||||||
|
|
||||||
</robot>
|
</robot>
|
||||||
70
assets/rotary_cartpole/sysid_result.json
Normal file
70
assets/rotary_cartpole/sysid_result.json
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
{
|
||||||
|
"best_params": {
|
||||||
|
"arm_mass": 0.012391951282440451,
|
||||||
|
"arm_com_x": 0.014950488360794875,
|
||||||
|
"arm_com_y": 0.006089886527968399,
|
||||||
|
"arm_com_z": 0.004470745447817278,
|
||||||
|
"pendulum_mass": 0.035508993892747365,
|
||||||
|
"pendulum_com_x": 0.06432778588634695,
|
||||||
|
"pendulum_com_y": -0.05999895841669392,
|
||||||
|
"pendulum_com_z": 0.0008769789937631209,
|
||||||
|
"pendulum_ixx": 3.139576982078822e-05,
|
||||||
|
"pendulum_iyy": 9.431951659638859e-06,
|
||||||
|
"pendulum_izz": 4.07315891863556e-05,
|
||||||
|
"pendulum_ixy": -1.8892943833253423e-06,
|
||||||
|
"actuator_gear": 0.17669161390939517,
|
||||||
|
"actuator_filter_tau": 0.040905643692382504,
|
||||||
|
"motor_damping": 0.009504542103348917,
|
||||||
|
"pendulum_damping": 6.128535042404019e-05,
|
||||||
|
"motor_armature": 0.0013894759540138252,
|
||||||
|
"motor_frictionloss": 0.002179448047511452
|
||||||
|
},
|
||||||
|
"best_cost": 0.7471380533090072,
|
||||||
|
"recording": "/Users/victormylle/Library/CloudStorage/SeaDrive-VictorMylle(cloud.optimize-it.be)/My Libraries/Projects/AI/RL-Framework/assets/rotary_cartpole/recordings/capture_20260311_215608.npz",
|
||||||
|
"param_names": [
|
||||||
|
"arm_mass",
|
||||||
|
"arm_com_x",
|
||||||
|
"arm_com_y",
|
||||||
|
"arm_com_z",
|
||||||
|
"pendulum_mass",
|
||||||
|
"pendulum_com_x",
|
||||||
|
"pendulum_com_y",
|
||||||
|
"pendulum_com_z",
|
||||||
|
"pendulum_ixx",
|
||||||
|
"pendulum_iyy",
|
||||||
|
"pendulum_izz",
|
||||||
|
"pendulum_ixy",
|
||||||
|
"actuator_gear",
|
||||||
|
"actuator_filter_tau",
|
||||||
|
"motor_damping",
|
||||||
|
"pendulum_damping",
|
||||||
|
"motor_armature",
|
||||||
|
"motor_frictionloss"
|
||||||
|
],
|
||||||
|
"defaults": {
|
||||||
|
"arm_mass": 0.01,
|
||||||
|
"arm_com_x": 5e-05,
|
||||||
|
"arm_com_y": 0.0065,
|
||||||
|
"arm_com_z": 0.00563,
|
||||||
|
"pendulum_mass": 0.015,
|
||||||
|
"pendulum_com_x": 0.1583,
|
||||||
|
"pendulum_com_y": -0.0983,
|
||||||
|
"pendulum_com_z": 0.0,
|
||||||
|
"pendulum_ixx": 6.16e-06,
|
||||||
|
"pendulum_iyy": 6.16e-06,
|
||||||
|
"pendulum_izz": 1.23e-05,
|
||||||
|
"pendulum_ixy": 6.1e-06,
|
||||||
|
"actuator_gear": 0.064,
|
||||||
|
"actuator_filter_tau": 0.03,
|
||||||
|
"motor_damping": 0.003,
|
||||||
|
"pendulum_damping": 0.0001,
|
||||||
|
"motor_armature": 0.0001,
|
||||||
|
"motor_frictionloss": 0.03
|
||||||
|
},
|
||||||
|
"timestamp": "2026-03-11T22:08:04.782736",
|
||||||
|
"history_summary": {
|
||||||
|
"first_cost": 3.909456214944022,
|
||||||
|
"final_cost": 0.7471380533090072,
|
||||||
|
"generations": 200
|
||||||
|
}
|
||||||
|
}
|
||||||
7
configs/env/rotary_cartpole.yaml
vendored
7
configs/env/rotary_cartpole.yaml
vendored
@@ -1,3 +1,10 @@
|
|||||||
max_steps: 1000
|
max_steps: 1000
|
||||||
robot_path: assets/rotary_cartpole
|
robot_path: assets/rotary_cartpole
|
||||||
reward_upright_scale: 1.0
|
reward_upright_scale: 1.0
|
||||||
|
speed_penalty_scale: 0.1
|
||||||
|
|
||||||
|
# ── HPO search ranges ────────────────────────────────────────────────
|
||||||
|
hpo:
|
||||||
|
reward_upright_scale: {min: 0.5, max: 5.0}
|
||||||
|
speed_penalty_scale: {min: 0.01, max: 1.0}
|
||||||
|
max_steps: {values: [500, 1000, 2000]}
|
||||||
7
configs/runner/mujoco_single.yaml
Normal file
7
configs/runner/mujoco_single.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Single-env MuJoCo runner — mimics real hardware timing.
|
||||||
|
# dt × substeps = 0.002 × 10 = 0.02 s → 50 Hz control, same as serial runner.
|
||||||
|
|
||||||
|
num_envs: 1
|
||||||
|
device: cpu
|
||||||
|
dt: 0.002
|
||||||
|
substeps: 10
|
||||||
11
configs/runner/serial.yaml
Normal file
11
configs/runner/serial.yaml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# Serial runner — communicates with real hardware over USB/serial.
|
||||||
|
# Always single-env, CPU-only. Override port on CLI:
|
||||||
|
# python train.py runner=serial runner.port=/dev/ttyUSB0
|
||||||
|
|
||||||
|
num_envs: 1
|
||||||
|
device: cpu
|
||||||
|
port: /dev/cu.usbserial-0001
|
||||||
|
baud: 115200
|
||||||
|
dt: 0.02 # control loop period (50 Hz)
|
||||||
|
no_data_timeout: 2.0 # seconds of silence before declaring disconnect
|
||||||
|
encoder_jump_threshold: 200 # encoder tick jump → reboot detection
|
||||||
25
configs/sysid.yaml
Normal file
25
configs/sysid.yaml
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# System identification defaults.
|
||||||
|
# Override via CLI: python -m src.sysid.optimize sysid.max_generations=50
|
||||||
|
#
|
||||||
|
# These are NOT Hydra config groups — the sysid scripts use argparse.
|
||||||
|
# This file serves as documentation and can be loaded by custom wrappers.
|
||||||
|
|
||||||
|
capture:
|
||||||
|
port: /dev/cu.usbserial-0001
|
||||||
|
baud: 115200
|
||||||
|
duration: 20.0 # seconds
|
||||||
|
amplitude: 180 # max PWM magnitude (0–255)
|
||||||
|
hold_min_ms: 50 # PRBS min hold time
|
||||||
|
hold_max_ms: 300 # PRBS max hold time
|
||||||
|
dt: 0.02 # sample period (50 Hz)
|
||||||
|
|
||||||
|
optimize:
|
||||||
|
sigma0: 0.3 # CMA-ES initial step size (in [0,1] normalised space)
|
||||||
|
population_size: 20 # candidates per generation
|
||||||
|
max_generations: 200 # total generations (~4000 evaluations)
|
||||||
|
sim_dt: 0.002 # MuJoCo physics timestep
|
||||||
|
substeps: 10 # physics substeps per control step (ctrl_dt = 0.02s)
|
||||||
|
pos_weight: 1.0 # MSE weight for angle errors
|
||||||
|
vel_weight: 0.1 # MSE weight for velocity errors
|
||||||
|
window_duration: 0.5 # multiple-shooting window length (s); 0 = open-loop
|
||||||
|
seed: 42
|
||||||
@@ -12,5 +12,23 @@ entropy_loss_scale: 0.05
|
|||||||
log_interval: 1000
|
log_interval: 1000
|
||||||
checkpoint_interval: 50000
|
checkpoint_interval: 50000
|
||||||
|
|
||||||
|
initial_log_std: 0.5
|
||||||
|
min_log_std: -2.0
|
||||||
|
max_log_std: 2.0
|
||||||
|
|
||||||
|
record_video_every: 10000
|
||||||
|
|
||||||
# ClearML remote execution (GPU worker)
|
# ClearML remote execution (GPU worker)
|
||||||
remote: false
|
remote: false
|
||||||
|
|
||||||
|
# ── HPO search ranges ────────────────────────────────────────────────
|
||||||
|
# Read by scripts/hpo.py — ignored by TrainerConfig during training.
|
||||||
|
hpo:
|
||||||
|
learning_rate: {min: 0.00005, max: 0.001}
|
||||||
|
clip_ratio: {min: 0.1, max: 0.3}
|
||||||
|
discount_factor: {min: 0.98, max: 0.999}
|
||||||
|
gae_lambda: {min: 0.9, max: 0.99}
|
||||||
|
entropy_loss_scale: {min: 0.0001, max: 0.1}
|
||||||
|
value_loss_scale: {min: 0.1, max: 1.0}
|
||||||
|
learning_epochs: {min: 2, max: 8, type: int}
|
||||||
|
mini_batches: {values: [2, 4, 8, 16]}
|
||||||
|
|||||||
@@ -1,22 +1,18 @@
|
|||||||
# PPO tuned for MJX (1024+ parallel envs on GPU).
|
# PPO tuned for MJX (1024+ parallel envs on GPU).
|
||||||
|
# Inherits defaults + HPO ranges from ppo.yaml.
|
||||||
# With 1024 envs, each timestep collects 1024 samples, so total_timesteps
|
# With 1024 envs, each timestep collects 1024 samples, so total_timesteps
|
||||||
# can be much lower than the CPU config.
|
# can be much lower than the CPU config.
|
||||||
|
|
||||||
hidden_sizes: [128, 128]
|
defaults:
|
||||||
|
- ppo
|
||||||
|
- _self_
|
||||||
|
|
||||||
total_timesteps: 300000 # 300K × 1024 envs ≈ 307M env steps
|
total_timesteps: 300000 # 300K × 1024 envs ≈ 307M env steps
|
||||||
rollout_steps: 1024 # PPO batch = 1024 envs × 1024 steps = 1M samples
|
mini_batches: 32 # keep mini-batch size similar (~32K)
|
||||||
learning_epochs: 4
|
|
||||||
mini_batches: 32 # keep mini-batch size similar to CPU config (~32K)
|
|
||||||
discount_factor: 0.99
|
|
||||||
gae_lambda: 0.95
|
|
||||||
learning_rate: 0.001 # ~3x higher LR for 16x larger batch (sqrt scaling)
|
learning_rate: 0.001 # ~3x higher LR for 16x larger batch (sqrt scaling)
|
||||||
clip_ratio: 0.2
|
log_interval: 100
|
||||||
value_loss_scale: 0.5
|
|
||||||
entropy_loss_scale: 0.05
|
|
||||||
log_interval: 100 # log more often (shorter run)
|
|
||||||
checkpoint_interval: 10000
|
checkpoint_interval: 10000
|
||||||
|
|
||||||
record_video_every: 10000
|
record_video_every: 10000
|
||||||
|
|
||||||
# ClearML remote execution (GPU worker)
|
|
||||||
remote: false
|
remote: false
|
||||||
|
|||||||
27
configs/training/ppo_real.yaml
Normal file
27
configs/training/ppo_real.yaml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# PPO tuned for single-env real-time training on real hardware.
|
||||||
|
# Inherits defaults + HPO ranges from ppo.yaml.
|
||||||
|
# ~50 Hz control × 1 env = ~50 timesteps/s.
|
||||||
|
# 100k timesteps ≈ 33 minutes of wall-clock training.
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- ppo
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
hidden_sizes: [256, 256]
|
||||||
|
total_timesteps: 100000
|
||||||
|
learning_epochs: 5
|
||||||
|
learning_rate: 0.001 # conservative — can't undo real-world damage
|
||||||
|
entropy_loss_scale: 0.0001
|
||||||
|
log_interval: 1024
|
||||||
|
checkpoint_interval: 5000 # frequent saves — can't rewind real hardware
|
||||||
|
initial_log_std: -0.5 # moderate initial exploration
|
||||||
|
min_log_std: -4.0
|
||||||
|
max_log_std: 0.0 # cap σ at 1.0
|
||||||
|
|
||||||
|
# Never run real-hardware training remotely
|
||||||
|
remote: false
|
||||||
|
|
||||||
|
# Tighter HPO ranges for real hardware (override base ppo.yaml ranges)
|
||||||
|
hpo:
|
||||||
|
entropy_loss_scale: {min: 0.00005, max: 0.001}
|
||||||
|
learning_rate: {min: 0.0003, max: 0.003}
|
||||||
23
configs/training/ppo_single.yaml
Normal file
23
configs/training/ppo_single.yaml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# PPO tuned for single-env simulation — mimics real hardware training.
|
||||||
|
# Inherits defaults + HPO ranges from ppo.yaml.
|
||||||
|
# Same 50 Hz control (runner=mujoco_single), 1 env, conservative hypers.
|
||||||
|
# Sim runs ~100× faster than real time, so we can afford more timesteps.
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- ppo
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
hidden_sizes: [256, 256]
|
||||||
|
total_timesteps: 500000
|
||||||
|
learning_epochs: 5
|
||||||
|
learning_rate: 0.001
|
||||||
|
entropy_loss_scale: 0.0001
|
||||||
|
log_interval: 1024
|
||||||
|
checkpoint_interval: 10000
|
||||||
|
initial_log_std: -0.5
|
||||||
|
min_log_std: -4.0
|
||||||
|
max_log_std: 0.0
|
||||||
|
|
||||||
|
record_video_every: 50000
|
||||||
|
|
||||||
|
remote: false
|
||||||
@@ -11,4 +11,10 @@ imageio
|
|||||||
imageio-ffmpeg
|
imageio-ffmpeg
|
||||||
structlog
|
structlog
|
||||||
pyyaml
|
pyyaml
|
||||||
|
pyserial
|
||||||
|
cmaes
|
||||||
|
matplotlib
|
||||||
|
smac>=2.0.0
|
||||||
|
ConfigSpace
|
||||||
|
hpbandster
|
||||||
pytest
|
pytest
|
||||||
340
scripts/hpo.py
Normal file
340
scripts/hpo.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
"""Hyperparameter optimization for RL-Framework using ClearML + SMAC3.
|
||||||
|
|
||||||
|
Automatically creates a base training task (via Task.create), reads HPO
|
||||||
|
search ranges from the Hydra config's `training.hpo` and `env.hpo` blocks,
|
||||||
|
and launches SMAC3 Successive Halving optimization.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/hpo.py \
|
||||||
|
--env rotary_cartpole \
|
||||||
|
--runner mujoco_single \
|
||||||
|
--training ppo_single \
|
||||||
|
--queue gpu-queue
|
||||||
|
|
||||||
|
# Or use an existing base task:
|
||||||
|
python scripts/hpo.py --base-task-id <TASK_ID>
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure project root is on sys.path
|
||||||
|
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||||
|
if _PROJECT_ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, _PROJECT_ROOT)
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from clearml import Task
|
||||||
|
from clearml.automation import (
|
||||||
|
DiscreteParameterRange,
|
||||||
|
HyperParameterOptimizer,
|
||||||
|
UniformIntegerParameterRange,
|
||||||
|
UniformParameterRange,
|
||||||
|
)
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _load_hydra_config(
|
||||||
|
env: str, runner: str, training: str
|
||||||
|
) -> dict:
|
||||||
|
"""Load and merge Hydra configs to extract HPO ranges.
|
||||||
|
|
||||||
|
We read the YAML files directly (without running Hydra) so this script
|
||||||
|
doesn't need @hydra.main — it's a ClearML optimizer, not a training job.
|
||||||
|
"""
|
||||||
|
configs_dir = Path(__file__).resolve().parent.parent / "configs"
|
||||||
|
|
||||||
|
# Load training config (handles defaults: [ppo] inheritance)
|
||||||
|
training_path = configs_dir / "training" / f"{training}.yaml"
|
||||||
|
training_cfg = OmegaConf.load(training_path)
|
||||||
|
|
||||||
|
# If the training config has defaults pointing to a base, load + merge
|
||||||
|
if "defaults" in training_cfg:
|
||||||
|
defaults = OmegaConf.to_container(training_cfg.defaults)
|
||||||
|
base_cfg = OmegaConf.create({})
|
||||||
|
for d in defaults:
|
||||||
|
if isinstance(d, str):
|
||||||
|
base_path = configs_dir / "training" / f"{d}.yaml"
|
||||||
|
if base_path.exists():
|
||||||
|
loaded = OmegaConf.load(base_path)
|
||||||
|
base_cfg = OmegaConf.merge(base_cfg, loaded)
|
||||||
|
# Remove defaults key and merge
|
||||||
|
training_no_defaults = {
|
||||||
|
k: v for k, v in OmegaConf.to_container(training_cfg).items()
|
||||||
|
if k != "defaults"
|
||||||
|
}
|
||||||
|
training_cfg = OmegaConf.merge(base_cfg, OmegaConf.create(training_no_defaults))
|
||||||
|
|
||||||
|
# Load env config
|
||||||
|
env_path = configs_dir / "env" / f"{env}.yaml"
|
||||||
|
env_cfg = OmegaConf.load(env_path) if env_path.exists() else OmegaConf.create({})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"training": OmegaConf.to_container(training_cfg, resolve=True),
|
||||||
|
"env": OmegaConf.to_container(env_cfg, resolve=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_hyper_parameters(config: dict) -> list:
|
||||||
|
"""Build ClearML parameter ranges from hpo: blocks in config.
|
||||||
|
|
||||||
|
Reads training.hpo and env.hpo dicts and creates appropriate
|
||||||
|
ClearML parameter range objects.
|
||||||
|
|
||||||
|
Each hpo entry can have:
|
||||||
|
{min, max} → UniformParameterRange (float)
|
||||||
|
{min, max, type: int} → UniformIntegerParameterRange
|
||||||
|
{min, max, log: true} → UniformParameterRange with log scale
|
||||||
|
{values: [...]} → DiscreteParameterRange
|
||||||
|
"""
|
||||||
|
params = []
|
||||||
|
|
||||||
|
for section in ("training", "env"):
|
||||||
|
hpo_ranges = config.get(section, {}).get("hpo", {})
|
||||||
|
if not hpo_ranges:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for param_name, spec in hpo_ranges.items():
|
||||||
|
hydra_key = f"Hydra/{section}.{param_name}"
|
||||||
|
|
||||||
|
if "values" in spec:
|
||||||
|
params.append(
|
||||||
|
DiscreteParameterRange(hydra_key, values=spec["values"])
|
||||||
|
)
|
||||||
|
elif "min" in spec and "max" in spec:
|
||||||
|
if spec.get("type") == "int":
|
||||||
|
params.append(
|
||||||
|
UniformIntegerParameterRange(
|
||||||
|
hydra_key,
|
||||||
|
min_value=int(spec["min"]),
|
||||||
|
max_value=int(spec["max"]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
step = spec.get("step", None)
|
||||||
|
params.append(
|
||||||
|
UniformParameterRange(
|
||||||
|
hydra_key,
|
||||||
|
min_value=float(spec["min"]),
|
||||||
|
max_value=float(spec["max"]),
|
||||||
|
step_size=step,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("skipping_unknown_hpo_spec", param=param_name, spec=spec)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def _create_base_task(
|
||||||
|
env: str, runner: str, training: str, queue: str
|
||||||
|
) -> str:
|
||||||
|
"""Create a base ClearML task without executing it.
|
||||||
|
|
||||||
|
Uses Task.create() to register a task pointing at scripts/train.py
|
||||||
|
with the correct Hydra overrides. The HPO optimizer will clone this.
|
||||||
|
"""
|
||||||
|
script_path = str(Path(__file__).resolve().parent / "train.py")
|
||||||
|
project_root = str(Path(__file__).resolve().parent.parent)
|
||||||
|
|
||||||
|
base_task = Task.create(
|
||||||
|
project_name="RL-Framework",
|
||||||
|
task_name=f"{env}-{runner}-{training} (HPO base)",
|
||||||
|
task_type=Task.TaskTypes.training,
|
||||||
|
script=script_path,
|
||||||
|
working_directory=project_root,
|
||||||
|
argparse_args=[
|
||||||
|
f"env={env}",
|
||||||
|
f"runner={runner}",
|
||||||
|
f"training={training}",
|
||||||
|
],
|
||||||
|
add_task_init_call=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set docker config
|
||||||
|
base_task.set_base_docker(
|
||||||
|
"registry.kube.optimize/worker-image:latest",
|
||||||
|
docker_setup_bash_script=(
|
||||||
|
"apt-get update && apt-get install -y --no-install-recommends "
|
||||||
|
"libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||||
|
"&& pip install 'jax[cuda12]' mujoco-mjx"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
|
||||||
|
base_task.set_packages(str(req_file))
|
||||||
|
|
||||||
|
task_id = base_task.id
|
||||||
|
logger.info("base_task_created", task_id=task_id, task_name=base_task.name)
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Hyperparameter optimization for RL-Framework"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-task-id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Existing ClearML task ID to use as base (skip auto-creation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--env", type=str, default="rotary_cartpole")
|
||||||
|
parser.add_argument("--runner", type=str, default="mujoco_single")
|
||||||
|
parser.add_argument("--training", type=str, default="ppo_single")
|
||||||
|
parser.add_argument("--queue", type=str, default="gpu-queue")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-concurrent", type=int, default=2,
|
||||||
|
help="Maximum concurrent trial tasks",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--total-trials", type=int, default=200,
|
||||||
|
help="Total HPO trial budget",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-budget", type=int, default=3,
|
||||||
|
help="Minimum budget (epochs) per trial",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-budget", type=int, default=81,
|
||||||
|
help="Maximum budget (epochs) for promoted trials",
|
||||||
|
)
|
||||||
|
parser.add_argument("--eta", type=int, default=3, help="Successive halving reduction factor")
|
||||||
|
parser.add_argument(
|
||||||
|
"--time-limit-hours", type=float, default=72,
|
||||||
|
help="Total wall-clock time limit in hours",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--objective-metric", type=str, default="Reward / Total reward (mean)",
|
||||||
|
help="ClearML scalar metric title to optimize",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--objective-series", type=str, default=None,
|
||||||
|
help="ClearML scalar metric series (default: same as title)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--maximize", action="store_true", default=True,
|
||||||
|
help="Maximize the objective (default)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--minimize", action="store_true", default=False,
|
||||||
|
help="Minimize the objective",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run", action="store_true",
|
||||||
|
help="Print search space and exit without running",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
objective_sign = "min" if args.minimize else "max"
|
||||||
|
|
||||||
|
# ── Load config and build search space ────────────────────────
|
||||||
|
config = _load_hydra_config(args.env, args.runner, args.training)
|
||||||
|
hyper_parameters = _build_hyper_parameters(config)
|
||||||
|
|
||||||
|
if not hyper_parameters:
|
||||||
|
logger.error(
|
||||||
|
"no_hpo_ranges_found",
|
||||||
|
hint="Add 'hpo:' blocks to your training and/or env YAML configs",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.dry_run:
|
||||||
|
print(f"\nSearch space ({len(hyper_parameters)} parameters):")
|
||||||
|
for p in hyper_parameters:
|
||||||
|
print(f" {p.name}: {p}")
|
||||||
|
print(f"\nObjective: {args.objective_metric} ({objective_sign})")
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Create or reuse base task ─────────────────────────────────
|
||||||
|
if args.base_task_id:
|
||||||
|
base_task_id = args.base_task_id
|
||||||
|
logger.info("using_existing_base_task", task_id=base_task_id)
|
||||||
|
else:
|
||||||
|
base_task_id = _create_base_task(
|
||||||
|
args.env, args.runner, args.training, args.queue
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Initialize ClearML HPO task ───────────────────────────────
|
||||||
|
Task.ignore_requirements("torch")
|
||||||
|
task = Task.init(
|
||||||
|
project_name="RL-Framework",
|
||||||
|
task_name=f"HPO {args.env}-{args.runner}-{args.training}",
|
||||||
|
task_type=Task.TaskTypes.optimizer,
|
||||||
|
reuse_last_task_id=False,
|
||||||
|
)
|
||||||
|
task.set_base_docker(
|
||||||
|
docker_image="registry.kube.optimize/worker-image:latest",
|
||||||
|
docker_arguments=[
|
||||||
|
"-e", "CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1",
|
||||||
|
"-e", "CLEARML_AGENT_SKIP_PIP_VENV_INSTALL=1",
|
||||||
|
"-e", "CLEARML_AGENT_FORCE_SYSTEM_SITE_PACKAGES=1",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
|
||||||
|
task.set_packages(str(req_file))
|
||||||
|
|
||||||
|
# ── Build objective metric ────────────────────────────────────
|
||||||
|
# skrl's SequentialTrainer logs "Reward / Total reward (mean)" by default
|
||||||
|
objective_title = args.objective_metric
|
||||||
|
objective_series = args.objective_series or objective_title
|
||||||
|
|
||||||
|
# ── Launch optimizer ──────────────────────────────────────────
|
||||||
|
from src.hpo.smac3 import OptimizerSMAC
|
||||||
|
|
||||||
|
optimizer = HyperParameterOptimizer(
|
||||||
|
base_task_id=base_task_id,
|
||||||
|
hyper_parameters=hyper_parameters,
|
||||||
|
objective_metric_title=objective_title,
|
||||||
|
objective_metric_series=objective_series,
|
||||||
|
objective_metric_sign=objective_sign,
|
||||||
|
optimizer_class=OptimizerSMAC,
|
||||||
|
execution_queue=args.queue,
|
||||||
|
max_number_of_concurrent_tasks=args.max_concurrent,
|
||||||
|
total_max_jobs=args.total_trials,
|
||||||
|
min_iteration_per_job=args.min_budget,
|
||||||
|
max_iteration_per_job=args.max_budget,
|
||||||
|
pool_period_min=1,
|
||||||
|
time_limit_per_job=240, # 4 hours per trial max
|
||||||
|
eta=args.eta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send this HPO controller to a remote services worker
|
||||||
|
task.execute_remotely(queue_name="services", exit_process=True)
|
||||||
|
|
||||||
|
# Reporting and time limits
|
||||||
|
optimizer.set_report_period(1)
|
||||||
|
optimizer.set_time_limit(in_minutes=int(args.time_limit_hours * 60))
|
||||||
|
|
||||||
|
# Start and wait
|
||||||
|
optimizer.start()
|
||||||
|
optimizer.wait()
|
||||||
|
|
||||||
|
# Get top experiments
|
||||||
|
max_retries = 5
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
top_exp = optimizer.get_top_experiments(top_k=10)
|
||||||
|
logger.info("top_experiments_retrieved", count=len(top_exp))
|
||||||
|
for i, t in enumerate(top_exp):
|
||||||
|
logger.info("top_experiment", rank=i + 1, task_id=t.id, name=t.name)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("retry_get_top_experiments", attempt=attempt + 1, error=str(e))
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
time.sleep(5.0 * (2 ** attempt))
|
||||||
|
else:
|
||||||
|
logger.error("could_not_retrieve_top_experiments")
|
||||||
|
|
||||||
|
optimizer.stop()
|
||||||
|
logger.info("hpo_complete")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
57
scripts/sysid.py
Normal file
57
scripts/sysid.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Unified CLI for system identification tools.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/sysid.py capture --robot-path assets/rotary_cartpole --duration 20
|
||||||
|
python scripts/sysid.py optimize --robot-path assets/rotary_cartpole --recording <file>.npz
|
||||||
|
python scripts/sysid.py visualize --recording <file>.npz
|
||||||
|
python scripts/sysid.py export --robot-path assets/rotary_cartpole --result <result>.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure project root is on sys.path
|
||||||
|
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||||
|
if _PROJECT_ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, _PROJECT_ROOT)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
|
||||||
|
print(
|
||||||
|
"Usage: python scripts/sysid.py <command> [options]\n"
|
||||||
|
"\n"
|
||||||
|
"Commands:\n"
|
||||||
|
" capture Record real robot trajectory under PRBS excitation\n"
|
||||||
|
" optimize Run CMA-ES parameter optimization\n"
|
||||||
|
" visualize Plot real vs simulated trajectories\n"
|
||||||
|
" export Write tuned URDF + robot.yaml files\n"
|
||||||
|
"\n"
|
||||||
|
"Run 'python scripts/sysid.py <command> --help' for command-specific options."
|
||||||
|
)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
command = sys.argv[1]
|
||||||
|
# Remove the subcommand from argv so the module's argparse works normally
|
||||||
|
sys.argv = [f"sysid {command}"] + sys.argv[2:]
|
||||||
|
|
||||||
|
if command == "capture":
|
||||||
|
from src.sysid.capture import main as cmd_main
|
||||||
|
elif command == "optimize":
|
||||||
|
from src.sysid.optimize import main as cmd_main
|
||||||
|
elif command == "visualize":
|
||||||
|
from src.sysid.visualize import main as cmd_main
|
||||||
|
elif command == "export":
|
||||||
|
from src.sysid.export import main as cmd_main
|
||||||
|
else:
|
||||||
|
print(f"Unknown command: {command}")
|
||||||
|
print("Available commands: capture, optimize, visualize, export")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
cmd_main()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
118
scripts/train.py
Normal file
118
scripts/train.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Ensure project root is on sys.path so `src.*` imports work
|
||||||
|
# regardless of which directory the script is invoked from.
|
||||||
|
_PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
|
||||||
|
if _PROJECT_ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, _PROJECT_ROOT)
|
||||||
|
|
||||||
|
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import)
|
||||||
|
if sys.platform == "linux" and "DISPLAY" not in os.environ:
|
||||||
|
os.environ.setdefault("MUJOCO_GL", "osmesa")
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
import hydra.utils as hydra_utils
|
||||||
|
import structlog
|
||||||
|
from clearml import Task
|
||||||
|
from hydra.core.hydra_config import HydraConfig
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
|
from src.core.env import BaseEnv
|
||||||
|
from src.core.registry import build_env
|
||||||
|
from src.core.runner import BaseRunner
|
||||||
|
from src.training.trainer import Trainer, TrainerConfig
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# ── runner registry ───────────────────────────────────────────────────
|
||||||
|
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
|
||||||
|
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
|
||||||
|
|
||||||
|
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
|
||||||
|
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||||
|
"mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||||
|
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
|
||||||
|
"serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
|
||||||
|
"""Instantiate the right runner from the Hydra config-group name."""
|
||||||
|
if runner_name not in RUNNER_REGISTRY:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
|
||||||
|
)
|
||||||
|
module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
mod = importlib.import_module(module_path)
|
||||||
|
runner_cls = getattr(mod, cls_name)
|
||||||
|
config_cls = getattr(mod, cfg_cls_name)
|
||||||
|
|
||||||
|
runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||||
|
return runner_cls(env=env, config=runner_config)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||||
|
"""Initialize ClearML task with project structure and tags."""
|
||||||
|
Task.ignore_requirements("torch")
|
||||||
|
|
||||||
|
env_name = choices.get("env", "cartpole")
|
||||||
|
runner_name = choices.get("runner", "mujoco")
|
||||||
|
training_name = choices.get("training", "ppo")
|
||||||
|
|
||||||
|
project = "RL-Framework"
|
||||||
|
task_name = f"{env_name}-{runner_name}-{training_name}"
|
||||||
|
tags = [env_name, runner_name, training_name]
|
||||||
|
|
||||||
|
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||||
|
task.set_base_docker(
|
||||||
|
"registry.kube.optimize/worker-image:latest",
|
||||||
|
docker_setup_bash_script=(
|
||||||
|
"apt-get update && apt-get install -y --no-install-recommends "
|
||||||
|
"libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||||
|
"&& pip install 'jax[cuda12]' mujoco-mjx"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
||||||
|
task.set_packages(str(req_file))
|
||||||
|
|
||||||
|
# Execute remotely if requested and running locally
|
||||||
|
if remote and task.running_locally():
|
||||||
|
logger.info("executing_task_remotely", queue="gpu-queue")
|
||||||
|
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||||
|
def main(cfg: DictConfig) -> None:
|
||||||
|
choices = HydraConfig.get().runtime.choices
|
||||||
|
|
||||||
|
# ClearML init — must happen before heavy work so remote execution
|
||||||
|
# can take over early. The remote worker re-runs the full script;
|
||||||
|
# execute_remotely() is a no-op on the worker side.
|
||||||
|
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
||||||
|
remote = training_dict.pop("remote", False)
|
||||||
|
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
|
||||||
|
task = _init_clearml(choices, remote=remote)
|
||||||
|
|
||||||
|
env_name = choices.get("env", "cartpole")
|
||||||
|
env = build_env(env_name, cfg)
|
||||||
|
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||||
|
trainer_config = TrainerConfig(**training_dict)
|
||||||
|
trainer = Trainer(runner=runner, config=trainer_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
trainer.train()
|
||||||
|
finally:
|
||||||
|
trainer.close()
|
||||||
|
task.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
254
scripts/viz.py
Normal file
254
scripts/viz.py
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
|
||||||
|
|
||||||
|
Usage (simulation):
|
||||||
|
mjpython scripts/viz.py env=rotary_cartpole
|
||||||
|
mjpython scripts/viz.py env=cartpole +com=true
|
||||||
|
|
||||||
|
Usage (real hardware — digital twin):
|
||||||
|
mjpython scripts/viz.py env=rotary_cartpole runner=serial
|
||||||
|
mjpython scripts/viz.py env=rotary_cartpole runner=serial runner.port=/dev/ttyUSB0
|
||||||
|
|
||||||
|
Controls:
|
||||||
|
Left/Right arrows — apply torque to first actuator
|
||||||
|
R — reset environment
|
||||||
|
Esc / close window — quit
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure project root is on sys.path
|
||||||
|
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||||
|
if _PROJECT_ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, _PROJECT_ROOT)
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
import mujoco
|
||||||
|
import mujoco.viewer
|
||||||
|
import numpy as np
|
||||||
|
import structlog
|
||||||
|
import torch
|
||||||
|
from hydra.core.hydra_config import HydraConfig
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
|
from src.core.registry import build_env
|
||||||
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# ── keyboard state ───────────────────────────────────────────────────
|
||||||
|
_action_val = [0.0] # mutable container shared with callback
|
||||||
|
_action_time = [0.0] # timestamp of last key press
|
||||||
|
_reset_flag = [False]
|
||||||
|
_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
|
||||||
|
|
||||||
|
|
||||||
|
def _key_callback(keycode: int) -> None:
|
||||||
|
"""Called by MuJoCo on key press & repeat (not release)."""
|
||||||
|
if keycode == 263: # GLFW_KEY_LEFT
|
||||||
|
_action_val[0] = -1.0
|
||||||
|
_action_time[0] = time.time()
|
||||||
|
elif keycode == 262: # GLFW_KEY_RIGHT
|
||||||
|
_action_val[0] = 1.0
|
||||||
|
_action_time[0] = time.time()
|
||||||
|
elif keycode == 82: # GLFW_KEY_R
|
||||||
|
_reset_flag[0] = True
|
||||||
|
|
||||||
|
|
||||||
|
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
||||||
|
"""Draw an arrow on the motor joint showing applied torque direction."""
|
||||||
|
if abs(action_val) < 0.01 or model.nu == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the body that the first actuator's joint belongs to
|
||||||
|
jnt_id = model.actuator_trnid[0, 0]
|
||||||
|
body_id = model.jnt_bodyid[jnt_id]
|
||||||
|
|
||||||
|
# Arrow origin: body position
|
||||||
|
pos = data.xpos[body_id].copy()
|
||||||
|
pos[2] += 0.02 # lift slightly above the body
|
||||||
|
|
||||||
|
# Arrow direction: along joint axis in world frame, scaled by action
|
||||||
|
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
|
||||||
|
arrow_len = 0.08 * action_val
|
||||||
|
direction = axis * np.sign(arrow_len)
|
||||||
|
|
||||||
|
# Build rotation matrix: arrow rendered along local z-axis
|
||||||
|
z = direction / (np.linalg.norm(direction) + 1e-8)
|
||||||
|
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
|
||||||
|
x = np.cross(up, z)
|
||||||
|
x /= np.linalg.norm(x) + 1e-8
|
||||||
|
y = np.cross(z, x)
|
||||||
|
mat = np.column_stack([x, y, z]).flatten()
|
||||||
|
|
||||||
|
# Color: green = positive, red = negative
|
||||||
|
rgba = np.array(
|
||||||
|
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
|
||||||
|
mujoco.mjv_initGeom(
|
||||||
|
geom,
|
||||||
|
type=mujoco.mjtGeom.mjGEOM_ARROW,
|
||||||
|
size=np.array([0.008, 0.008, abs(arrow_len)]),
|
||||||
|
pos=pos,
|
||||||
|
mat=mat,
|
||||||
|
rgba=rgba,
|
||||||
|
)
|
||||||
|
viewer.user_scn.ngeom += 1
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||||
|
def main(cfg: DictConfig) -> None:
|
||||||
|
choices = HydraConfig.get().runtime.choices
|
||||||
|
env_name = choices.get("env", "cartpole")
|
||||||
|
runner_name = choices.get("runner", "mujoco")
|
||||||
|
|
||||||
|
if runner_name == "serial":
|
||||||
|
_main_serial(cfg, env_name)
|
||||||
|
else:
|
||||||
|
_main_sim(cfg, env_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _main_sim(cfg: DictConfig, env_name: str) -> None:
|
||||||
|
"""Simulation visualization — step MuJoCo physics with keyboard control."""
|
||||||
|
|
||||||
|
# Build env + runner (single env for viz)
|
||||||
|
env = build_env(env_name, cfg)
|
||||||
|
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||||
|
runner_dict["num_envs"] = 1
|
||||||
|
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
||||||
|
|
||||||
|
model = runner._model
|
||||||
|
data = runner._data[0]
|
||||||
|
|
||||||
|
# Control period
|
||||||
|
dt_ctrl = runner.config.dt * runner.config.substeps
|
||||||
|
|
||||||
|
# Launch viewer
|
||||||
|
with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
|
||||||
|
# Show CoM / inertia if requested via Hydra override: viz.py +com=true
|
||||||
|
show_com = cfg.get("com", False)
|
||||||
|
if show_com:
|
||||||
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||||
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||||
|
|
||||||
|
obs, _ = runner.reset()
|
||||||
|
step = 0
|
||||||
|
|
||||||
|
logger.info("viewer_started", env=env_name,
|
||||||
|
controls="Left/Right arrows = torque, R = reset")
|
||||||
|
|
||||||
|
while viewer.is_running():
|
||||||
|
# Read action from callback (expires after _ACTION_HOLD_S)
|
||||||
|
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||||
|
action_val = _action_val[0]
|
||||||
|
else:
|
||||||
|
action_val = 0.0
|
||||||
|
|
||||||
|
# Reset on R press
|
||||||
|
if _reset_flag[0]:
|
||||||
|
_reset_flag[0] = False
|
||||||
|
obs, _ = runner.reset()
|
||||||
|
step = 0
|
||||||
|
logger.info("reset")
|
||||||
|
|
||||||
|
# Step through runner
|
||||||
|
action = torch.tensor([[action_val]])
|
||||||
|
obs, reward, terminated, truncated, info = runner.step(action)
|
||||||
|
|
||||||
|
# Sync viewer with action arrow overlay
|
||||||
|
mujoco.mj_forward(model, data)
|
||||||
|
viewer.user_scn.ngeom = 0 # clear previous frame's overlays
|
||||||
|
_add_action_arrow(viewer, model, data, action_val)
|
||||||
|
viewer.sync()
|
||||||
|
|
||||||
|
# Print state
|
||||||
|
if step % 25 == 0:
|
||||||
|
joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
|
||||||
|
for i in range(model.njnt)}
|
||||||
|
logger.debug("step", n=step, reward=round(reward.item(), 3),
|
||||||
|
action=round(action_val, 1), **joints)
|
||||||
|
|
||||||
|
# Real-time pacing
|
||||||
|
time.sleep(dt_ctrl)
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
runner.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _main_serial(cfg: DictConfig, env_name: str) -> None:
|
||||||
|
"""Digital-twin visualization — mirror real hardware in MuJoCo viewer.
|
||||||
|
|
||||||
|
The MuJoCo model is loaded for rendering only. Joint positions are
|
||||||
|
read from the ESP32 over serial and applied to the model each frame.
|
||||||
|
Keyboard arrows send motor commands to the real robot.
|
||||||
|
"""
|
||||||
|
from src.runners.serial import SerialRunner, SerialRunnerConfig
|
||||||
|
|
||||||
|
env = build_env(env_name, cfg)
|
||||||
|
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||||
|
serial_runner = SerialRunner(
|
||||||
|
env=env, config=SerialRunnerConfig(**runner_dict)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load MuJoCo model for visualisation (same URDF the sim uses).
|
||||||
|
serial_runner._ensure_viz_model()
|
||||||
|
model = serial_runner._viz_model
|
||||||
|
data = serial_runner._viz_data
|
||||||
|
|
||||||
|
with mujoco.viewer.launch_passive(
|
||||||
|
model, data, key_callback=_key_callback
|
||||||
|
) as viewer:
|
||||||
|
# Show CoM / inertia if requested.
|
||||||
|
show_com = cfg.get("com", False)
|
||||||
|
if show_com:
|
||||||
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||||
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"viewer_started",
|
||||||
|
env=env_name,
|
||||||
|
mode="serial (digital twin)",
|
||||||
|
port=serial_runner.config.port,
|
||||||
|
controls="Left/Right arrows = motor command, R = reset",
|
||||||
|
)
|
||||||
|
|
||||||
|
while viewer.is_running():
|
||||||
|
# Read action from keyboard callback.
|
||||||
|
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||||
|
action_val = _action_val[0]
|
||||||
|
else:
|
||||||
|
action_val = 0.0
|
||||||
|
|
||||||
|
# Reset on R press.
|
||||||
|
if _reset_flag[0]:
|
||||||
|
_reset_flag[0] = False
|
||||||
|
serial_runner._send("M0")
|
||||||
|
serial_runner._drive_to_center()
|
||||||
|
serial_runner._wait_for_pendulum_still()
|
||||||
|
logger.info("reset (drive-to-center + settle)")
|
||||||
|
|
||||||
|
# Send motor command to real hardware.
|
||||||
|
motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
|
||||||
|
serial_runner._send(f"M{motor_speed}")
|
||||||
|
|
||||||
|
# Sync MuJoCo model with real sensor data.
|
||||||
|
serial_runner._sync_viz()
|
||||||
|
|
||||||
|
# Render overlays and sync viewer.
|
||||||
|
viewer.user_scn.ngeom = 0
|
||||||
|
_add_action_arrow(viewer, model, data, action_val)
|
||||||
|
viewer.sync()
|
||||||
|
|
||||||
|
# Real-time pacing (~50 Hz, matches serial dt).
|
||||||
|
time.sleep(serial_runner.config.dt)
|
||||||
|
|
||||||
|
serial_runner.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
91
src/core/hardware.py
Normal file
91
src/core/hardware.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""Real-hardware configuration — loaded from hardware.yaml next to robot.yaml.
|
||||||
|
|
||||||
|
Provides robot-specific constants for the SerialRunner: encoder specs,
|
||||||
|
safety limits, and reset behaviour. Simulation-only robots simply don't
|
||||||
|
have a hardware.yaml (the loader returns None).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
hw = load_hardware_config("assets/rotary_cartpole")
|
||||||
|
if hw is not None:
|
||||||
|
counts_per_rev = hw.encoder.ppr * hw.encoder.gear_ratio * 4.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
log = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class EncoderConfig:
|
||||||
|
"""Rotary encoder parameters."""
|
||||||
|
|
||||||
|
ppr: int = 11 # pulses per revolution (before quadrature)
|
||||||
|
gear_ratio: float = 30.0 # gearbox ratio
|
||||||
|
|
||||||
|
@property
|
||||||
|
def counts_per_rev(self) -> float:
|
||||||
|
"""Total encoder counts per output-shaft revolution (quadrature)."""
|
||||||
|
return self.ppr * self.gear_ratio * 4.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SafetyConfig:
|
||||||
|
"""Safety limits enforced by the runner (not the env)."""
|
||||||
|
|
||||||
|
max_motor_angle_deg: float = 90.0 # hard termination (0 = disabled)
|
||||||
|
soft_limit_deg: float = 40.0 # progressive penalty ramp start
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ResetConfig:
|
||||||
|
"""Parameters for the physical reset procedure."""
|
||||||
|
|
||||||
|
drive_speed: int = 80 # PWM for bang-bang drive-to-center
|
||||||
|
deadband: int = 15 # encoder count threshold for "centered"
|
||||||
|
drive_timeout: float = 3.0 # seconds
|
||||||
|
|
||||||
|
settle_angle_deg: float = 2.0 # pendulum angle threshold (degrees)
|
||||||
|
settle_vel_dps: float = 5.0 # pendulum velocity threshold (deg/s)
|
||||||
|
settle_duration: float = 0.5 # seconds the pendulum must stay still
|
||||||
|
settle_timeout: float = 30.0 # give up after this many seconds
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class HardwareConfig:
|
||||||
|
"""Complete real-hardware description for a robot."""
|
||||||
|
|
||||||
|
encoder: EncoderConfig = dataclasses.field(default_factory=EncoderConfig)
|
||||||
|
safety: SafetyConfig = dataclasses.field(default_factory=SafetyConfig)
|
||||||
|
reset: ResetConfig = dataclasses.field(default_factory=ResetConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_hardware_config(robot_dir: str | Path) -> HardwareConfig | None:
|
||||||
|
"""Load hardware.yaml from a directory.
|
||||||
|
|
||||||
|
Returns None if the file doesn't exist (simulation-only robot).
|
||||||
|
"""
|
||||||
|
robot_dir = Path(robot_dir).resolve()
|
||||||
|
yaml_path = robot_dir / "hardware.yaml"
|
||||||
|
|
||||||
|
if not yaml_path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw = yaml.safe_load(yaml_path.read_text()) or {}
|
||||||
|
|
||||||
|
encoder = EncoderConfig(**raw.get("encoder", {}))
|
||||||
|
safety = SafetyConfig(**raw.get("safety", {}))
|
||||||
|
reset = ResetConfig(**raw.get("reset", {}))
|
||||||
|
|
||||||
|
config = HardwareConfig(encoder=encoder, safety=safety, reset=reset)
|
||||||
|
|
||||||
|
log.debug(
|
||||||
|
"hardware_config_loaded",
|
||||||
|
robot_dir=str(robot_dir),
|
||||||
|
counts_per_rev=encoder.counts_per_rev,
|
||||||
|
max_motor_angle_deg=safety.max_motor_angle_deg,
|
||||||
|
)
|
||||||
|
return config
|
||||||
@@ -21,6 +21,7 @@ class RotaryCartPoleConfig(BaseEnvConfig):
|
|||||||
"""
|
"""
|
||||||
# Reward shaping
|
# Reward shaping
|
||||||
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
|
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
|
||||||
|
speed_penalty_scale: float = 0.1 # penalty for high pendulum velocity near top
|
||||||
|
|
||||||
|
|
||||||
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||||
@@ -69,11 +70,12 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
|||||||
# Upright reward: -cos(θ) ∈ [-1, +1]
|
# Upright reward: -cos(θ) ∈ [-1, +1]
|
||||||
upright = -torch.cos(state.pendulum_angle)
|
upright = -torch.cos(state.pendulum_angle)
|
||||||
|
|
||||||
# Velocity penalties — make spinning expensive but allow swing-up
|
# Penalise high pendulum velocity when near the top (upright).
|
||||||
pend_vel_penalty = 0.01 * state.pendulum_vel ** 2
|
# "nearness" is weighted by how upright the pendulum is (0 at bottom, 1 at top).
|
||||||
motor_vel_penalty = 0.01 * state.motor_vel ** 2
|
near_top = torch.clamp(-torch.cos(state.pendulum_angle), min=0.0) # 0‥1
|
||||||
|
speed_penalty = self.config.speed_penalty_scale * near_top * state.pendulum_vel.abs()
|
||||||
|
|
||||||
return upright - pend_vel_penalty - motor_vel_penalty
|
return upright * self.config.reward_upright_scale - speed_penalty
|
||||||
|
|
||||||
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||||
# No early termination — episode runs for max_steps (truncation only).
|
# No early termination — episode runs for max_steps (truncation only).
|
||||||
|
|||||||
1
src/hpo/__init__.py
Normal file
1
src/hpo/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Hyperparameter optimization — SMAC3 + ClearML Successive Halving."""
|
||||||
636
src/hpo/smac3.py
Normal file
636
src/hpo/smac3.py
Normal file
@@ -0,0 +1,636 @@
|
|||||||
|
# Requires: pip install smac==2.0.0 ConfigSpace==0.4.20
|
||||||
|
import contextlib
|
||||||
|
import time
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from clearml import Task
|
||||||
|
from clearml.automation.optimization import Objective, SearchStrategy
|
||||||
|
from clearml.automation.parameters import Parameter
|
||||||
|
from clearml.backend_interface.session import SendError
|
||||||
|
from ConfigSpace import (
|
||||||
|
CategoricalHyperparameter,
|
||||||
|
ConfigurationSpace,
|
||||||
|
UniformFloatHyperparameter,
|
||||||
|
UniformIntegerHyperparameter,
|
||||||
|
)
|
||||||
|
from smac import MultiFidelityFacade
|
||||||
|
from smac.intensifier.successive_halving import SuccessiveHalving
|
||||||
|
from smac.runhistory.dataclasses import TrialInfo, TrialValue
|
||||||
|
from smac.scenario import Scenario
|
||||||
|
|
||||||
|
|
||||||
|
def retry_on_error(max_retries=5, initial_delay=2.0, backoff=2.0, exceptions=(Exception,)):
|
||||||
|
"""Decorator to retry a function on exception with exponential backoff."""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
delay = initial_delay
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except exceptions:
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
return None # Return None instead of raising
|
||||||
|
time.sleep(delay)
|
||||||
|
delay *= backoff
|
||||||
|
return None
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_param_name(name: str) -> str:
|
||||||
|
"""Encode parameter name for ConfigSpace (replace / with __SLASH__)"""
|
||||||
|
return name.replace("/", "__SLASH__")
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_param_name(name: str) -> str:
|
||||||
|
"""Decode parameter name back to original (replace __SLASH__ with /)"""
|
||||||
|
return name.replace("__SLASH__", "/")
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_param_to_cs(param: Parameter):
|
||||||
|
"""
|
||||||
|
Convert a ClearML Parameter into a ConfigSpace hyperparameter,
|
||||||
|
adapted to ConfigSpace>=1.x (no more 'q' argument).
|
||||||
|
"""
|
||||||
|
# Encode the name to avoid ConfigSpace issues with special chars like '/'
|
||||||
|
name = _encode_param_name(param.name)
|
||||||
|
|
||||||
|
# Categorical / discrete list
|
||||||
|
if hasattr(param, "values"):
|
||||||
|
return CategoricalHyperparameter(name=name, choices=list(param.values))
|
||||||
|
|
||||||
|
# Numeric range (float or int)
|
||||||
|
if hasattr(param, "min_value") and hasattr(param, "max_value"):
|
||||||
|
min_val = param.min_value
|
||||||
|
max_val = param.max_value
|
||||||
|
|
||||||
|
# Check if this should be treated as integer
|
||||||
|
if isinstance(min_val, int) and isinstance(max_val, int):
|
||||||
|
log = getattr(param, "log_scale", False)
|
||||||
|
|
||||||
|
# Check for step_size for quantization
|
||||||
|
if hasattr(param, "step_size"):
|
||||||
|
sv = int(param.step_size)
|
||||||
|
if sv != 1:
|
||||||
|
# emulate quantization by explicit list of values
|
||||||
|
choices = list(range(min_val, max_val + 1, sv))
|
||||||
|
return CategoricalHyperparameter(name=name, choices=choices)
|
||||||
|
|
||||||
|
# Simple uniform integer range
|
||||||
|
return UniformIntegerHyperparameter(name=name, lower=min_val, upper=max_val, log=log)
|
||||||
|
else:
|
||||||
|
# Treat as float
|
||||||
|
lower, upper = float(min_val), float(max_val)
|
||||||
|
log = getattr(param, "log_scale", False)
|
||||||
|
return UniformFloatHyperparameter(name=name, lower=lower, upper=upper, log=log)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported Parameter type: {type(param)}")
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerSMAC(SearchStrategy):
|
||||||
|
"""
|
||||||
|
SMAC3-based hyperparameter optimizer, matching OptimizerBOHB interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_task_id: str,
|
||||||
|
hyper_parameters: Sequence[Parameter],
|
||||||
|
objective_metric: Objective,
|
||||||
|
execution_queue: str,
|
||||||
|
num_concurrent_workers: int,
|
||||||
|
min_iteration_per_job: int,
|
||||||
|
max_iteration_per_job: int,
|
||||||
|
total_max_jobs: int,
|
||||||
|
pool_period_min: float = 2.0,
|
||||||
|
time_limit_per_job: float | None = None,
|
||||||
|
compute_time_limit: float | None = None,
|
||||||
|
**smac_kwargs: Any,
|
||||||
|
):
|
||||||
|
# Initialize base SearchStrategy
|
||||||
|
super().__init__(
|
||||||
|
base_task_id=base_task_id,
|
||||||
|
hyper_parameters=hyper_parameters,
|
||||||
|
objective_metric=objective_metric,
|
||||||
|
execution_queue=execution_queue,
|
||||||
|
num_concurrent_workers=num_concurrent_workers,
|
||||||
|
pool_period_min=pool_period_min,
|
||||||
|
time_limit_per_job=time_limit_per_job,
|
||||||
|
compute_time_limit=compute_time_limit,
|
||||||
|
min_iteration_per_job=min_iteration_per_job,
|
||||||
|
max_iteration_per_job=max_iteration_per_job,
|
||||||
|
total_max_jobs=total_max_jobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expose for internal use (access private attributes from base class)
|
||||||
|
self.execution_queue = self._execution_queue
|
||||||
|
self.min_iterations = min_iteration_per_job
|
||||||
|
self.max_iterations = max_iteration_per_job
|
||||||
|
self.num_concurrent_workers = self._num_concurrent_workers # Fix: access private attribute
|
||||||
|
|
||||||
|
# Objective details
|
||||||
|
# Handle both single objective (string) and multi-objective (list) cases
|
||||||
|
if isinstance(self._objective_metric.title, list):
|
||||||
|
self.metric_title = self._objective_metric.title[0] # Use first objective
|
||||||
|
else:
|
||||||
|
self.metric_title = self._objective_metric.title
|
||||||
|
|
||||||
|
if isinstance(self._objective_metric.series, list):
|
||||||
|
self.metric_series = self._objective_metric.series[0] # Use first series
|
||||||
|
else:
|
||||||
|
self.metric_series = self._objective_metric.series
|
||||||
|
|
||||||
|
# ClearML Objective stores sign as a list, e.g., ['max'] or ['min']
|
||||||
|
objective_sign = getattr(self._objective_metric, "sign", None) or getattr(self._objective_metric, "order", None)
|
||||||
|
|
||||||
|
# Handle list case - extract first element
|
||||||
|
if isinstance(objective_sign, list):
|
||||||
|
objective_sign = objective_sign[0] if objective_sign else "max"
|
||||||
|
|
||||||
|
# Default to max if nothing found
|
||||||
|
if objective_sign is None:
|
||||||
|
objective_sign = "max"
|
||||||
|
|
||||||
|
self.maximize_metric = str(objective_sign).lower() in ("max", "max_global")
|
||||||
|
|
||||||
|
# Build ConfigSpace
|
||||||
|
self.config_space = ConfigurationSpace(seed=42)
|
||||||
|
for p in self._hyper_parameters: # Access private attribute correctly
|
||||||
|
cs_hp = _convert_param_to_cs(p)
|
||||||
|
self.config_space.add(cs_hp)
|
||||||
|
|
||||||
|
# Configure SMAC Scenario
|
||||||
|
scenario = Scenario(
|
||||||
|
configspace=self.config_space,
|
||||||
|
n_trials=self.total_max_jobs,
|
||||||
|
min_budget=float(self.min_iterations),
|
||||||
|
max_budget=float(self.max_iterations),
|
||||||
|
walltime_limit=(self.compute_time_limit * 60) if self.compute_time_limit else None,
|
||||||
|
deterministic=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# build the Successive Halving intensifier (NOT Hyperband!)
|
||||||
|
# Hyperband runs multiple brackets with different starting budgets - wasteful
|
||||||
|
# Successive Halving: ALL configs start at min_budget, only best get promoted
|
||||||
|
# eta controls the reduction factor (default 3 means keep top 1/3 each round)
|
||||||
|
# eta can be overridden via smac_kwargs from HyperParameterOptimizer
|
||||||
|
eta = smac_kwargs.pop("eta", 3) # Default to 3 if not specified
|
||||||
|
intensifier = SuccessiveHalving(scenario=scenario, eta=eta, **smac_kwargs)
|
||||||
|
|
||||||
|
# now pass that intensifier instance into the facade
|
||||||
|
self.smac = MultiFidelityFacade(
|
||||||
|
scenario=scenario,
|
||||||
|
target_function=lambda config, budget, seed: 0.0,
|
||||||
|
intensifier=intensifier,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bookkeeping
|
||||||
|
self.running_tasks = {} # task_id -> trial info
|
||||||
|
self.task_start_times = {} # task_id -> start time (for timeout)
|
||||||
|
self.completed_results = []
|
||||||
|
self.best_score_so_far = float("-inf") if self.maximize_metric else float("inf")
|
||||||
|
self.time_limit_per_job = time_limit_per_job # Store time limit (minutes)
|
||||||
|
|
||||||
|
# Checkpoint continuation tracking: config_key -> {budget: task_id}
|
||||||
|
# Used to find the previous task's checkpoint when promoting a config
|
||||||
|
self.config_to_tasks = {} # config_key -> {budget: task_id}
|
||||||
|
|
||||||
|
# Manual Successive Halving control
|
||||||
|
self.eta = eta
|
||||||
|
self.current_budget = float(self.min_iterations)
|
||||||
|
self.configs_at_budget = {} # budget -> list of (config, score, trial)
|
||||||
|
self.pending_configs = [] # configs waiting to be evaluated at current_budget - list of (trial, prev_task_id)
|
||||||
|
self.evaluated_at_budget = [] # (config, score, trial, task_id) for current budget
|
||||||
|
self.smac_asked_configs = set() # track which configs SMAC has given us
|
||||||
|
|
||||||
|
# Calculate initial rung size for proper Successive Halving
|
||||||
|
# With eta=3: rung sizes are n, n/3, n/9, ...
|
||||||
|
# Total trials = n * (1 + 1/eta + 1/eta^2 + ...) = n * eta/(eta-1) for infinite series
|
||||||
|
# For finite rungs, calculate exactly
|
||||||
|
num_rungs = 1
|
||||||
|
b = float(self.min_iterations)
|
||||||
|
while b * eta <= self.max_iterations:
|
||||||
|
num_rungs += 1
|
||||||
|
b *= eta
|
||||||
|
|
||||||
|
# Sum of geometric series: 1 + 1/eta + 1/eta^2 + ... (num_rungs terms)
|
||||||
|
series_sum = sum(1.0 / (eta**i) for i in range(num_rungs))
|
||||||
|
self.initial_rung_size = int(self.total_max_jobs / series_sum)
|
||||||
|
self.initial_rung_size = max(self.initial_rung_size, self.num_concurrent_workers) # at least num_workers
|
||||||
|
self.configs_needed_for_rung = self.initial_rung_size # how many configs we still need for current rung
|
||||||
|
self.rung_closed = False # whether we've collected all configs for current rung
|
||||||
|
|
||||||
|
@retry_on_error(max_retries=3, initial_delay=5.0, exceptions=(ValueError, SendError, ConnectionError))
|
||||||
|
def _get_task_safe(self, task_id: str):
|
||||||
|
"""Safely get a task with retry logic."""
|
||||||
|
return Task.get_task(task_id=task_id)
|
||||||
|
|
||||||
|
@retry_on_error(max_retries=3, initial_delay=5.0, exceptions=(ValueError, SendError, ConnectionError))
|
||||||
|
def _launch_task(self, config: dict, budget: float, prev_task_id: str | None = None):
|
||||||
|
"""Launch a task with retry logic for robustness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Hyperparameter configuration dict
|
||||||
|
budget: Number of epochs to train
|
||||||
|
prev_task_id: Optional task ID from previous budget to continue from (checkpoint)
|
||||||
|
"""
|
||||||
|
base = self._get_task_safe(task_id=self._base_task_id)
|
||||||
|
if base is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
clone = Task.clone(
|
||||||
|
source_task=base,
|
||||||
|
name=f"HPO Trial - {base.name}",
|
||||||
|
parent=Task.current_task().id, # Set the current HPO task as parent
|
||||||
|
)
|
||||||
|
# Override hyperparameters
|
||||||
|
for k, v in config.items():
|
||||||
|
# Decode parameter name back to original (with slashes)
|
||||||
|
original_name = _decode_param_name(k)
|
||||||
|
# Convert numpy types to Python built-in types
|
||||||
|
if hasattr(v, "item"): # numpy scalar
|
||||||
|
param_value = v.item()
|
||||||
|
elif isinstance(v, int | float | str | bool):
|
||||||
|
param_value = type(v)(v) # Ensure it's the built-in type
|
||||||
|
else:
|
||||||
|
param_value = v
|
||||||
|
clone.set_parameter(original_name, param_value)
|
||||||
|
# Override epochs budget if multi-fidelity
|
||||||
|
if self.max_iterations != self.min_iterations:
|
||||||
|
clone.set_parameter("Hydra/training.max_epochs", int(budget))
|
||||||
|
else:
|
||||||
|
clone.set_parameter("Hydra/training.max_epochs", int(self.max_iterations))
|
||||||
|
|
||||||
|
# If we have a previous task, pass its ID so the worker can download the checkpoint
|
||||||
|
if prev_task_id:
|
||||||
|
clone.set_parameter("Hydra/training.resume_from_task_id", prev_task_id)
|
||||||
|
|
||||||
|
Task.enqueue(task=clone, queue_name=self.execution_queue)
|
||||||
|
# Track start time for timeout enforcement
|
||||||
|
self.task_start_times[clone.id] = time.time()
|
||||||
|
return clone
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
controller = Task.current_task()
|
||||||
|
total_launched = 0
|
||||||
|
|
||||||
|
# Keep launching & collecting until budget exhausted
|
||||||
|
while total_launched < self.total_max_jobs:
|
||||||
|
# Check if current budget rung is complete BEFORE asking for new trials
|
||||||
|
# (no running tasks, no pending configs, and we have results for this budget)
|
||||||
|
if not self.running_tasks and not self.pending_configs and self.evaluated_at_budget:
|
||||||
|
# Rung complete! Promote top performers to next budget
|
||||||
|
|
||||||
|
# Store results for this budget
|
||||||
|
self.configs_at_budget[self.current_budget] = self.evaluated_at_budget.copy()
|
||||||
|
|
||||||
|
# Sort by score (best first)
|
||||||
|
sorted_configs = sorted(
|
||||||
|
self.evaluated_at_budget,
|
||||||
|
key=lambda x: x[1], # score
|
||||||
|
reverse=self.maximize_metric,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print rung results
|
||||||
|
for _i, (_cfg, _score, _tri, _task_id) in enumerate(sorted_configs[:5], 1):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Move to next budget?
|
||||||
|
next_budget = self.current_budget * self.eta
|
||||||
|
if next_budget <= self.max_iterations:
|
||||||
|
# How many to promote (top 1/eta)
|
||||||
|
n_promote = max(1, len(sorted_configs) // self.eta)
|
||||||
|
promoted = sorted_configs[:n_promote]
|
||||||
|
|
||||||
|
# Update budget and reset for next rung
|
||||||
|
self.current_budget = next_budget
|
||||||
|
self.evaluated_at_budget = []
|
||||||
|
self.configs_needed_for_rung = 0 # promoted configs are all we need
|
||||||
|
self.rung_closed = True # rung is pre-filled with promoted configs
|
||||||
|
|
||||||
|
# Re-queue promoted configs with new budget
|
||||||
|
# Include the previous task ID for checkpoint continuation
|
||||||
|
for _cfg, _score, old_trial, prev_task_id in promoted:
|
||||||
|
new_trial = TrialInfo(
|
||||||
|
config=old_trial.config,
|
||||||
|
instance=old_trial.instance,
|
||||||
|
seed=old_trial.seed,
|
||||||
|
budget=self.current_budget,
|
||||||
|
)
|
||||||
|
# Store as tuple: (trial, prev_task_id)
|
||||||
|
self.pending_configs.append((new_trial, prev_task_id))
|
||||||
|
else:
|
||||||
|
# All budgets complete
|
||||||
|
break
|
||||||
|
|
||||||
|
# Fill pending_configs with new trials ONLY if we haven't closed this rung yet
|
||||||
|
# For the first rung: ask SMAC for initial_rung_size configs total
|
||||||
|
# For subsequent rungs: only use promoted configs (rung is already closed)
|
||||||
|
while (
|
||||||
|
not self.rung_closed
|
||||||
|
and len(self.pending_configs) + len(self.evaluated_at_budget) + len(self.running_tasks)
|
||||||
|
< self.initial_rung_size
|
||||||
|
and total_launched < self.total_max_jobs
|
||||||
|
):
|
||||||
|
trial = self.smac.ask()
|
||||||
|
if trial is None:
|
||||||
|
self.rung_closed = True
|
||||||
|
break
|
||||||
|
# Create new trial with forced budget (TrialInfo is frozen, can't modify)
|
||||||
|
trial_with_budget = TrialInfo(
|
||||||
|
config=trial.config,
|
||||||
|
instance=trial.instance,
|
||||||
|
seed=trial.seed,
|
||||||
|
budget=self.current_budget,
|
||||||
|
)
|
||||||
|
cfg_key = str(sorted(trial.config.items()))
|
||||||
|
if cfg_key not in self.smac_asked_configs:
|
||||||
|
self.smac_asked_configs.add(cfg_key)
|
||||||
|
# Store as tuple: (trial, None) - no previous task for new configs
|
||||||
|
self.pending_configs.append((trial_with_budget, None))
|
||||||
|
|
||||||
|
# Check if we've collected enough configs for this rung
|
||||||
|
if (
|
||||||
|
not self.rung_closed
|
||||||
|
and len(self.pending_configs) + len(self.evaluated_at_budget) + len(self.running_tasks)
|
||||||
|
>= self.initial_rung_size
|
||||||
|
):
|
||||||
|
self.rung_closed = True
|
||||||
|
|
||||||
|
# Launch pending configs up to concurrent limit
|
||||||
|
while self.pending_configs and len(self.running_tasks) < self.num_concurrent_workers:
|
||||||
|
# Unpack tuple: (trial, prev_task_id)
|
||||||
|
trial, prev_task_id = self.pending_configs.pop(0)
|
||||||
|
t = self._launch_task(trial.config, self.current_budget, prev_task_id=prev_task_id)
|
||||||
|
if t is None:
|
||||||
|
# Launch failed, mark trial as failed and continue
|
||||||
|
# Tell SMAC this trial failed with worst possible score
|
||||||
|
cost = float("inf") if self.maximize_metric else float("-inf")
|
||||||
|
self.smac.tell(trial, TrialValue(cost=cost))
|
||||||
|
total_launched += 1
|
||||||
|
continue
|
||||||
|
self.running_tasks[t.id] = trial
|
||||||
|
|
||||||
|
# Track which task ID was used for this config at this budget
|
||||||
|
cfg_key = str(sorted(trial.config.items()))
|
||||||
|
if cfg_key not in self.config_to_tasks:
|
||||||
|
self.config_to_tasks[cfg_key] = {}
|
||||||
|
self.config_to_tasks[cfg_key][self.current_budget] = t.id
|
||||||
|
|
||||||
|
total_launched += 1
|
||||||
|
|
||||||
|
if not self.running_tasks and not self.pending_configs:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Poll for finished or timed out
|
||||||
|
done = []
|
||||||
|
timed_out = []
|
||||||
|
failed_to_check = []
|
||||||
|
for tid, _tri in self.running_tasks.items():
|
||||||
|
try:
|
||||||
|
task = self._get_task_safe(task_id=tid)
|
||||||
|
if task is None:
|
||||||
|
failed_to_check.append(tid)
|
||||||
|
continue
|
||||||
|
|
||||||
|
st = task.get_status()
|
||||||
|
|
||||||
|
# Check if task completed normally
|
||||||
|
if st == Task.TaskStatusEnum.completed or st in (
|
||||||
|
Task.TaskStatusEnum.failed,
|
||||||
|
Task.TaskStatusEnum.stopped,
|
||||||
|
):
|
||||||
|
done.append(tid)
|
||||||
|
# Check for timeout (if time limit is set)
|
||||||
|
elif self.time_limit_per_job and tid in self.task_start_times:
|
||||||
|
elapsed_minutes = (time.time() - self.task_start_times[tid]) / 60.0
|
||||||
|
if elapsed_minutes > self.time_limit_per_job:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
task.mark_stopped(force=True)
|
||||||
|
timed_out.append(tid)
|
||||||
|
except Exception:
|
||||||
|
# Don't mark as failed immediately, might be transient
|
||||||
|
# Only mark failed after multiple consecutive failures
|
||||||
|
if not hasattr(self, "_task_check_failures"):
|
||||||
|
self._task_check_failures = {}
|
||||||
|
self._task_check_failures[tid] = self._task_check_failures.get(tid, 0) + 1
|
||||||
|
if self._task_check_failures[tid] >= 5: # 5 consecutive failures
|
||||||
|
failed_to_check.append(tid)
|
||||||
|
del self._task_check_failures[tid]
|
||||||
|
|
||||||
|
# Process tasks that failed to check
|
||||||
|
for tid in failed_to_check:
|
||||||
|
tri = self.running_tasks.pop(tid)
|
||||||
|
if tid in self.task_start_times:
|
||||||
|
del self.task_start_times[tid]
|
||||||
|
# Tell SMAC this trial failed with worst possible score
|
||||||
|
res = float("-inf") if self.maximize_metric else float("inf")
|
||||||
|
cost = -res if self.maximize_metric else res
|
||||||
|
self.smac.tell(tri, TrialValue(cost=cost))
|
||||||
|
self.completed_results.append(
|
||||||
|
{
|
||||||
|
"task_id": tid,
|
||||||
|
"config": tri.config,
|
||||||
|
"budget": tri.budget,
|
||||||
|
"value": res,
|
||||||
|
"failed": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Store result with task_id for checkpoint tracking
|
||||||
|
self.evaluated_at_budget.append((tri.config, res, tri, tid))
|
||||||
|
|
||||||
|
# Process completed tasks
|
||||||
|
for tid in done:
|
||||||
|
tri = self.running_tasks.pop(tid)
|
||||||
|
if tid in self.task_start_times:
|
||||||
|
del self.task_start_times[tid]
|
||||||
|
|
||||||
|
# Clear any accumulated failures for this task
|
||||||
|
if hasattr(self, "_task_check_failures") and tid in self._task_check_failures:
|
||||||
|
del self._task_check_failures[tid]
|
||||||
|
|
||||||
|
task = self._get_task_safe(task_id=tid)
|
||||||
|
if task is None:
|
||||||
|
res = float("-inf") if self.maximize_metric else float("inf")
|
||||||
|
else:
|
||||||
|
res = self._get_objective(task)
|
||||||
|
|
||||||
|
if res is None or res == float("-inf") or res == float("inf"):
|
||||||
|
res = float("-inf") if self.maximize_metric else float("inf")
|
||||||
|
|
||||||
|
cost = -res if self.maximize_metric else res
|
||||||
|
self.smac.tell(tri, TrialValue(cost=cost))
|
||||||
|
self.completed_results.append(
|
||||||
|
{
|
||||||
|
"task_id": tid,
|
||||||
|
"config": tri.config,
|
||||||
|
"budget": tri.budget,
|
||||||
|
"value": res,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store result for this budget rung with task_id for checkpoint tracking
|
||||||
|
self.evaluated_at_budget.append((tri.config, res, tri, tid))
|
||||||
|
|
||||||
|
iteration = len(self.completed_results)
|
||||||
|
|
||||||
|
# Always report the trial score (even if it's bad)
|
||||||
|
if res is not None and res != float("-inf") and res != float("inf"):
|
||||||
|
controller.get_logger().report_scalar(
|
||||||
|
title="Optimization", series="trial_score", value=res, iteration=iteration
|
||||||
|
)
|
||||||
|
controller.get_logger().report_scalar(
|
||||||
|
title="Optimization",
|
||||||
|
series="trial_budget",
|
||||||
|
value=tri.budget or self.max_iterations,
|
||||||
|
iteration=iteration,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update best score tracking based on actual results
|
||||||
|
if res is not None and res != float("-inf") and res != float("inf"):
|
||||||
|
if self.maximize_metric:
|
||||||
|
self.best_score_so_far = max(self.best_score_so_far, res)
|
||||||
|
elif res < self.best_score_so_far:
|
||||||
|
self.best_score_so_far = res
|
||||||
|
|
||||||
|
# Always report best score so far (shows flat line when no improvement)
|
||||||
|
if self.best_score_so_far != float("-inf") and self.best_score_so_far != float("inf"):
|
||||||
|
controller.get_logger().report_scalar(
|
||||||
|
title="Optimization", series="best_score", value=self.best_score_so_far, iteration=iteration
|
||||||
|
)
|
||||||
|
|
||||||
|
# Report running statistics
|
||||||
|
valid_scores = [
|
||||||
|
r["value"]
|
||||||
|
for r in self.completed_results
|
||||||
|
if r["value"] is not None and r["value"] != float("-inf") and r["value"] != float("inf")
|
||||||
|
]
|
||||||
|
if valid_scores:
|
||||||
|
controller.get_logger().report_scalar(
|
||||||
|
title="Optimization",
|
||||||
|
series="mean_score",
|
||||||
|
value=sum(valid_scores) / len(valid_scores),
|
||||||
|
iteration=iteration,
|
||||||
|
)
|
||||||
|
controller.get_logger().report_scalar(
|
||||||
|
title="Progress",
|
||||||
|
series="completed_trials",
|
||||||
|
value=len(self.completed_results),
|
||||||
|
iteration=iteration,
|
||||||
|
)
|
||||||
|
controller.get_logger().report_scalar(
|
||||||
|
title="Progress", series="running_tasks", value=len(self.running_tasks), iteration=iteration
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process timed out tasks (treat as failed with current objective value)
|
||||||
|
for tid in timed_out:
|
||||||
|
tri = self.running_tasks.pop(tid)
|
||||||
|
if tid in self.task_start_times:
|
||||||
|
del self.task_start_times[tid]
|
||||||
|
|
||||||
|
# Clear any accumulated failures for this task
|
||||||
|
if hasattr(self, "_task_check_failures") and tid in self._task_check_failures:
|
||||||
|
del self._task_check_failures[tid]
|
||||||
|
|
||||||
|
# Try to get the last objective value before timeout
|
||||||
|
task = self._get_task_safe(task_id=tid)
|
||||||
|
if task is None:
|
||||||
|
res = float("-inf") if self.maximize_metric else float("inf")
|
||||||
|
else:
|
||||||
|
res = self._get_objective(task)
|
||||||
|
|
||||||
|
if res is None:
|
||||||
|
res = float("-inf") if self.maximize_metric else float("inf")
|
||||||
|
cost = -res if self.maximize_metric else res
|
||||||
|
self.smac.tell(tri, TrialValue(cost=cost))
|
||||||
|
self.completed_results.append(
|
||||||
|
{
|
||||||
|
"task_id": tid,
|
||||||
|
"config": tri.config,
|
||||||
|
"budget": tri.budget,
|
||||||
|
"value": res,
|
||||||
|
"timed_out": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store timed out result for this budget rung with task_id
|
||||||
|
self.evaluated_at_budget.append((tri.config, res, tri, tid))
|
||||||
|
|
||||||
|
time.sleep(self.pool_period_minutes * 60) # Fix: use correct attribute name from base class
|
||||||
|
if self.compute_time_limit and controller.get_runtime() > self.compute_time_limit * 60:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Finalize
|
||||||
|
self._finalize()
|
||||||
|
return self.completed_results
|
||||||
|
|
||||||
|
@retry_on_error(max_retries=3, initial_delay=2.0, exceptions=(SendError, ConnectionError, KeyError))
|
||||||
|
def _get_objective(self, task: Task):
|
||||||
|
"""Get objective metric value with retry logic for robustness."""
|
||||||
|
if task is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
m = task.get_last_scalar_metrics()
|
||||||
|
if not m:
|
||||||
|
return None
|
||||||
|
|
||||||
|
metric_data = m[self.metric_title][self.metric_series]
|
||||||
|
|
||||||
|
# ClearML returns dict with 'last', 'min', 'max' keys representing
|
||||||
|
# the last/min/max values of this series over ALL logged iterations.
|
||||||
|
# For snake_length/train_max: 'last' is the last logged train_max value,
|
||||||
|
# 'max' is the highest train_max ever logged during training.
|
||||||
|
|
||||||
|
# Use 'max' if maximizing (we want the best performance achieved),
|
||||||
|
# 'min' if minimizing, fallback to 'last'
|
||||||
|
if self.maximize_metric and "max" in metric_data:
|
||||||
|
result = metric_data["max"]
|
||||||
|
elif not self.maximize_metric and "min" in metric_data:
|
||||||
|
result = metric_data["min"]
|
||||||
|
else:
|
||||||
|
result = metric_data["last"]
|
||||||
|
return result
|
||||||
|
except (KeyError, Exception):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _finalize(self):
|
||||||
|
controller = Task.current_task()
|
||||||
|
# Report final best score
|
||||||
|
controller.get_logger().report_text(f"Final best score: {self.best_score_so_far}")
|
||||||
|
|
||||||
|
# Also try to get SMAC's incumbent for comparison
|
||||||
|
try:
|
||||||
|
incumbent = self.smac.intensifier.get_incumbent()
|
||||||
|
if incumbent is not None:
|
||||||
|
runhistory = self.smac.runhistory
|
||||||
|
# Try different ways to get the cost
|
||||||
|
incumbent_cost = None
|
||||||
|
try:
|
||||||
|
incumbent_cost = runhistory.get_cost(incumbent)
|
||||||
|
except Exception:
|
||||||
|
# Fallback: search through runhistory manually
|
||||||
|
for trial_key, trial_value in runhistory.items():
|
||||||
|
trial_config = runhistory.get_config(trial_key.config_id)
|
||||||
|
if trial_config == incumbent and (incumbent_cost is None or trial_value.cost < incumbent_cost):
|
||||||
|
incumbent_cost = trial_value.cost
|
||||||
|
|
||||||
|
if incumbent_cost is not None:
|
||||||
|
score = -incumbent_cost if self.maximize_metric else incumbent_cost
|
||||||
|
controller.get_logger().report_text(f"SMAC incumbent: {incumbent}, score: {score}")
|
||||||
|
controller.upload_artifact(
|
||||||
|
"best_config",
|
||||||
|
{"config": dict(incumbent), "score": score, "our_best_score": self.best_score_so_far},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
controller.upload_artifact("best_config", {"our_best_score": self.best_score_so_far})
|
||||||
|
except Exception as e:
|
||||||
|
controller.get_logger().report_text(f"Error getting SMAC incumbent: {e}")
|
||||||
|
controller.upload_artifact("best_config", {"our_best_score": self.best_score_so_far})
|
||||||
@@ -214,6 +214,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
"""Offscreen render — copies one env's state from GPU to CPU."""
|
"""Offscreen render — copies one env's state from GPU to CPU."""
|
||||||
self._render_data.qpos[:] = np.asarray(self._batch_data.qpos[env_idx])
|
self._render_data.qpos[:] = np.asarray(self._batch_data.qpos[env_idx])
|
||||||
self._render_data.qvel[:] = np.asarray(self._batch_data.qvel[env_idx])
|
self._render_data.qvel[:] = np.asarray(self._batch_data.qvel[env_idx])
|
||||||
|
self._render_data.ctrl[:] = np.asarray(self._batch_data.ctrl[env_idx])
|
||||||
mujoco.mj_forward(self._mj_model, self._render_data)
|
mujoco.mj_forward(self._mj_model, self._render_data)
|
||||||
|
|
||||||
if not hasattr(self, "_offscreen_renderer"):
|
if not hasattr(self, "_offscreen_renderer"):
|
||||||
@@ -221,4 +222,10 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
self._mj_model, width=640, height=480,
|
self._mj_model, width=640, height=480,
|
||||||
)
|
)
|
||||||
self._offscreen_renderer.update_scene(self._render_data)
|
self._offscreen_renderer.update_scene(self._render_data)
|
||||||
return self._offscreen_renderer.render()
|
frame = self._offscreen_renderer.render().copy()
|
||||||
|
|
||||||
|
# Import shared overlay helper from mujoco runner
|
||||||
|
from src.runners.mujoco import _draw_action_overlay
|
||||||
|
ctrl_val = float(self._render_data.ctrl[0]) if self._mj_model.nu > 0 else 0.0
|
||||||
|
_draw_action_overlay(frame, ctrl_val)
|
||||||
|
return frame
|
||||||
|
|||||||
@@ -283,4 +283,43 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
)
|
)
|
||||||
mujoco.mj_forward(self._model, self._data[env_idx])
|
mujoco.mj_forward(self._model, self._data[env_idx])
|
||||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||||
return self._offscreen_renderer.render()
|
frame = self._offscreen_renderer.render().copy()
|
||||||
|
|
||||||
|
# Draw action bar overlay — shows ctrl[0] as a horizontal bar
|
||||||
|
ctrl_val = float(self._data[env_idx].ctrl[0]) if self._model.nu > 0 else 0.0
|
||||||
|
_draw_action_overlay(frame, ctrl_val)
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
def _draw_action_overlay(frame: np.ndarray, action: float) -> None:
|
||||||
|
"""Draw an action bar + text on a rendered frame (no OpenCV needed).
|
||||||
|
|
||||||
|
Bar is centered horizontally: green to the right (+), red to the left (-).
|
||||||
|
"""
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
|
||||||
|
# Bar geometry
|
||||||
|
bar_y = h - 30
|
||||||
|
bar_h = 16
|
||||||
|
bar_x_center = w // 2
|
||||||
|
bar_half_w = w // 4 # max half-width of the bar
|
||||||
|
bar_x_left = bar_x_center - bar_half_w
|
||||||
|
bar_x_right = bar_x_center + bar_half_w
|
||||||
|
|
||||||
|
# Background (dark grey)
|
||||||
|
frame[bar_y:bar_y + bar_h, bar_x_left:bar_x_right] = [40, 40, 40]
|
||||||
|
|
||||||
|
# Filled bar
|
||||||
|
fill_len = int(abs(action) * bar_half_w)
|
||||||
|
if action > 0:
|
||||||
|
color = [60, 200, 60] # green
|
||||||
|
x0 = bar_x_center
|
||||||
|
x1 = min(bar_x_center + fill_len, bar_x_right)
|
||||||
|
else:
|
||||||
|
color = [200, 60, 60] # red
|
||||||
|
x1 = bar_x_center
|
||||||
|
x0 = max(bar_x_center - fill_len, bar_x_left)
|
||||||
|
frame[bar_y:bar_y + bar_h, x0:x1] = color
|
||||||
|
|
||||||
|
# Center tick mark (white)
|
||||||
|
frame[bar_y:bar_y + bar_h, bar_x_center - 1:bar_x_center + 1] = [255, 255, 255]
|
||||||
571
src/runners/serial.py
Normal file
571
src/runners/serial.py
Normal file
@@ -0,0 +1,571 @@
|
|||||||
|
"""Serial runner — real hardware over USB/serial (ESP32).
|
||||||
|
|
||||||
|
Implements the BaseRunner interface for a single physical robot.
|
||||||
|
All physics come from the real world; the runner translates between
|
||||||
|
the ESP32 serial protocol and the qpos/qvel tensors that BaseRunner
|
||||||
|
and BaseEnv expect.
|
||||||
|
|
||||||
|
Serial protocol (ESP32 firmware):
|
||||||
|
Commands sent TO the ESP32:
|
||||||
|
G — start streaming state lines
|
||||||
|
H — stop streaming
|
||||||
|
M<int> — set motor PWM speed (-255 … 255)
|
||||||
|
|
||||||
|
State lines received FROM the ESP32:
|
||||||
|
S,<ms>,<enc>,<rpm>,<motor_speed>,<at_limit>,
|
||||||
|
<pend_deg>,<pend_vel>,<target_speed>,<braking>,
|
||||||
|
<enc_vel_cps>,<pendulum_ok>
|
||||||
|
(12 comma-separated fields after the ``S`` prefix)
|
||||||
|
|
||||||
|
A daemon thread continuously reads the serial stream so the control
|
||||||
|
loop never blocks on I/O.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python train.py env=rotary_cartpole runner=serial training=ppo_real
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.core.env import BaseEnv
|
||||||
|
from src.core.hardware import HardwareConfig, load_hardware_config
|
||||||
|
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SerialRunnerConfig(BaseRunnerConfig):
|
||||||
|
"""Configuration for serial communication with the ESP32."""
|
||||||
|
|
||||||
|
num_envs: int = 1 # always 1 — single physical robot
|
||||||
|
device: str = "cpu"
|
||||||
|
|
||||||
|
port: str = "/dev/cu.usbserial-0001"
|
||||||
|
baud: int = 115200
|
||||||
|
dt: float = 0.02 # control loop period (seconds), 50 Hz
|
||||||
|
no_data_timeout: float = 2.0 # seconds of silence → disconnect
|
||||||
|
encoder_jump_threshold: int = 200 # encoder tick jump → reboot
|
||||||
|
|
||||||
|
|
||||||
|
class SerialRunner(BaseRunner[SerialRunnerConfig]):
|
||||||
|
"""BaseRunner implementation that talks to real hardware over serial.
|
||||||
|
|
||||||
|
Maps the ESP32 serial protocol to qpos/qvel tensors so the existing
|
||||||
|
RotaryCartPoleEnv (or any compatible env) works unchanged.
|
||||||
|
|
||||||
|
qpos layout: [motor_angle_rad, pendulum_angle_rad]
|
||||||
|
qvel layout: [motor_vel_rad_s, pendulum_vel_rad_s]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# BaseRunner interface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_envs(self) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def _sim_initialize(self, config: SerialRunnerConfig) -> None:
|
||||||
|
# Load hardware description (encoder, safety, reset params).
|
||||||
|
hw = load_hardware_config(self.env.config.robot_path)
|
||||||
|
if hw is None:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"hardware.yaml not found in {self.env.config.robot_path}. "
|
||||||
|
"The serial runner requires a hardware config for encoder, "
|
||||||
|
"safety, and reset parameters."
|
||||||
|
)
|
||||||
|
self._hw: HardwareConfig = hw
|
||||||
|
self._counts_per_rev: float = hw.encoder.counts_per_rev
|
||||||
|
self._max_motor_angle_rad: float = (
|
||||||
|
math.radians(hw.safety.max_motor_angle_deg)
|
||||||
|
if hw.safety.max_motor_angle_deg > 0
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Joint dimensions for the rotary cartpole (motor + pendulum).
|
||||||
|
self._nq = 2
|
||||||
|
self._nv = 2
|
||||||
|
|
||||||
|
# Import serial here so it's not a hard dependency for sim-only users.
|
||||||
|
import serial as _serial
|
||||||
|
|
||||||
|
self._serial_mod = _serial
|
||||||
|
|
||||||
|
self.ser: _serial.Serial = _serial.Serial(
|
||||||
|
config.port, config.baud, timeout=0.05
|
||||||
|
)
|
||||||
|
time.sleep(2) # Wait for ESP32 boot.
|
||||||
|
self.ser.reset_input_buffer()
|
||||||
|
|
||||||
|
# Internal state tracking.
|
||||||
|
self._rebooted: bool = False
|
||||||
|
self._serial_disconnected: bool = False
|
||||||
|
self._last_esp_ms: int = 0
|
||||||
|
self._last_data_time: float = time.monotonic()
|
||||||
|
self._last_encoder_count: int = 0
|
||||||
|
self._streaming: bool = False
|
||||||
|
|
||||||
|
# Latest parsed state (updated by the reader thread).
|
||||||
|
self._latest_state: dict[str, Any] = {
|
||||||
|
"timestamp_ms": 0,
|
||||||
|
"encoder_count": 0,
|
||||||
|
"rpm": 0.0,
|
||||||
|
"motor_speed": 0,
|
||||||
|
"at_limit": False,
|
||||||
|
"pendulum_angle": 0.0,
|
||||||
|
"pendulum_velocity": 0.0,
|
||||||
|
"target_speed": 0,
|
||||||
|
"braking": False,
|
||||||
|
"enc_vel_cps": 0.0,
|
||||||
|
"pendulum_ok": False,
|
||||||
|
}
|
||||||
|
self._state_lock = threading.Lock()
|
||||||
|
self._state_event = threading.Event()
|
||||||
|
|
||||||
|
# Start background serial reader.
|
||||||
|
self._reader_running = True
|
||||||
|
self._reader_thread = threading.Thread(
|
||||||
|
target=self._serial_reader, daemon=True
|
||||||
|
)
|
||||||
|
self._reader_thread.start()
|
||||||
|
|
||||||
|
# Start streaming.
|
||||||
|
self._send("G")
|
||||||
|
self._streaming = True
|
||||||
|
self._last_data_time = time.monotonic()
|
||||||
|
|
||||||
|
# Track wall-clock time of last step for PPO-gap detection.
|
||||||
|
self._last_step_time: float = time.monotonic()
|
||||||
|
|
||||||
|
def _sim_step(
|
||||||
|
self, actions: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
now = time.monotonic()
|
||||||
|
|
||||||
|
# Detect PPO update gap: if more than 0.5s since last step,
|
||||||
|
# the optimizer was running and no motor commands were sent.
|
||||||
|
# Trigger a full reset so the robot starts from a clean state.
|
||||||
|
gap = now - self._last_step_time
|
||||||
|
if gap > 0.5:
|
||||||
|
logger.info(
|
||||||
|
"PPO update gap detected (%.1f s) — resetting before resuming.",
|
||||||
|
gap,
|
||||||
|
)
|
||||||
|
self._send("M0")
|
||||||
|
all_ids = torch.arange(self.num_envs, device=self.device)
|
||||||
|
self._sim_reset(all_ids)
|
||||||
|
self.step_counts.zero_()
|
||||||
|
|
||||||
|
step_start = time.monotonic()
|
||||||
|
|
||||||
|
# Map normalised action [-1, 1] → PWM [-255, 255].
|
||||||
|
action_val = float(actions[0, 0].clamp(-1.0, 1.0))
|
||||||
|
motor_speed = int(action_val * 255)
|
||||||
|
self._send(f"M{motor_speed}")
|
||||||
|
|
||||||
|
# Enforce dt wall-clock timing.
|
||||||
|
elapsed = time.monotonic() - step_start
|
||||||
|
remaining = self.config.dt - elapsed
|
||||||
|
if remaining > 0:
|
||||||
|
time.sleep(remaining)
|
||||||
|
|
||||||
|
# Read latest sensor data (non-blocking — dt sleep ensures freshness).
|
||||||
|
state = self._read_state()
|
||||||
|
|
||||||
|
motor_angle = (
|
||||||
|
state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
|
||||||
|
)
|
||||||
|
motor_vel = (
|
||||||
|
state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
|
||||||
|
)
|
||||||
|
pendulum_angle = math.radians(state["pendulum_angle"])
|
||||||
|
pendulum_vel = math.radians(state["pendulum_velocity"])
|
||||||
|
|
||||||
|
# Cache motor angle for safety check in step() — avoids a second read.
|
||||||
|
self._last_motor_angle_rad = motor_angle
|
||||||
|
self._last_step_time = time.monotonic()
|
||||||
|
|
||||||
|
qpos = torch.tensor(
|
||||||
|
[[motor_angle, pendulum_angle]], dtype=torch.float32
|
||||||
|
)
|
||||||
|
qvel = torch.tensor(
|
||||||
|
[[motor_vel, pendulum_vel]], dtype=torch.float32
|
||||||
|
)
|
||||||
|
return qpos, qvel
|
||||||
|
|
||||||
|
def _sim_reset(
|
||||||
|
self, env_ids: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# If ESP32 rebooted or disconnected, we can't recover.
|
||||||
|
if self._rebooted or self._serial_disconnected:
|
||||||
|
raise RuntimeError(
|
||||||
|
"ESP32 rebooted or disconnected during training! "
|
||||||
|
"Encoder center is lost. "
|
||||||
|
"Please re-center the motor manually and restart."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop motor and restart streaming.
|
||||||
|
self._send("M0")
|
||||||
|
self._send("H")
|
||||||
|
self._streaming = False
|
||||||
|
time.sleep(0.05)
|
||||||
|
self._state_event.clear()
|
||||||
|
self._send("G")
|
||||||
|
self._streaming = True
|
||||||
|
self._last_data_time = time.monotonic()
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
# Physically return the motor to the centre position.
|
||||||
|
self._drive_to_center()
|
||||||
|
|
||||||
|
# Wait until the pendulum settles.
|
||||||
|
self._wait_for_pendulum_still()
|
||||||
|
|
||||||
|
# Refresh data timer so health checks don't false-positive.
|
||||||
|
self._last_data_time = time.monotonic()
|
||||||
|
|
||||||
|
# Read settled state and return as qpos/qvel.
|
||||||
|
state = self._read_state_blocking()
|
||||||
|
motor_angle = (
|
||||||
|
state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
|
||||||
|
)
|
||||||
|
motor_vel = (
|
||||||
|
state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
|
||||||
|
)
|
||||||
|
pendulum_angle = math.radians(state["pendulum_angle"])
|
||||||
|
pendulum_vel = math.radians(state["pendulum_velocity"])
|
||||||
|
|
||||||
|
qpos = torch.tensor(
|
||||||
|
[[motor_angle, pendulum_angle]], dtype=torch.float32
|
||||||
|
)
|
||||||
|
qvel = torch.tensor(
|
||||||
|
[[motor_vel, pendulum_vel]], dtype=torch.float32
|
||||||
|
)
|
||||||
|
return qpos, qvel
|
||||||
|
|
||||||
|
def _sim_close(self) -> None:
|
||||||
|
self._reader_running = False
|
||||||
|
self._streaming = False
|
||||||
|
self._send("H") # Stop streaming.
|
||||||
|
self._send("M0") # Stop motor.
|
||||||
|
time.sleep(0.1)
|
||||||
|
self._reader_thread.join(timeout=1.0)
|
||||||
|
self.ser.close()
|
||||||
|
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||||
|
self._offscreen_renderer.close()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# MuJoCo digital-twin rendering
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _ensure_viz_model(self) -> None:
|
||||||
|
"""Lazily load the MuJoCo model for visualisation (digital twin).
|
||||||
|
|
||||||
|
Reuses the same URDF + robot.yaml that the MuJoCoRunner would use,
|
||||||
|
but only for rendering — no physics stepping.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "_viz_model"):
|
||||||
|
return
|
||||||
|
|
||||||
|
import mujoco
|
||||||
|
from src.runners.mujoco import MuJoCoRunner
|
||||||
|
|
||||||
|
self._viz_model = MuJoCoRunner._load_model(self.env.robot)
|
||||||
|
self._viz_data = mujoco.MjData(self._viz_model)
|
||||||
|
self._offscreen_renderer = None
|
||||||
|
|
||||||
|
def _sync_viz(self) -> None:
|
||||||
|
"""Copy current serial sensor state into the MuJoCo viz model."""
|
||||||
|
import mujoco
|
||||||
|
|
||||||
|
self._ensure_viz_model()
|
||||||
|
state = self._read_state()
|
||||||
|
|
||||||
|
# Set joint positions from serial data.
|
||||||
|
motor_angle = (
|
||||||
|
state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
|
||||||
|
)
|
||||||
|
pendulum_angle = math.radians(state["pendulum_angle"])
|
||||||
|
self._viz_data.qpos[0] = motor_angle
|
||||||
|
self._viz_data.qpos[1] = pendulum_angle
|
||||||
|
|
||||||
|
# Set joint velocities (for any velocity-dependent visuals).
|
||||||
|
motor_vel = (
|
||||||
|
state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
|
||||||
|
)
|
||||||
|
pendulum_vel = math.radians(state["pendulum_velocity"])
|
||||||
|
self._viz_data.qvel[0] = motor_vel
|
||||||
|
self._viz_data.qvel[1] = pendulum_vel
|
||||||
|
|
||||||
|
# Forward kinematics (updates body positions for rendering).
|
||||||
|
mujoco.mj_forward(self._viz_model, self._viz_data)
|
||||||
|
|
||||||
|
def render(self, env_idx: int = 0) -> np.ndarray:
|
||||||
|
"""Offscreen render of the digital-twin MuJoCo model.
|
||||||
|
|
||||||
|
Called by VideoRecordingTrainer during training to capture frames.
|
||||||
|
"""
|
||||||
|
import mujoco
|
||||||
|
|
||||||
|
self._sync_viz()
|
||||||
|
|
||||||
|
if self._offscreen_renderer is None:
|
||||||
|
self._offscreen_renderer = mujoco.Renderer(
|
||||||
|
self._viz_model, width=640, height=480,
|
||||||
|
)
|
||||||
|
self._offscreen_renderer.update_scene(self._viz_data)
|
||||||
|
return self._offscreen_renderer.render().copy()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Override step() for runner-level safety
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, actions: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
|
||||||
|
# Check for ESP32 reboot / disconnect BEFORE stepping.
|
||||||
|
if self._rebooted or self._serial_disconnected:
|
||||||
|
self._send("M0")
|
||||||
|
# Return a terminal observation with penalty.
|
||||||
|
qpos, qvel = self._make_current_state()
|
||||||
|
state = self.env.build_state(qpos, qvel)
|
||||||
|
obs = self.env.compute_observations(state)
|
||||||
|
reward = torch.tensor([[-100.0]])
|
||||||
|
terminated = torch.tensor([[True]])
|
||||||
|
truncated = torch.tensor([[False]])
|
||||||
|
return obs, reward, terminated, truncated, {"reboot_detected": True}
|
||||||
|
|
||||||
|
# Normal step via BaseRunner (calls _sim_step → env logic).
|
||||||
|
obs, rewards, terminated, truncated, info = super().step(actions)
|
||||||
|
|
||||||
|
# Check connection health after stepping.
|
||||||
|
if not self._check_connection_health():
|
||||||
|
self._send("M0")
|
||||||
|
terminated = torch.tensor([[True]])
|
||||||
|
rewards = torch.tensor([[-100.0]])
|
||||||
|
info["reboot_detected"] = True
|
||||||
|
|
||||||
|
# Check motor angle against hard safety limit.
|
||||||
|
# Uses the cached value from _sim_step — no extra serial read.
|
||||||
|
if self._max_motor_angle_rad > 0:
|
||||||
|
motor_angle = abs(getattr(self, "_last_motor_angle_rad", 0.0))
|
||||||
|
if motor_angle >= self._max_motor_angle_rad:
|
||||||
|
self._send("M0")
|
||||||
|
terminated = torch.tensor([[True]])
|
||||||
|
rewards = torch.tensor([[-100.0]])
|
||||||
|
info["motor_limit_exceeded"] = True
|
||||||
|
|
||||||
|
# Always stop motor on episode end.
|
||||||
|
if terminated.any() or truncated.any():
|
||||||
|
self._send("M0")
|
||||||
|
|
||||||
|
return obs, rewards, terminated, truncated, info
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Serial helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _send(self, cmd: str) -> None:
|
||||||
|
"""Send a command to the ESP32."""
|
||||||
|
try:
|
||||||
|
self.ser.write(f"{cmd}\n".encode())
|
||||||
|
except (OSError, self._serial_mod.SerialException):
|
||||||
|
self._serial_disconnected = True
|
||||||
|
|
||||||
|
def _serial_reader(self) -> None:
|
||||||
|
"""Background thread: continuously read and parse serial lines."""
|
||||||
|
while self._reader_running:
|
||||||
|
try:
|
||||||
|
if self.ser.in_waiting:
|
||||||
|
line = (
|
||||||
|
self.ser.readline()
|
||||||
|
.decode("utf-8", errors="ignore")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detect ESP32 reboot: it prints READY on startup.
|
||||||
|
if line.startswith("READY"):
|
||||||
|
self._rebooted = True
|
||||||
|
logger.critical("ESP32 reboot detected: %s", line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.startswith("S,"):
|
||||||
|
parts = line.split(",")
|
||||||
|
if len(parts) >= 12:
|
||||||
|
esp_ms = int(parts[1])
|
||||||
|
enc = int(parts[2])
|
||||||
|
|
||||||
|
# Detect reboot: timestamp jumped backwards.
|
||||||
|
if (
|
||||||
|
self._last_esp_ms > 5000
|
||||||
|
and esp_ms < self._last_esp_ms - 3000
|
||||||
|
):
|
||||||
|
self._rebooted = True
|
||||||
|
logger.critical(
|
||||||
|
"ESP32 reboot detected: timestamp"
|
||||||
|
" %d -> %d",
|
||||||
|
self._last_esp_ms,
|
||||||
|
esp_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detect reboot: encoder snapped to 0 from
|
||||||
|
# a far position.
|
||||||
|
if (
|
||||||
|
abs(self._last_encoder_count)
|
||||||
|
> self.config.encoder_jump_threshold
|
||||||
|
and abs(enc) < 5
|
||||||
|
):
|
||||||
|
self._rebooted = True
|
||||||
|
logger.critical(
|
||||||
|
"ESP32 reboot detected: encoder"
|
||||||
|
" jumped %d -> %d",
|
||||||
|
self._last_encoder_count,
|
||||||
|
enc,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._last_esp_ms = esp_ms
|
||||||
|
self._last_encoder_count = enc
|
||||||
|
self._last_data_time = time.monotonic()
|
||||||
|
|
||||||
|
parsed: dict[str, Any] = {
|
||||||
|
"timestamp_ms": esp_ms,
|
||||||
|
"encoder_count": enc,
|
||||||
|
"rpm": float(parts[3]),
|
||||||
|
"motor_speed": int(parts[4]),
|
||||||
|
"at_limit": bool(int(parts[5])),
|
||||||
|
"pendulum_angle": float(parts[6]),
|
||||||
|
"pendulum_velocity": float(parts[7]),
|
||||||
|
"target_speed": int(parts[8]),
|
||||||
|
"braking": bool(int(parts[9])),
|
||||||
|
"enc_vel_cps": float(parts[10]),
|
||||||
|
"pendulum_ok": bool(int(parts[11])),
|
||||||
|
}
|
||||||
|
with self._state_lock:
|
||||||
|
self._latest_state = parsed
|
||||||
|
self._state_event.set()
|
||||||
|
else:
|
||||||
|
time.sleep(0.001) # Avoid busy-spinning.
|
||||||
|
except (OSError, self._serial_mod.SerialException) as exc:
|
||||||
|
self._serial_disconnected = True
|
||||||
|
logger.critical("Serial connection lost: %s", exc)
|
||||||
|
break
|
||||||
|
|
||||||
|
def _check_connection_health(self) -> bool:
|
||||||
|
"""Return True if the ESP32 connection appears healthy."""
|
||||||
|
if self._serial_disconnected:
|
||||||
|
logger.critical("ESP32 serial connection lost.")
|
||||||
|
return False
|
||||||
|
if (
|
||||||
|
self._streaming
|
||||||
|
and (time.monotonic() - self._last_data_time)
|
||||||
|
> self.config.no_data_timeout
|
||||||
|
):
|
||||||
|
logger.critical(
|
||||||
|
"No data from ESP32 for %.1f s — possible crash/disconnect.",
|
||||||
|
time.monotonic() - self._last_data_time,
|
||||||
|
)
|
||||||
|
self._rebooted = True
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _read_state(self) -> dict[str, Any]:
|
||||||
|
"""Return the most recent state from the reader thread (non-blocking).
|
||||||
|
|
||||||
|
The background thread updates at ~50 Hz and `_sim_step` already
|
||||||
|
sleeps for `dt` before calling this, so the data is always fresh.
|
||||||
|
"""
|
||||||
|
with self._state_lock:
|
||||||
|
return dict(self._latest_state)
|
||||||
|
|
||||||
|
def _read_state_blocking(self, timeout: float = 0.05) -> dict[str, Any]:
|
||||||
|
"""Wait for a fresh sample, then return it.
|
||||||
|
|
||||||
|
Used during reset / settling where we need to guarantee we have
|
||||||
|
a new reading (no prior dt sleep).
|
||||||
|
"""
|
||||||
|
self._state_event.clear()
|
||||||
|
self._state_event.wait(timeout=timeout)
|
||||||
|
with self._state_lock:
|
||||||
|
return dict(self._latest_state)
|
||||||
|
|
||||||
|
def _make_current_state(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Build qpos/qvel from current sensor data (utility)."""
|
||||||
|
state = self._read_state_blocking()
|
||||||
|
motor_angle = (
|
||||||
|
state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
|
||||||
|
)
|
||||||
|
motor_vel = (
|
||||||
|
state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
|
||||||
|
)
|
||||||
|
pendulum_angle = math.radians(state["pendulum_angle"])
|
||||||
|
pendulum_vel = math.radians(state["pendulum_velocity"])
|
||||||
|
|
||||||
|
qpos = torch.tensor(
|
||||||
|
[[motor_angle, pendulum_angle]], dtype=torch.float32
|
||||||
|
)
|
||||||
|
qvel = torch.tensor(
|
||||||
|
[[motor_vel, pendulum_vel]], dtype=torch.float32
|
||||||
|
)
|
||||||
|
return qpos, qvel
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Physical reset helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _drive_to_center(self) -> None:
|
||||||
|
"""Drive the motor back toward encoder=0 using bang-bang control."""
|
||||||
|
rc = self._hw.reset
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < rc.drive_timeout:
|
||||||
|
state = self._read_state_blocking()
|
||||||
|
enc = state["encoder_count"]
|
||||||
|
if abs(enc) < rc.deadband:
|
||||||
|
break
|
||||||
|
speed = rc.drive_speed if enc < 0 else -rc.drive_speed
|
||||||
|
self._send(f"M{speed}")
|
||||||
|
time.sleep(0.05)
|
||||||
|
self._send("M0")
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
def _wait_for_pendulum_still(self) -> None:
|
||||||
|
"""Block until the pendulum has settled (angle and velocity near zero)."""
|
||||||
|
rc = self._hw.reset
|
||||||
|
stable_since: float | None = None
|
||||||
|
start = time.monotonic()
|
||||||
|
|
||||||
|
while time.monotonic() - start < rc.settle_timeout:
|
||||||
|
state = self._read_state_blocking()
|
||||||
|
angle_ok = abs(state["pendulum_angle"]) < rc.settle_angle_deg
|
||||||
|
vel_ok = abs(state["pendulum_velocity"]) < rc.settle_vel_dps
|
||||||
|
|
||||||
|
if angle_ok and vel_ok:
|
||||||
|
if stable_since is None:
|
||||||
|
stable_since = time.monotonic()
|
||||||
|
elif time.monotonic() - stable_since >= rc.settle_duration:
|
||||||
|
logger.info(
|
||||||
|
"Pendulum settled after %.2f s",
|
||||||
|
time.monotonic() - start,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
stable_since = None
|
||||||
|
time.sleep(0.02)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"Pendulum did not fully settle within %.1f s — proceeding anyway.",
|
||||||
|
rc.settle_timeout,
|
||||||
|
)
|
||||||
1
src/sysid/__init__.py
Normal file
1
src/sysid/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""System identification — tune simulation parameters to match real hardware."""
|
||||||
381
src/sysid/capture.py
Normal file
381
src/sysid/capture.py
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
"""Capture a real-robot trajectory under random excitation (PRBS-style).
|
||||||
|
|
||||||
|
Connects to the ESP32 over serial, sends random PWM commands to excite
|
||||||
|
the system, and records motor + pendulum angles and velocities at ~50 Hz.
|
||||||
|
|
||||||
|
Saves a compressed numpy archive (.npz) that the optimizer can replay
|
||||||
|
in simulation to fit physics parameters.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m src.sysid.capture \
|
||||||
|
--robot-path assets/rotary_cartpole \
|
||||||
|
--port /dev/cu.usbserial-0001 \
|
||||||
|
--duration 20
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import structlog
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
log = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Serial protocol helpers (mirrored from SerialRunner) ─────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_state_line(line: str) -> dict[str, Any] | None:
|
||||||
|
"""Parse an ``S,…`` state line from the ESP32."""
|
||||||
|
if not line.startswith("S,"):
|
||||||
|
return None
|
||||||
|
parts = line.split(",")
|
||||||
|
if len(parts) < 12:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
"timestamp_ms": int(parts[1]),
|
||||||
|
"encoder_count": int(parts[2]),
|
||||||
|
"rpm": float(parts[3]),
|
||||||
|
"motor_speed": int(parts[4]),
|
||||||
|
"at_limit": bool(int(parts[5])),
|
||||||
|
"pendulum_angle": float(parts[6]),
|
||||||
|
"pendulum_velocity": float(parts[7]),
|
||||||
|
"target_speed": int(parts[8]),
|
||||||
|
"braking": bool(int(parts[9])),
|
||||||
|
"enc_vel_cps": float(parts[10]),
|
||||||
|
"pendulum_ok": bool(int(parts[11])),
|
||||||
|
}
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Background serial reader ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _SerialReader:
|
||||||
|
"""Minimal background reader for the ESP32 serial stream."""
|
||||||
|
|
||||||
|
def __init__(self, port: str, baud: int = 115200):
|
||||||
|
import serial as _serial
|
||||||
|
|
||||||
|
self._serial_mod = _serial
|
||||||
|
self.ser = _serial.Serial(port, baud, timeout=0.05)
|
||||||
|
time.sleep(2) # Wait for ESP32 boot.
|
||||||
|
self.ser.reset_input_buffer()
|
||||||
|
|
||||||
|
self._latest: dict[str, Any] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._event = threading.Event()
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
self._thread = threading.Thread(target=self._reader_loop, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def _reader_loop(self) -> None:
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
if self.ser.in_waiting:
|
||||||
|
line = (
|
||||||
|
self.ser.readline()
|
||||||
|
.decode("utf-8", errors="ignore")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
parsed = _parse_state_line(line)
|
||||||
|
if parsed is not None:
|
||||||
|
with self._lock:
|
||||||
|
self._latest = parsed
|
||||||
|
self._event.set()
|
||||||
|
else:
|
||||||
|
time.sleep(0.001)
|
||||||
|
except (OSError, self._serial_mod.SerialException):
|
||||||
|
log.critical("serial_lost")
|
||||||
|
break
|
||||||
|
|
||||||
|
def send(self, cmd: str) -> None:
|
||||||
|
try:
|
||||||
|
self.ser.write(f"{cmd}\n".encode())
|
||||||
|
except (OSError, self._serial_mod.SerialException):
|
||||||
|
log.critical("serial_send_failed", cmd=cmd)
|
||||||
|
|
||||||
|
def read(self) -> dict[str, Any]:
|
||||||
|
with self._lock:
|
||||||
|
return dict(self._latest)
|
||||||
|
|
||||||
|
def read_blocking(self, timeout: float = 0.1) -> dict[str, Any]:
|
||||||
|
self._event.clear()
|
||||||
|
self._event.wait(timeout=timeout)
|
||||||
|
return self.read()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._running = False
|
||||||
|
self.send("H")
|
||||||
|
self.send("M0")
|
||||||
|
time.sleep(0.1)
|
||||||
|
self._thread.join(timeout=1.0)
|
||||||
|
self.ser.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── PRBS excitation signal ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _PRBSExcitation:
|
||||||
|
"""Random hold-value excitation with configurable amplitude and hold time.
|
||||||
|
|
||||||
|
At each call to ``__call__``, returns the current PWM value.
|
||||||
|
The value is held for a random duration (``hold_min``–``hold_max`` ms),
|
||||||
|
then a new random value is drawn uniformly from ``[-amplitude, +amplitude]``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
amplitude: int = 180,
|
||||||
|
hold_min_ms: int = 50,
|
||||||
|
hold_max_ms: int = 300,
|
||||||
|
):
|
||||||
|
self.amplitude = amplitude
|
||||||
|
self.hold_min_ms = hold_min_ms
|
||||||
|
self.hold_max_ms = hold_max_ms
|
||||||
|
self._current: int = 0
|
||||||
|
self._switch_time: float = 0.0
|
||||||
|
self._new_value()
|
||||||
|
|
||||||
|
def _new_value(self) -> None:
|
||||||
|
self._current = random.randint(-self.amplitude, self.amplitude)
|
||||||
|
hold_ms = random.randint(self.hold_min_ms, self.hold_max_ms)
|
||||||
|
self._switch_time = time.monotonic() + hold_ms / 1000.0
|
||||||
|
|
||||||
|
def __call__(self) -> int:
|
||||||
|
if time.monotonic() >= self._switch_time:
|
||||||
|
self._new_value()
|
||||||
|
return self._current
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main capture loop ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def capture(
|
||||||
|
robot_path: str | Path,
|
||||||
|
port: str = "/dev/cu.usbserial-0001",
|
||||||
|
baud: int = 115200,
|
||||||
|
duration: float = 20.0,
|
||||||
|
amplitude: int = 180,
|
||||||
|
hold_min_ms: int = 50,
|
||||||
|
hold_max_ms: int = 300,
|
||||||
|
dt: float = 0.02,
|
||||||
|
) -> Path:
|
||||||
|
"""Run the capture procedure and return the path to the saved .npz file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
robot_path : path to robot asset directory (contains hardware.yaml)
|
||||||
|
port : serial port for ESP32
|
||||||
|
baud : baud rate
|
||||||
|
duration : capture duration in seconds
|
||||||
|
amplitude : max PWM magnitude for excitation (0–255)
|
||||||
|
hold_min_ms / hold_max_ms : random hold time range (ms)
|
||||||
|
dt : target sampling period (seconds), default 50 Hz
|
||||||
|
"""
|
||||||
|
robot_path = Path(robot_path).resolve()
|
||||||
|
|
||||||
|
# Load hardware config for encoder conversion + safety.
|
||||||
|
hw_yaml = robot_path / "hardware.yaml"
|
||||||
|
if not hw_yaml.exists():
|
||||||
|
raise FileNotFoundError(f"hardware.yaml not found in {robot_path}")
|
||||||
|
raw_hw = yaml.safe_load(hw_yaml.read_text())
|
||||||
|
ppr = raw_hw.get("encoder", {}).get("ppr", 11)
|
||||||
|
gear_ratio = raw_hw.get("encoder", {}).get("gear_ratio", 30.0)
|
||||||
|
counts_per_rev: float = ppr * gear_ratio * 4.0
|
||||||
|
max_motor_deg = raw_hw.get("safety", {}).get("max_motor_angle_deg", 90.0)
|
||||||
|
max_motor_rad = math.radians(max_motor_deg) if max_motor_deg > 0 else 0.0
|
||||||
|
|
||||||
|
# Connect.
|
||||||
|
reader = _SerialReader(port, baud)
|
||||||
|
excitation = _PRBSExcitation(amplitude, hold_min_ms, hold_max_ms)
|
||||||
|
|
||||||
|
# Prepare recording buffers.
|
||||||
|
max_samples = int(duration / dt) + 500 # headroom
|
||||||
|
rec_time = np.zeros(max_samples, dtype=np.float64)
|
||||||
|
rec_action = np.zeros(max_samples, dtype=np.float64)
|
||||||
|
rec_motor_angle = np.zeros(max_samples, dtype=np.float64)
|
||||||
|
rec_motor_vel = np.zeros(max_samples, dtype=np.float64)
|
||||||
|
rec_pend_angle = np.zeros(max_samples, dtype=np.float64)
|
||||||
|
rec_pend_vel = np.zeros(max_samples, dtype=np.float64)
|
||||||
|
|
||||||
|
# Start streaming.
|
||||||
|
reader.send("G")
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"capture_starting",
|
||||||
|
port=port,
|
||||||
|
duration=duration,
|
||||||
|
amplitude=amplitude,
|
||||||
|
hold_range_ms=f"{hold_min_ms}–{hold_max_ms}",
|
||||||
|
dt=dt,
|
||||||
|
)
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
loop_start = time.monotonic()
|
||||||
|
elapsed = loop_start - t0
|
||||||
|
if elapsed >= duration:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get excitation PWM.
|
||||||
|
pwm = excitation()
|
||||||
|
|
||||||
|
# Safety: reverse/zero if near motor limit.
|
||||||
|
state = reader.read()
|
||||||
|
if state:
|
||||||
|
motor_angle_rad = (
|
||||||
|
state.get("encoder_count", 0) / counts_per_rev * 2.0 * math.pi
|
||||||
|
)
|
||||||
|
if max_motor_rad > 0:
|
||||||
|
margin = max_motor_rad * 0.85 # start braking at 85%
|
||||||
|
if motor_angle_rad > margin and pwm > 0:
|
||||||
|
pwm = -abs(pwm) # reverse
|
||||||
|
elif motor_angle_rad < -margin and pwm < 0:
|
||||||
|
pwm = abs(pwm) # reverse
|
||||||
|
|
||||||
|
# Send command.
|
||||||
|
reader.send(f"M{pwm}")
|
||||||
|
|
||||||
|
# Wait for fresh data.
|
||||||
|
time.sleep(max(0, dt - (time.monotonic() - loop_start) - 0.005))
|
||||||
|
state = reader.read_blocking(timeout=dt)
|
||||||
|
|
||||||
|
if state:
|
||||||
|
enc = state.get("encoder_count", 0)
|
||||||
|
motor_angle = enc / counts_per_rev * 2.0 * math.pi
|
||||||
|
motor_vel = (
|
||||||
|
state.get("enc_vel_cps", 0.0) / counts_per_rev * 2.0 * math.pi
|
||||||
|
)
|
||||||
|
pend_angle = math.radians(state.get("pendulum_angle", 0.0))
|
||||||
|
pend_vel = math.radians(state.get("pendulum_velocity", 0.0))
|
||||||
|
|
||||||
|
# Normalised action: PWM / 255 → [-1, 1]
|
||||||
|
action_norm = pwm / 255.0
|
||||||
|
|
||||||
|
if idx < max_samples:
|
||||||
|
rec_time[idx] = elapsed
|
||||||
|
rec_action[idx] = action_norm
|
||||||
|
rec_motor_angle[idx] = motor_angle
|
||||||
|
rec_motor_vel[idx] = motor_vel
|
||||||
|
rec_pend_angle[idx] = pend_angle
|
||||||
|
rec_pend_vel[idx] = pend_vel
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# Progress.
|
||||||
|
if idx % 50 == 0:
|
||||||
|
log.info(
|
||||||
|
"capture_progress",
|
||||||
|
elapsed=f"{elapsed:.1f}/{duration:.0f}s",
|
||||||
|
samples=idx,
|
||||||
|
pwm=pwm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pace to dt.
|
||||||
|
remaining = dt - (time.monotonic() - loop_start)
|
||||||
|
if remaining > 0:
|
||||||
|
time.sleep(remaining)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
reader.send("M0")
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
# Trim to actual sample count.
|
||||||
|
rec_time = rec_time[:idx]
|
||||||
|
rec_action = rec_action[:idx]
|
||||||
|
rec_motor_angle = rec_motor_angle[:idx]
|
||||||
|
rec_motor_vel = rec_motor_vel[:idx]
|
||||||
|
rec_pend_angle = rec_pend_angle[:idx]
|
||||||
|
rec_pend_vel = rec_pend_vel[:idx]
|
||||||
|
|
||||||
|
# Save.
|
||||||
|
recordings_dir = robot_path / "recordings"
|
||||||
|
recordings_dir.mkdir(exist_ok=True)
|
||||||
|
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
out_path = recordings_dir / f"capture_{stamp}.npz"
|
||||||
|
np.savez_compressed(
|
||||||
|
out_path,
|
||||||
|
time=rec_time,
|
||||||
|
action=rec_action,
|
||||||
|
motor_angle=rec_motor_angle,
|
||||||
|
motor_vel=rec_motor_vel,
|
||||||
|
pendulum_angle=rec_pend_angle,
|
||||||
|
pendulum_vel=rec_pend_vel,
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"capture_saved",
|
||||||
|
path=str(out_path),
|
||||||
|
samples=idx,
|
||||||
|
duration_actual=f"{rec_time[-1]:.2f}s" if idx > 0 else "0s",
|
||||||
|
)
|
||||||
|
return out_path
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI entry point ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Capture a real-robot trajectory for system identification."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--robot-path",
|
||||||
|
type=str,
|
||||||
|
default="assets/rotary_cartpole",
|
||||||
|
help="Path to robot asset directory (contains hardware.yaml)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=str,
|
||||||
|
default="/dev/cu.usbserial-0001",
|
||||||
|
help="Serial port for ESP32",
|
||||||
|
)
|
||||||
|
parser.add_argument("--baud", type=int, default=115200)
|
||||||
|
parser.add_argument(
|
||||||
|
"--duration", type=float, default=20.0, help="Capture duration (s)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--amplitude", type=int, default=180, help="Max PWM magnitude (0–255)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hold-min-ms", type=int, default=50, help="Min hold time (ms)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hold-max-ms", type=int, default=300, help="Max hold time (ms)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dt", type=float, default=0.02, help="Sample period (s)"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
capture(
|
||||||
|
robot_path=args.robot_path,
|
||||||
|
port=args.port,
|
||||||
|
baud=args.baud,
|
||||||
|
duration=args.duration,
|
||||||
|
amplitude=args.amplitude,
|
||||||
|
hold_min_ms=args.hold_min_ms,
|
||||||
|
hold_max_ms=args.hold_max_ms,
|
||||||
|
dt=args.dt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
186
src/sysid/export.py
Normal file
186
src/sysid/export.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""Export tuned parameters to URDF and robot.yaml files.
|
||||||
|
|
||||||
|
Reads the original files, injects the optimised parameter values,
|
||||||
|
and writes ``rotary_cartpole_tuned.urdf`` + ``robot_tuned.yaml``
|
||||||
|
alongside the originals in the robot asset directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
log = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def export_tuned_files(
|
||||||
|
robot_path: str | Path,
|
||||||
|
params: dict[str, float],
|
||||||
|
) -> tuple[Path, Path]:
|
||||||
|
"""Write tuned URDF and robot.yaml files.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
robot_path : robot asset directory (contains robot.yaml + *.urdf)
|
||||||
|
params : dict of parameter name → tuned value (from optimizer)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(tuned_urdf_path, tuned_robot_yaml_path)
|
||||||
|
"""
|
||||||
|
robot_path = Path(robot_path).resolve()
|
||||||
|
|
||||||
|
# ── Load originals ───────────────────────────────────────────
|
||||||
|
robot_yaml_path = robot_path / "robot.yaml"
|
||||||
|
robot_cfg = yaml.safe_load(robot_yaml_path.read_text())
|
||||||
|
urdf_path = robot_path / robot_cfg["urdf"]
|
||||||
|
|
||||||
|
# ── Tune URDF ────────────────────────────────────────────────
|
||||||
|
tree = ET.parse(urdf_path)
|
||||||
|
root = tree.getroot()
|
||||||
|
|
||||||
|
for link in root.iter("link"):
|
||||||
|
link_name = link.get("name", "")
|
||||||
|
inertial = link.find("inertial")
|
||||||
|
if inertial is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if link_name == "arm":
|
||||||
|
_set_mass(inertial, params.get("arm_mass"))
|
||||||
|
_set_com(
|
||||||
|
inertial,
|
||||||
|
params.get("arm_com_x"),
|
||||||
|
params.get("arm_com_y"),
|
||||||
|
params.get("arm_com_z"),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif link_name == "pendulum":
|
||||||
|
_set_mass(inertial, params.get("pendulum_mass"))
|
||||||
|
_set_com(
|
||||||
|
inertial,
|
||||||
|
params.get("pendulum_com_x"),
|
||||||
|
params.get("pendulum_com_y"),
|
||||||
|
params.get("pendulum_com_z"),
|
||||||
|
)
|
||||||
|
_set_inertia(
|
||||||
|
inertial,
|
||||||
|
ixx=params.get("pendulum_ixx"),
|
||||||
|
iyy=params.get("pendulum_iyy"),
|
||||||
|
izz=params.get("pendulum_izz"),
|
||||||
|
ixy=params.get("pendulum_ixy"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write tuned URDF.
|
||||||
|
tuned_urdf_name = urdf_path.stem + "_tuned" + urdf_path.suffix
|
||||||
|
tuned_urdf_path = robot_path / tuned_urdf_name
|
||||||
|
|
||||||
|
# Preserve the XML declaration and original formatting as much as possible.
|
||||||
|
ET.indent(tree, space=" ")
|
||||||
|
tree.write(str(tuned_urdf_path), xml_declaration=True, encoding="unicode")
|
||||||
|
log.info("tuned_urdf_written", path=str(tuned_urdf_path))
|
||||||
|
|
||||||
|
# ── Tune robot.yaml ──────────────────────────────────────────
|
||||||
|
tuned_cfg = copy.deepcopy(robot_cfg)
|
||||||
|
|
||||||
|
# Point to the tuned URDF.
|
||||||
|
tuned_cfg["urdf"] = tuned_urdf_name
|
||||||
|
|
||||||
|
# Update actuator parameters.
|
||||||
|
if tuned_cfg.get("actuators") and len(tuned_cfg["actuators"]) > 0:
|
||||||
|
act = tuned_cfg["actuators"][0]
|
||||||
|
if "actuator_gear" in params:
|
||||||
|
act["gear"] = round(params["actuator_gear"], 6)
|
||||||
|
if "actuator_filter_tau" in params:
|
||||||
|
act["filter_tau"] = round(params["actuator_filter_tau"], 6)
|
||||||
|
if "motor_damping" in params:
|
||||||
|
act["damping"] = round(params["motor_damping"], 6)
|
||||||
|
|
||||||
|
# Update joint overrides.
|
||||||
|
if "joints" not in tuned_cfg:
|
||||||
|
tuned_cfg["joints"] = {}
|
||||||
|
|
||||||
|
if "motor_joint" not in tuned_cfg["joints"]:
|
||||||
|
tuned_cfg["joints"]["motor_joint"] = {}
|
||||||
|
mj = tuned_cfg["joints"]["motor_joint"]
|
||||||
|
if "motor_armature" in params:
|
||||||
|
mj["armature"] = round(params["motor_armature"], 6)
|
||||||
|
if "motor_frictionloss" in params:
|
||||||
|
mj["frictionloss"] = round(params["motor_frictionloss"], 6)
|
||||||
|
|
||||||
|
if "pendulum_joint" not in tuned_cfg["joints"]:
|
||||||
|
tuned_cfg["joints"]["pendulum_joint"] = {}
|
||||||
|
pj = tuned_cfg["joints"]["pendulum_joint"]
|
||||||
|
if "pendulum_damping" in params:
|
||||||
|
pj["damping"] = round(params["pendulum_damping"], 6)
|
||||||
|
|
||||||
|
# Write tuned robot.yaml.
|
||||||
|
tuned_yaml_path = robot_path / "robot_tuned.yaml"
|
||||||
|
|
||||||
|
# Add a header comment.
|
||||||
|
header = (
|
||||||
|
"# Tuned robot config — generated by src.sysid.optimize\n"
|
||||||
|
"# Original: robot.yaml\n"
|
||||||
|
"# Run `python -m src.sysid.visualize` to compare real vs sim.\n\n"
|
||||||
|
)
|
||||||
|
tuned_yaml_path.write_text(
|
||||||
|
header + yaml.dump(tuned_cfg, default_flow_style=False, sort_keys=False)
|
||||||
|
)
|
||||||
|
log.info("tuned_robot_yaml_written", path=str(tuned_yaml_path))
|
||||||
|
|
||||||
|
return tuned_urdf_path, tuned_yaml_path
|
||||||
|
|
||||||
|
|
||||||
|
# ── XML helpers (shared with rollout.py) ─────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _set_mass(inertial: ET.Element, mass: float | None) -> None:
|
||||||
|
if mass is None:
|
||||||
|
return
|
||||||
|
mass_el = inertial.find("mass")
|
||||||
|
if mass_el is not None:
|
||||||
|
mass_el.set("value", str(mass))
|
||||||
|
|
||||||
|
|
||||||
|
def _set_com(
|
||||||
|
inertial: ET.Element,
|
||||||
|
x: float | None,
|
||||||
|
y: float | None,
|
||||||
|
z: float | None,
|
||||||
|
) -> None:
|
||||||
|
origin = inertial.find("origin")
|
||||||
|
if origin is None:
|
||||||
|
return
|
||||||
|
xyz = origin.get("xyz", "0 0 0").split()
|
||||||
|
if x is not None:
|
||||||
|
xyz[0] = str(x)
|
||||||
|
if y is not None:
|
||||||
|
xyz[1] = str(y)
|
||||||
|
if z is not None:
|
||||||
|
xyz[2] = str(z)
|
||||||
|
origin.set("xyz", " ".join(xyz))
|
||||||
|
|
||||||
|
|
||||||
|
def _set_inertia(
|
||||||
|
inertial: ET.Element,
|
||||||
|
ixx: float | None = None,
|
||||||
|
iyy: float | None = None,
|
||||||
|
izz: float | None = None,
|
||||||
|
ixy: float | None = None,
|
||||||
|
iyz: float | None = None,
|
||||||
|
ixz: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
ine = inertial.find("inertia")
|
||||||
|
if ine is None:
|
||||||
|
return
|
||||||
|
for attr, val in [
|
||||||
|
("ixx", ixx), ("iyy", iyy), ("izz", izz),
|
||||||
|
("ixy", ixy), ("iyz", iyz), ("ixz", ixz),
|
||||||
|
]:
|
||||||
|
if val is not None:
|
||||||
|
ine.set(attr, str(val))
|
||||||
376
src/sysid/optimize.py
Normal file
376
src/sysid/optimize.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
"""CMA-ES optimiser — fit simulation parameters to a real-robot recording.
|
||||||
|
|
||||||
|
Minimises the trajectory-matching cost between a MuJoCo rollout and a
|
||||||
|
recorded real-robot sequence. Uses the ``cmaes`` package (pure-Python
|
||||||
|
CMA-ES with native box-constraint support).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m src.sysid.optimize \
|
||||||
|
--robot-path assets/rotary_cartpole \
|
||||||
|
--recording assets/rotary_cartpole/recordings/capture_20260311_120000.npz
|
||||||
|
|
||||||
|
# Shorter run for testing:
|
||||||
|
python -m src.sysid.optimize \
|
||||||
|
--robot-path assets/rotary_cartpole \
|
||||||
|
--recording <file>.npz \
|
||||||
|
--max-generations 10 --population-size 8
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from src.sysid.rollout import (
|
||||||
|
ROTARY_CARTPOLE_PARAMS,
|
||||||
|
ParamSpec,
|
||||||
|
bounds_arrays,
|
||||||
|
defaults_vector,
|
||||||
|
params_to_dict,
|
||||||
|
rollout,
|
||||||
|
windowed_rollout,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cost function ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _angle_diff(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||||
|
"""Shortest signed angle difference, handling wrapping."""
|
||||||
|
return np.arctan2(np.sin(a - b), np.cos(a - b))
|
||||||
|
|
||||||
|
|
||||||
|
def _check_inertia_valid(params: dict[str, float]) -> bool:
|
||||||
|
"""Quick reject: pendulum inertia tensor must be positive-definite."""
|
||||||
|
ixx = params.get("pendulum_ixx", 6.16e-06)
|
||||||
|
iyy = params.get("pendulum_iyy", 6.16e-06)
|
||||||
|
izz = params.get("pendulum_izz", 1.23e-05)
|
||||||
|
ixy = params.get("pendulum_ixy", 6.10e-06)
|
||||||
|
det_xy = ixx * iyy - ixy * ixy
|
||||||
|
return det_xy > 0 and ixx > 0 and iyy > 0 and izz > 0
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_trajectory_cost(
|
||||||
|
sim: dict[str, np.ndarray],
|
||||||
|
recording: dict[str, np.ndarray],
|
||||||
|
pos_weight: float = 1.0,
|
||||||
|
vel_weight: float = 0.1,
|
||||||
|
) -> float:
|
||||||
|
"""Weighted MSE between sim and real trajectories."""
|
||||||
|
motor_err = _angle_diff(sim["motor_angle"], recording["motor_angle"])
|
||||||
|
pend_err = _angle_diff(sim["pendulum_angle"], recording["pendulum_angle"])
|
||||||
|
motor_vel_err = sim["motor_vel"] - recording["motor_vel"]
|
||||||
|
pend_vel_err = sim["pendulum_vel"] - recording["pendulum_vel"]
|
||||||
|
|
||||||
|
return float(
|
||||||
|
pos_weight * np.mean(motor_err**2)
|
||||||
|
+ pos_weight * np.mean(pend_err**2)
|
||||||
|
+ vel_weight * np.mean(motor_vel_err**2)
|
||||||
|
+ vel_weight * np.mean(pend_vel_err**2)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cost_function(
|
||||||
|
params_vec: np.ndarray,
|
||||||
|
recording: dict[str, np.ndarray],
|
||||||
|
robot_path: Path,
|
||||||
|
specs: list[ParamSpec],
|
||||||
|
sim_dt: float = 0.002,
|
||||||
|
substeps: int = 10,
|
||||||
|
pos_weight: float = 1.0,
|
||||||
|
vel_weight: float = 0.1,
|
||||||
|
window_duration: float = 0.5,
|
||||||
|
) -> float:
|
||||||
|
"""Compute trajectory-matching cost for a candidate parameter vector.
|
||||||
|
|
||||||
|
Uses **multiple-shooting** (windowed rollout): the recording is split
|
||||||
|
into short windows (default 0.5 s). Each window is initialised from
|
||||||
|
the real qpos/qvel, so early errors don’t compound across the full
|
||||||
|
trajectory. This gives a much smoother cost landscape for CMA-ES.
|
||||||
|
|
||||||
|
Set ``window_duration=0`` to fall back to the original open-loop
|
||||||
|
single-shot rollout (not recommended).
|
||||||
|
"""
|
||||||
|
params = params_to_dict(params_vec, specs)
|
||||||
|
|
||||||
|
if not _check_inertia_valid(params):
|
||||||
|
return 1e6
|
||||||
|
|
||||||
|
try:
|
||||||
|
if window_duration > 0:
|
||||||
|
sim = windowed_rollout(
|
||||||
|
robot_path=robot_path,
|
||||||
|
params=params,
|
||||||
|
recording=recording,
|
||||||
|
window_duration=window_duration,
|
||||||
|
sim_dt=sim_dt,
|
||||||
|
substeps=substeps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sim = rollout(
|
||||||
|
robot_path=robot_path,
|
||||||
|
params=params,
|
||||||
|
actions=recording["action"],
|
||||||
|
timesteps=recording["time"],
|
||||||
|
sim_dt=sim_dt,
|
||||||
|
substeps=substeps,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("rollout_failed", error=str(exc))
|
||||||
|
return 1e6
|
||||||
|
|
||||||
|
return _compute_trajectory_cost(sim, recording, pos_weight, vel_weight)
|
||||||
|
|
||||||
|
|
||||||
|
# ── CMA-ES optimisation loop ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def optimize(
|
||||||
|
robot_path: str | Path,
|
||||||
|
recording_path: str | Path,
|
||||||
|
specs: list[ParamSpec] | None = None,
|
||||||
|
sigma0: float = 0.3,
|
||||||
|
population_size: int = 20,
|
||||||
|
max_generations: int = 1000,
|
||||||
|
sim_dt: float = 0.002,
|
||||||
|
substeps: int = 10,
|
||||||
|
pos_weight: float = 1.0,
|
||||||
|
vel_weight: float = 0.1,
|
||||||
|
window_duration: float = 0.5,
|
||||||
|
seed: int = 42,
|
||||||
|
) -> dict:
|
||||||
|
"""Run CMA-ES optimisation and return results.
|
||||||
|
|
||||||
|
Returns a dict with:
|
||||||
|
best_params: dict[str, float]
|
||||||
|
best_cost: float
|
||||||
|
history: list of (generation, best_cost) tuples
|
||||||
|
recording: str (path used)
|
||||||
|
specs: list of param names
|
||||||
|
"""
|
||||||
|
from cmaes import CMA
|
||||||
|
|
||||||
|
robot_path = Path(robot_path).resolve()
|
||||||
|
recording_path = Path(recording_path).resolve()
|
||||||
|
|
||||||
|
if specs is None:
|
||||||
|
specs = ROTARY_CARTPOLE_PARAMS
|
||||||
|
|
||||||
|
# Load recording.
|
||||||
|
recording = dict(np.load(recording_path))
|
||||||
|
n_samples = len(recording["time"])
|
||||||
|
duration = recording["time"][-1] - recording["time"][0]
|
||||||
|
n_windows = max(1, int(duration / window_duration)) if window_duration > 0 else 1
|
||||||
|
log.info(
|
||||||
|
"recording_loaded",
|
||||||
|
path=str(recording_path),
|
||||||
|
samples=n_samples,
|
||||||
|
duration=f"{duration:.1f}s",
|
||||||
|
window_duration=f"{window_duration}s",
|
||||||
|
n_windows=n_windows,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initial point (defaults) — normalised to [0, 1] for CMA-ES.
|
||||||
|
lo, hi = bounds_arrays(specs)
|
||||||
|
x0 = defaults_vector(specs)
|
||||||
|
|
||||||
|
# Normalise to [0, 1] for the optimizer (better conditioned).
|
||||||
|
span = hi - lo
|
||||||
|
span[span == 0] = 1.0 # avoid division by zero
|
||||||
|
|
||||||
|
def to_normed(x: np.ndarray) -> np.ndarray:
|
||||||
|
return (x - lo) / span
|
||||||
|
|
||||||
|
def from_normed(x_n: np.ndarray) -> np.ndarray:
|
||||||
|
return x_n * span + lo
|
||||||
|
|
||||||
|
x0_normed = to_normed(x0)
|
||||||
|
bounds_normed = np.column_stack(
|
||||||
|
[np.zeros(len(specs)), np.ones(len(specs))]
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = CMA(
|
||||||
|
mean=x0_normed,
|
||||||
|
sigma=sigma0,
|
||||||
|
bounds=bounds_normed,
|
||||||
|
population_size=population_size,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_cost = float("inf")
|
||||||
|
best_params_vec = x0.copy()
|
||||||
|
history: list[tuple[int, float]] = []
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"cmaes_starting",
|
||||||
|
n_params=len(specs),
|
||||||
|
population=population_size,
|
||||||
|
max_gens=max_generations,
|
||||||
|
sigma0=sigma0,
|
||||||
|
)
|
||||||
|
|
||||||
|
t0 = time.monotonic()
|
||||||
|
|
||||||
|
for gen in range(max_generations):
|
||||||
|
solutions = []
|
||||||
|
for _ in range(optimizer.population_size):
|
||||||
|
x_normed = optimizer.ask()
|
||||||
|
x_natural = from_normed(x_normed)
|
||||||
|
|
||||||
|
# Clip to bounds (CMA-ES can slightly exceed with sampling noise).
|
||||||
|
x_natural = np.clip(x_natural, lo, hi)
|
||||||
|
|
||||||
|
c = cost_function(
|
||||||
|
x_natural,
|
||||||
|
recording,
|
||||||
|
robot_path,
|
||||||
|
specs,
|
||||||
|
sim_dt=sim_dt,
|
||||||
|
substeps=substeps,
|
||||||
|
pos_weight=pos_weight,
|
||||||
|
vel_weight=vel_weight,
|
||||||
|
window_duration=window_duration,
|
||||||
|
)
|
||||||
|
solutions.append((x_normed, c))
|
||||||
|
|
||||||
|
if c < best_cost:
|
||||||
|
best_cost = c
|
||||||
|
best_params_vec = x_natural.copy()
|
||||||
|
|
||||||
|
optimizer.tell(solutions)
|
||||||
|
history.append((gen, best_cost))
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
if gen % 5 == 0 or gen == max_generations - 1:
|
||||||
|
log.info(
|
||||||
|
"cmaes_generation",
|
||||||
|
gen=gen,
|
||||||
|
best_cost=f"{best_cost:.6f}",
|
||||||
|
elapsed=f"{elapsed:.1f}s",
|
||||||
|
gen_best=f"{min(c for _, c in solutions):.6f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
total_time = time.monotonic() - t0
|
||||||
|
best_params = params_to_dict(best_params_vec, specs)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"cmaes_finished",
|
||||||
|
best_cost=f"{best_cost:.6f}",
|
||||||
|
total_time=f"{total_time:.1f}s",
|
||||||
|
evaluations=max_generations * population_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log parameter comparison.
|
||||||
|
defaults = params_to_dict(defaults_vector(specs), specs)
|
||||||
|
for name in best_params:
|
||||||
|
d = defaults[name]
|
||||||
|
b = best_params[name]
|
||||||
|
change_pct = ((b - d) / abs(d) * 100) if abs(d) > 1e-12 else 0.0
|
||||||
|
log.info(
|
||||||
|
"param_result",
|
||||||
|
name=name,
|
||||||
|
default=f"{d:.6g}",
|
||||||
|
tuned=f"{b:.6g}",
|
||||||
|
change=f"{change_pct:+.1f}%",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"best_params": best_params,
|
||||||
|
"best_cost": best_cost,
|
||||||
|
"history": history,
|
||||||
|
"recording": str(recording_path),
|
||||||
|
"param_names": [s.name for s in specs],
|
||||||
|
"defaults": {s.name: s.default for s in specs},
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI entry point ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Fit simulation parameters to a real-robot recording (CMA-ES)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--robot-path",
|
||||||
|
type=str,
|
||||||
|
default="assets/rotary_cartpole",
|
||||||
|
help="Path to robot asset directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--recording",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to .npz recording file",
|
||||||
|
)
|
||||||
|
parser.add_argument("--sigma0", type=float, default=0.3)
|
||||||
|
parser.add_argument("--population-size", type=int, default=20)
|
||||||
|
parser.add_argument("--max-generations", type=int, default=200)
|
||||||
|
parser.add_argument("--sim-dt", type=float, default=0.002)
|
||||||
|
parser.add_argument("--substeps", type=int, default=10)
|
||||||
|
parser.add_argument("--pos-weight", type=float, default=1.0)
|
||||||
|
parser.add_argument("--vel-weight", type=float, default=0.1)
|
||||||
|
parser.add_argument(
|
||||||
|
"--window-duration",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Shooting window length in seconds (0 = open-loop, default 0.5)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-export",
|
||||||
|
action="store_true",
|
||||||
|
help="Skip exporting tuned files (results JSON only)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
result = optimize(
|
||||||
|
robot_path=args.robot_path,
|
||||||
|
recording_path=args.recording,
|
||||||
|
sigma0=args.sigma0,
|
||||||
|
population_size=args.population_size,
|
||||||
|
max_generations=args.max_generations,
|
||||||
|
sim_dt=args.sim_dt,
|
||||||
|
substeps=args.substeps,
|
||||||
|
pos_weight=args.pos_weight,
|
||||||
|
vel_weight=args.vel_weight,
|
||||||
|
window_duration=args.window_duration,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save results JSON.
|
||||||
|
robot_path = Path(args.robot_path).resolve()
|
||||||
|
result_path = robot_path / "sysid_result.json"
|
||||||
|
# Convert numpy types for JSON serialisation.
|
||||||
|
result_json = {
|
||||||
|
k: v for k, v in result.items() if k != "history"
|
||||||
|
}
|
||||||
|
result_json["history_summary"] = {
|
||||||
|
"first_cost": result["history"][0][1] if result["history"] else None,
|
||||||
|
"final_cost": result["history"][-1][1] if result["history"] else None,
|
||||||
|
"generations": len(result["history"]),
|
||||||
|
}
|
||||||
|
result_path.write_text(json.dumps(result_json, indent=2, default=str))
|
||||||
|
log.info("results_saved", path=str(result_path))
|
||||||
|
|
||||||
|
# Export tuned files unless --no-export.
|
||||||
|
if not args.no_export:
|
||||||
|
from src.sysid.export import export_tuned_files
|
||||||
|
|
||||||
|
export_tuned_files(
|
||||||
|
robot_path=args.robot_path,
|
||||||
|
params=result["best_params"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
477
src/sysid/rollout.py
Normal file
477
src/sysid/rollout.py
Normal file
@@ -0,0 +1,477 @@
|
|||||||
|
"""Deterministic simulation replay — roll out recorded actions in MuJoCo.
|
||||||
|
|
||||||
|
Given a parameter vector and a recorded action sequence, builds a MuJoCo
|
||||||
|
model with overridden physics parameters, replays the actions, and returns
|
||||||
|
the simulated trajectory for comparison with the real recording.
|
||||||
|
|
||||||
|
This module is the inner loop of the CMA-ES optimizer: it is called once
|
||||||
|
per candidate parameter vector per generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import dataclasses
|
||||||
|
import math
|
||||||
|
import tempfile
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import mujoco
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tunable parameter specification ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ParamSpec:
|
||||||
|
"""Specification for a single tunable parameter."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
default: float
|
||||||
|
lower: float
|
||||||
|
upper: float
|
||||||
|
log_scale: bool = False # optimise in log-space (masses, inertias)
|
||||||
|
|
||||||
|
|
||||||
|
# Default parameter specs for the rotary cartpole.
|
||||||
|
# Order matters: the optimizer maps a flat vector to these specs.
|
||||||
|
ROTARY_CARTPOLE_PARAMS: list[ParamSpec] = [
|
||||||
|
# ── Arm link (URDF) ──────────────────────────────────────────
|
||||||
|
ParamSpec("arm_mass", 0.010, 0.003, 0.05, log_scale=True),
|
||||||
|
ParamSpec("arm_com_x", 0.00005, -0.02, 0.02),
|
||||||
|
ParamSpec("arm_com_y", 0.0065, -0.01, 0.02),
|
||||||
|
ParamSpec("arm_com_z", 0.00563, -0.01, 0.02),
|
||||||
|
# ── Pendulum link (URDF) ─────────────────────────────────────
|
||||||
|
ParamSpec("pendulum_mass", 0.015, 0.005, 0.05, log_scale=True),
|
||||||
|
ParamSpec("pendulum_com_x", 0.1583, 0.05, 0.25),
|
||||||
|
ParamSpec("pendulum_com_y", -0.0983, -0.20, 0.0),
|
||||||
|
ParamSpec("pendulum_com_z", 0.0, -0.05, 0.05),
|
||||||
|
ParamSpec("pendulum_ixx", 6.16e-06, 1e-07, 1e-04, log_scale=True),
|
||||||
|
ParamSpec("pendulum_iyy", 6.16e-06, 1e-07, 1e-04, log_scale=True),
|
||||||
|
ParamSpec("pendulum_izz", 1.23e-05, 1e-07, 1e-04, log_scale=True),
|
||||||
|
ParamSpec("pendulum_ixy", 6.10e-06, -1e-04, 1e-04),
|
||||||
|
# ── Actuator / joint dynamics (robot.yaml) ───────────────────
|
||||||
|
ParamSpec("actuator_gear", 0.064, 0.01, 0.2, log_scale=True),
|
||||||
|
ParamSpec("actuator_filter_tau", 0.03, 0.005, 0.15),
|
||||||
|
ParamSpec("motor_damping", 0.003, 1e-4, 0.05, log_scale=True),
|
||||||
|
ParamSpec("pendulum_damping", 0.0001, 1e-5, 0.01, log_scale=True),
|
||||||
|
ParamSpec("motor_armature", 0.0001, 1e-5, 0.01, log_scale=True),
|
||||||
|
ParamSpec("motor_frictionloss", 0.03, 0.001, 0.1, log_scale=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def params_to_dict(
|
||||||
|
values: np.ndarray, specs: list[ParamSpec] | None = None
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Convert a flat parameter vector to a named dict."""
|
||||||
|
if specs is None:
|
||||||
|
specs = ROTARY_CARTPOLE_PARAMS
|
||||||
|
return {s.name: float(values[i]) for i, s in enumerate(specs)}
|
||||||
|
|
||||||
|
|
||||||
|
def defaults_vector(specs: list[ParamSpec] | None = None) -> np.ndarray:
|
||||||
|
"""Return the default parameter vector (in natural space)."""
|
||||||
|
if specs is None:
|
||||||
|
specs = ROTARY_CARTPOLE_PARAMS
|
||||||
|
return np.array([s.default for s in specs], dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def bounds_arrays(
|
||||||
|
specs: list[ParamSpec] | None = None,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Return (lower, upper) bound arrays."""
|
||||||
|
if specs is None:
|
||||||
|
specs = ROTARY_CARTPOLE_PARAMS
|
||||||
|
lo = np.array([s.lower for s in specs], dtype=np.float64)
|
||||||
|
hi = np.array([s.upper for s in specs], dtype=np.float64)
|
||||||
|
return lo, hi
|
||||||
|
|
||||||
|
|
||||||
|
# ── MuJoCo model building with parameter overrides ──────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _build_model(
|
||||||
|
robot_path: Path,
|
||||||
|
params: dict[str, float],
|
||||||
|
) -> mujoco.MjModel:
|
||||||
|
"""Build a MuJoCo model from URDF + robot.yaml with parameter overrides.
|
||||||
|
|
||||||
|
Follows the same two-step approach as ``MuJoCoRunner._load_model()``:
|
||||||
|
1. Parse URDF, inject meshdir, load into MuJoCo
|
||||||
|
2. Export MJCF, inject actuators + joint overrides + param overrides, reload
|
||||||
|
"""
|
||||||
|
robot_path = Path(robot_path).resolve()
|
||||||
|
robot_yaml = yaml.safe_load((robot_path / "robot.yaml").read_text())
|
||||||
|
urdf_path = robot_path / robot_yaml["urdf"]
|
||||||
|
|
||||||
|
# ── Step 1: Load URDF ────────────────────────────────────────
|
||||||
|
tree = ET.parse(urdf_path)
|
||||||
|
root = tree.getroot()
|
||||||
|
|
||||||
|
# Inject meshdir compiler directive.
|
||||||
|
meshdir = None
|
||||||
|
for mesh_el in root.iter("mesh"):
|
||||||
|
fn = mesh_el.get("filename", "")
|
||||||
|
parent = str(Path(fn).parent)
|
||||||
|
if parent and parent != ".":
|
||||||
|
meshdir = parent
|
||||||
|
break
|
||||||
|
if meshdir:
|
||||||
|
mj_ext = ET.SubElement(root, "mujoco")
|
||||||
|
ET.SubElement(
|
||||||
|
mj_ext, "compiler", attrib={"meshdir": meshdir, "balanceinertia": "true"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override URDF inertial parameters BEFORE MuJoCo loading.
|
||||||
|
for link in root.iter("link"):
|
||||||
|
link_name = link.get("name", "")
|
||||||
|
inertial = link.find("inertial")
|
||||||
|
if inertial is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if link_name == "arm":
|
||||||
|
_set_mass(inertial, params.get("arm_mass"))
|
||||||
|
_set_com(
|
||||||
|
inertial,
|
||||||
|
params.get("arm_com_x"),
|
||||||
|
params.get("arm_com_y"),
|
||||||
|
params.get("arm_com_z"),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif link_name == "pendulum":
|
||||||
|
_set_mass(inertial, params.get("pendulum_mass"))
|
||||||
|
_set_com(
|
||||||
|
inertial,
|
||||||
|
params.get("pendulum_com_x"),
|
||||||
|
params.get("pendulum_com_y"),
|
||||||
|
params.get("pendulum_com_z"),
|
||||||
|
)
|
||||||
|
_set_inertia(
|
||||||
|
inertial,
|
||||||
|
ixx=params.get("pendulum_ixx"),
|
||||||
|
iyy=params.get("pendulum_iyy"),
|
||||||
|
izz=params.get("pendulum_izz"),
|
||||||
|
ixy=params.get("pendulum_ixy"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write temp URDF and load.
|
||||||
|
tmp_urdf = robot_path / "_tmp_sysid_load.urdf"
|
||||||
|
tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode")
|
||||||
|
try:
|
||||||
|
model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf))
|
||||||
|
finally:
|
||||||
|
tmp_urdf.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
# ── Step 2: Export MJCF, inject actuators + overrides ────────
|
||||||
|
tmp_mjcf = robot_path / "_tmp_sysid_inject.xml"
|
||||||
|
try:
|
||||||
|
mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
|
||||||
|
mjcf_root = ET.fromstring(tmp_mjcf.read_text())
|
||||||
|
|
||||||
|
# Actuator.
|
||||||
|
gear = params.get("actuator_gear", robot_yaml["actuators"][0].get("gear", 0.064))
|
||||||
|
filter_tau = params.get(
|
||||||
|
"actuator_filter_tau",
|
||||||
|
robot_yaml["actuators"][0].get("filter_tau", 0.03),
|
||||||
|
)
|
||||||
|
act_cfg = robot_yaml["actuators"][0]
|
||||||
|
ctrl_lo, ctrl_hi = act_cfg.get("ctrl_range", [-1.0, 1.0])
|
||||||
|
|
||||||
|
act_elem = ET.SubElement(mjcf_root, "actuator")
|
||||||
|
attribs: dict[str, str] = {
|
||||||
|
"name": f"{act_cfg['joint']}_motor",
|
||||||
|
"joint": act_cfg["joint"],
|
||||||
|
"gear": str(gear),
|
||||||
|
"ctrlrange": f"{ctrl_lo} {ctrl_hi}",
|
||||||
|
}
|
||||||
|
if filter_tau > 0:
|
||||||
|
attribs["dyntype"] = "filter"
|
||||||
|
attribs["dynprm"] = str(filter_tau)
|
||||||
|
attribs["gaintype"] = "fixed"
|
||||||
|
attribs["biastype"] = "none"
|
||||||
|
ET.SubElement(act_elem, "general", attrib=attribs)
|
||||||
|
else:
|
||||||
|
ET.SubElement(act_elem, "motor", attrib=attribs)
|
||||||
|
|
||||||
|
# Joint overrides.
|
||||||
|
motor_damping = params.get("motor_damping", 0.003)
|
||||||
|
pend_damping = params.get("pendulum_damping", 0.0001)
|
||||||
|
motor_armature = params.get("motor_armature", 0.0001)
|
||||||
|
motor_frictionloss = params.get("motor_frictionloss", 0.03)
|
||||||
|
|
||||||
|
for body in mjcf_root.iter("body"):
|
||||||
|
for jnt in body.findall("joint"):
|
||||||
|
name = jnt.get("name")
|
||||||
|
if name == "motor_joint":
|
||||||
|
jnt.set("damping", str(motor_damping))
|
||||||
|
jnt.set("armature", str(motor_armature))
|
||||||
|
jnt.set("frictionloss", str(motor_frictionloss))
|
||||||
|
elif name == "pendulum_joint":
|
||||||
|
jnt.set("damping", str(pend_damping))
|
||||||
|
|
||||||
|
# Disable self-collision.
|
||||||
|
for geom in mjcf_root.iter("geom"):
|
||||||
|
geom.set("contype", "0")
|
||||||
|
geom.set("conaffinity", "0")
|
||||||
|
|
||||||
|
modified_xml = ET.tostring(mjcf_root, encoding="unicode")
|
||||||
|
tmp_mjcf.write_text(modified_xml)
|
||||||
|
model = mujoco.MjModel.from_xml_path(str(tmp_mjcf))
|
||||||
|
finally:
|
||||||
|
tmp_mjcf.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _set_mass(inertial: ET.Element, mass: float | None) -> None:
|
||||||
|
if mass is None:
|
||||||
|
return
|
||||||
|
mass_el = inertial.find("mass")
|
||||||
|
if mass_el is not None:
|
||||||
|
mass_el.set("value", str(mass))
|
||||||
|
|
||||||
|
|
||||||
|
def _set_com(
|
||||||
|
inertial: ET.Element,
|
||||||
|
x: float | None,
|
||||||
|
y: float | None,
|
||||||
|
z: float | None,
|
||||||
|
) -> None:
|
||||||
|
origin = inertial.find("origin")
|
||||||
|
if origin is None:
|
||||||
|
return
|
||||||
|
xyz = origin.get("xyz", "0 0 0").split()
|
||||||
|
if x is not None:
|
||||||
|
xyz[0] = str(x)
|
||||||
|
if y is not None:
|
||||||
|
xyz[1] = str(y)
|
||||||
|
if z is not None:
|
||||||
|
xyz[2] = str(z)
|
||||||
|
origin.set("xyz", " ".join(xyz))
|
||||||
|
|
||||||
|
|
||||||
|
def _set_inertia(
|
||||||
|
inertial: ET.Element,
|
||||||
|
ixx: float | None = None,
|
||||||
|
iyy: float | None = None,
|
||||||
|
izz: float | None = None,
|
||||||
|
ixy: float | None = None,
|
||||||
|
iyz: float | None = None,
|
||||||
|
ixz: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
ine = inertial.find("inertia")
|
||||||
|
if ine is None:
|
||||||
|
return
|
||||||
|
for attr, val in [
|
||||||
|
("ixx", ixx), ("iyy", iyy), ("izz", izz),
|
||||||
|
("ixy", ixy), ("iyz", iyz), ("ixz", ixz),
|
||||||
|
]:
|
||||||
|
if val is not None:
|
||||||
|
ine.set(attr, str(val))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Simulation rollout ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def rollout(
|
||||||
|
robot_path: str | Path,
|
||||||
|
params: dict[str, float],
|
||||||
|
actions: np.ndarray,
|
||||||
|
timesteps: np.ndarray,
|
||||||
|
sim_dt: float = 0.002,
|
||||||
|
substeps: int = 10,
|
||||||
|
) -> dict[str, np.ndarray]:
|
||||||
|
"""Replay recorded actions in MuJoCo with overridden parameters.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
robot_path : asset directory
|
||||||
|
params : named parameter overrides
|
||||||
|
actions : (N,) normalised actions [-1, 1] from the recording
|
||||||
|
timesteps : (N,) wall-clock times (seconds) from the recording
|
||||||
|
sim_dt : MuJoCo physics timestep
|
||||||
|
substeps : physics substeps per control step
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict with keys: motor_angle, motor_vel, pendulum_angle, pendulum_vel
|
||||||
|
Each is an (N,) numpy array of simulated values.
|
||||||
|
"""
|
||||||
|
robot_path = Path(robot_path).resolve()
|
||||||
|
model = _build_model(robot_path, params)
|
||||||
|
model.opt.timestep = sim_dt
|
||||||
|
data = mujoco.MjData(model)
|
||||||
|
|
||||||
|
# Start from pendulum hanging down (qpos=0 is down per URDF convention).
|
||||||
|
mujoco.mj_resetData(model, data)
|
||||||
|
|
||||||
|
# Control dt derived from actual recording sample rate.
|
||||||
|
n = len(actions)
|
||||||
|
ctrl_dt = sim_dt * substeps
|
||||||
|
|
||||||
|
# Pre-allocate output.
|
||||||
|
sim_motor_angle = np.zeros(n, dtype=np.float64)
|
||||||
|
sim_motor_vel = np.zeros(n, dtype=np.float64)
|
||||||
|
sim_pend_angle = np.zeros(n, dtype=np.float64)
|
||||||
|
sim_pend_vel = np.zeros(n, dtype=np.float64)
|
||||||
|
|
||||||
|
# Extract actuator limit info for software limit switch.
|
||||||
|
nu = model.nu
|
||||||
|
if nu > 0:
|
||||||
|
jnt_id = model.actuator_trnid[0, 0]
|
||||||
|
jnt_limited = bool(model.jnt_limited[jnt_id])
|
||||||
|
jnt_lo = model.jnt_range[jnt_id, 0]
|
||||||
|
jnt_hi = model.jnt_range[jnt_id, 1]
|
||||||
|
gear_sign = float(np.sign(model.actuator_gear[0, 0]))
|
||||||
|
else:
|
||||||
|
jnt_limited = False
|
||||||
|
jnt_lo = jnt_hi = gear_sign = 0.0
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
data.ctrl[0] = actions[i]
|
||||||
|
|
||||||
|
for _ in range(substeps):
|
||||||
|
# Software limit switch (mirrors MuJoCoRunner).
|
||||||
|
if jnt_limited and nu > 0:
|
||||||
|
pos = data.qpos[jnt_id]
|
||||||
|
if pos >= jnt_hi and gear_sign * data.ctrl[0] > 0:
|
||||||
|
data.ctrl[0] = 0.0
|
||||||
|
elif pos <= jnt_lo and gear_sign * data.ctrl[0] < 0:
|
||||||
|
data.ctrl[0] = 0.0
|
||||||
|
mujoco.mj_step(model, data)
|
||||||
|
|
||||||
|
sim_motor_angle[i] = data.qpos[0]
|
||||||
|
sim_motor_vel[i] = data.qvel[0]
|
||||||
|
sim_pend_angle[i] = data.qpos[1]
|
||||||
|
sim_pend_vel[i] = data.qvel[1]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"motor_angle": sim_motor_angle,
|
||||||
|
"motor_vel": sim_motor_vel,
|
||||||
|
"pendulum_angle": sim_pend_angle,
|
||||||
|
"pendulum_vel": sim_pend_vel,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def windowed_rollout(
|
||||||
|
robot_path: str | Path,
|
||||||
|
params: dict[str, float],
|
||||||
|
recording: dict[str, np.ndarray],
|
||||||
|
window_duration: float = 0.5,
|
||||||
|
sim_dt: float = 0.002,
|
||||||
|
substeps: int = 10,
|
||||||
|
) -> dict[str, np.ndarray | float]:
|
||||||
|
"""Multiple-shooting rollout — split recording into short windows.
|
||||||
|
|
||||||
|
For each window:
|
||||||
|
1. Initialize MuJoCo state from the real qpos/qvel at the window start.
|
||||||
|
2. Replay the recorded actions within the window.
|
||||||
|
3. Record the simulated output.
|
||||||
|
|
||||||
|
This prevents error accumulation across the full trajectory, giving
|
||||||
|
a much smoother cost landscape for the optimizer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
robot_path : asset directory
|
||||||
|
params : named parameter overrides
|
||||||
|
recording : dict with keys time, action, motor_angle, motor_vel,
|
||||||
|
pendulum_angle, pendulum_vel (all 1D arrays of length N)
|
||||||
|
window_duration : length of each shooting window in seconds
|
||||||
|
sim_dt : MuJoCo physics timestep
|
||||||
|
substeps : physics substeps per control step
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict with:
|
||||||
|
motor_angle, motor_vel, pendulum_angle, pendulum_vel — (N,) arrays
|
||||||
|
(stitched from per-window simulations)
|
||||||
|
n_windows — number of windows used
|
||||||
|
"""
|
||||||
|
robot_path = Path(robot_path).resolve()
|
||||||
|
model = _build_model(robot_path, params)
|
||||||
|
model.opt.timestep = sim_dt
|
||||||
|
data = mujoco.MjData(model)
|
||||||
|
|
||||||
|
times = recording["time"]
|
||||||
|
actions = recording["action"]
|
||||||
|
real_motor = recording["motor_angle"]
|
||||||
|
real_motor_vel = recording["motor_vel"]
|
||||||
|
real_pend = recording["pendulum_angle"]
|
||||||
|
real_pend_vel = recording["pendulum_vel"]
|
||||||
|
n = len(actions)
|
||||||
|
|
||||||
|
# Pre-allocate output (stitched from all windows).
|
||||||
|
sim_motor_angle = np.zeros(n, dtype=np.float64)
|
||||||
|
sim_motor_vel = np.zeros(n, dtype=np.float64)
|
||||||
|
sim_pend_angle = np.zeros(n, dtype=np.float64)
|
||||||
|
sim_pend_vel = np.zeros(n, dtype=np.float64)
|
||||||
|
|
||||||
|
# Extract actuator limit info.
|
||||||
|
nu = model.nu
|
||||||
|
if nu > 0:
|
||||||
|
jnt_id = model.actuator_trnid[0, 0]
|
||||||
|
jnt_limited = bool(model.jnt_limited[jnt_id])
|
||||||
|
jnt_lo = model.jnt_range[jnt_id, 0]
|
||||||
|
jnt_hi = model.jnt_range[jnt_id, 1]
|
||||||
|
gear_sign = float(np.sign(model.actuator_gear[0, 0]))
|
||||||
|
else:
|
||||||
|
jnt_limited = False
|
||||||
|
jnt_lo = jnt_hi = gear_sign = 0.0
|
||||||
|
|
||||||
|
# Compute window boundaries from recording timestamps.
|
||||||
|
t0 = times[0]
|
||||||
|
t_end = times[-1]
|
||||||
|
window_starts: list[int] = [] # indices into the recording
|
||||||
|
current_t = t0
|
||||||
|
while current_t < t_end:
|
||||||
|
# Find the index closest to current_t.
|
||||||
|
idx = int(np.searchsorted(times, current_t))
|
||||||
|
idx = min(idx, n - 1)
|
||||||
|
window_starts.append(idx)
|
||||||
|
current_t += window_duration
|
||||||
|
|
||||||
|
n_windows = len(window_starts)
|
||||||
|
|
||||||
|
for w, w_start in enumerate(window_starts):
|
||||||
|
# Window end: next window start, or end of recording.
|
||||||
|
w_end = window_starts[w + 1] if w + 1 < n_windows else n
|
||||||
|
|
||||||
|
# Initialize MuJoCo state from real data at window start.
|
||||||
|
mujoco.mj_resetData(model, data)
|
||||||
|
data.qpos[0] = real_motor[w_start]
|
||||||
|
data.qpos[1] = real_pend[w_start]
|
||||||
|
data.qvel[0] = real_motor_vel[w_start]
|
||||||
|
data.qvel[1] = real_pend_vel[w_start]
|
||||||
|
data.ctrl[:] = 0.0
|
||||||
|
# Forward kinematics to make state consistent.
|
||||||
|
mujoco.mj_forward(model, data)
|
||||||
|
|
||||||
|
for i in range(w_start, w_end):
|
||||||
|
data.ctrl[0] = actions[i]
|
||||||
|
|
||||||
|
for _ in range(substeps):
|
||||||
|
if jnt_limited and nu > 0:
|
||||||
|
pos = data.qpos[jnt_id]
|
||||||
|
if pos >= jnt_hi and gear_sign * data.ctrl[0] > 0:
|
||||||
|
data.ctrl[0] = 0.0
|
||||||
|
elif pos <= jnt_lo and gear_sign * data.ctrl[0] < 0:
|
||||||
|
data.ctrl[0] = 0.0
|
||||||
|
mujoco.mj_step(model, data)
|
||||||
|
|
||||||
|
sim_motor_angle[i] = data.qpos[0]
|
||||||
|
sim_motor_vel[i] = data.qvel[0]
|
||||||
|
sim_pend_angle[i] = data.qpos[1]
|
||||||
|
sim_pend_vel[i] = data.qvel[1]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"motor_angle": sim_motor_angle,
|
||||||
|
"motor_vel": sim_motor_vel,
|
||||||
|
"pendulum_angle": sim_pend_angle,
|
||||||
|
"pendulum_vel": sim_pend_vel,
|
||||||
|
"n_windows": n_windows,
|
||||||
|
}
|
||||||
287
src/sysid/visualize.py
Normal file
287
src/sysid/visualize.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
"""Visualise system identification results — real vs simulated trajectories.
|
||||||
|
|
||||||
|
Loads a recording and runs simulation with both the original and tuned
|
||||||
|
parameters, then plots a 4-panel comparison (motor angle, motor vel,
|
||||||
|
pendulum angle, pendulum vel) over time.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m src.sysid.visualize \
|
||||||
|
--robot-path assets/rotary_cartpole \
|
||||||
|
--recording assets/rotary_cartpole/recordings/capture_20260311_120000.npz
|
||||||
|
|
||||||
|
# Also compare with tuned parameters:
|
||||||
|
python -m src.sysid.visualize \
|
||||||
|
--robot-path assets/rotary_cartpole \
|
||||||
|
--recording <file>.npz \
|
||||||
|
--result assets/rotary_cartpole/sysid_result.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
log = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def visualize(
|
||||||
|
robot_path: str | Path,
|
||||||
|
recording_path: str | Path,
|
||||||
|
result_path: str | Path | None = None,
|
||||||
|
sim_dt: float = 0.002,
|
||||||
|
substeps: int = 10,
|
||||||
|
window_duration: float = 0.5,
|
||||||
|
save_path: str | Path | None = None,
|
||||||
|
show: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Generate comparison plot.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
robot_path : robot asset directory
|
||||||
|
recording_path : .npz file from capture
|
||||||
|
result_path : sysid_result.json with best_params (optional)
|
||||||
|
sim_dt / substeps : physics settings for rollout
|
||||||
|
window_duration : shooting window length (s); 0 = open-loop
|
||||||
|
save_path : if provided, save figure to this path (PNG, PDF, …)
|
||||||
|
show : if True, display interactive matplotlib window
|
||||||
|
"""
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from src.sysid.rollout import (
|
||||||
|
ROTARY_CARTPOLE_PARAMS,
|
||||||
|
defaults_vector,
|
||||||
|
params_to_dict,
|
||||||
|
rollout,
|
||||||
|
windowed_rollout,
|
||||||
|
)
|
||||||
|
|
||||||
|
robot_path = Path(robot_path).resolve()
|
||||||
|
recording = dict(np.load(recording_path))
|
||||||
|
|
||||||
|
t = recording["time"]
|
||||||
|
actions = recording["action"]
|
||||||
|
|
||||||
|
# ── Simulate with default parameters ─────────────────────────
|
||||||
|
default_params = params_to_dict(
|
||||||
|
defaults_vector(ROTARY_CARTPOLE_PARAMS), ROTARY_CARTPOLE_PARAMS
|
||||||
|
)
|
||||||
|
log.info("simulating_default_params", windowed=window_duration > 0)
|
||||||
|
if window_duration > 0:
|
||||||
|
sim_default = windowed_rollout(
|
||||||
|
robot_path=robot_path,
|
||||||
|
params=default_params,
|
||||||
|
recording=recording,
|
||||||
|
window_duration=window_duration,
|
||||||
|
sim_dt=sim_dt,
|
||||||
|
substeps=substeps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sim_default = rollout(
|
||||||
|
robot_path=robot_path,
|
||||||
|
params=default_params,
|
||||||
|
actions=actions,
|
||||||
|
timesteps=t,
|
||||||
|
sim_dt=sim_dt,
|
||||||
|
substeps=substeps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Simulate with tuned parameters (if available) ────────────
|
||||||
|
sim_tuned = None
|
||||||
|
tuned_cost = None
|
||||||
|
if result_path is not None:
|
||||||
|
result_path = Path(result_path)
|
||||||
|
if result_path.exists():
|
||||||
|
result = json.loads(result_path.read_text())
|
||||||
|
tuned_params = result.get("best_params", {})
|
||||||
|
tuned_cost = result.get("best_cost")
|
||||||
|
log.info("simulating_tuned_params", cost=tuned_cost)
|
||||||
|
if window_duration > 0:
|
||||||
|
sim_tuned = windowed_rollout(
|
||||||
|
robot_path=robot_path,
|
||||||
|
params=tuned_params,
|
||||||
|
recording=recording,
|
||||||
|
window_duration=window_duration,
|
||||||
|
sim_dt=sim_dt,
|
||||||
|
substeps=substeps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sim_tuned = rollout(
|
||||||
|
robot_path=robot_path,
|
||||||
|
params=tuned_params,
|
||||||
|
actions=actions,
|
||||||
|
timesteps=t,
|
||||||
|
sim_dt=sim_dt,
|
||||||
|
substeps=substeps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.warning("result_file_not_found", path=str(result_path))
|
||||||
|
else:
|
||||||
|
# Auto-detect sysid_result.json in robot_path.
|
||||||
|
auto_result = robot_path / "sysid_result.json"
|
||||||
|
if auto_result.exists():
|
||||||
|
result = json.loads(auto_result.read_text())
|
||||||
|
tuned_params = result.get("best_params", {})
|
||||||
|
tuned_cost = result.get("best_cost")
|
||||||
|
log.info("auto_detected_tuned_params", cost=tuned_cost)
|
||||||
|
if window_duration > 0:
|
||||||
|
sim_tuned = windowed_rollout(
|
||||||
|
robot_path=robot_path,
|
||||||
|
params=tuned_params,
|
||||||
|
recording=recording,
|
||||||
|
window_duration=window_duration,
|
||||||
|
sim_dt=sim_dt,
|
||||||
|
substeps=substeps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sim_tuned = rollout(
|
||||||
|
robot_path=robot_path,
|
||||||
|
params=tuned_params,
|
||||||
|
actions=actions,
|
||||||
|
timesteps=t,
|
||||||
|
sim_dt=sim_dt,
|
||||||
|
substeps=substeps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Plot ─────────────────────────────────────────────────────
|
||||||
|
fig, axes = plt.subplots(5, 1, figsize=(14, 12), sharex=True)
|
||||||
|
|
||||||
|
channels = [
|
||||||
|
("motor_angle", "Motor Angle (rad)", True),
|
||||||
|
("motor_vel", "Motor Velocity (rad/s)", False),
|
||||||
|
("pendulum_angle", "Pendulum Angle (rad)", True),
|
||||||
|
("pendulum_vel", "Pendulum Velocity (rad/s)", False),
|
||||||
|
]
|
||||||
|
|
||||||
|
for ax, (key, ylabel, is_angle) in zip(axes[:4], channels):
|
||||||
|
real = recording[key]
|
||||||
|
|
||||||
|
ax.plot(t, real, "k-", linewidth=1.2, alpha=0.8, label="Real")
|
||||||
|
ax.plot(
|
||||||
|
t,
|
||||||
|
sim_default[key],
|
||||||
|
"--",
|
||||||
|
color="#d62728",
|
||||||
|
linewidth=1.0,
|
||||||
|
alpha=0.7,
|
||||||
|
label="Sim (original)",
|
||||||
|
)
|
||||||
|
if sim_tuned is not None:
|
||||||
|
ax.plot(
|
||||||
|
t,
|
||||||
|
sim_tuned[key],
|
||||||
|
"--",
|
||||||
|
color="#2ca02c",
|
||||||
|
linewidth=1.0,
|
||||||
|
alpha=0.7,
|
||||||
|
label="Sim (tuned)",
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_ylabel(ylabel)
|
||||||
|
ax.legend(loc="upper right", fontsize=8)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Action plot (bottom panel).
|
||||||
|
axes[4].plot(t, actions, "b-", linewidth=0.8, alpha=0.6)
|
||||||
|
axes[4].set_ylabel("Action (norm)")
|
||||||
|
axes[4].set_xlabel("Time (s)")
|
||||||
|
axes[4].grid(True, alpha=0.3)
|
||||||
|
axes[4].set_ylim(-1.1, 1.1)
|
||||||
|
|
||||||
|
# Title with cost info.
|
||||||
|
title = "System Identification — Real vs Simulated Trajectories"
|
||||||
|
if tuned_cost is not None:
|
||||||
|
# Compute original cost for comparison.
|
||||||
|
from src.sysid.optimize import cost_function
|
||||||
|
|
||||||
|
orig_cost = cost_function(
|
||||||
|
defaults_vector(ROTARY_CARTPOLE_PARAMS),
|
||||||
|
recording,
|
||||||
|
robot_path,
|
||||||
|
ROTARY_CARTPOLE_PARAMS,
|
||||||
|
sim_dt=sim_dt,
|
||||||
|
substeps=substeps,
|
||||||
|
window_duration=window_duration,
|
||||||
|
)
|
||||||
|
title += f"\nOriginal cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
|
||||||
|
improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0
|
||||||
|
title += f" ({improvement:+.1f}%)"
|
||||||
|
|
||||||
|
fig.suptitle(title, fontsize=12)
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
save_path = Path(save_path)
|
||||||
|
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
|
||||||
|
log.info("figure_saved", path=str(save_path))
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
else:
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI entry point ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Visualise system identification results."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--robot-path",
|
||||||
|
type=str,
|
||||||
|
default="assets/rotary_cartpole",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--recording",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to .npz recording file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--result",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to sysid_result.json (auto-detected if omitted)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--sim-dt", type=float, default=0.002)
|
||||||
|
parser.add_argument("--substeps", type=int, default=10)
|
||||||
|
parser.add_argument(
|
||||||
|
"--window-duration",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Shooting window length in seconds (0 = open-loop)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Save figure to this path (PNG, PDF, …)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-show",
|
||||||
|
action="store_true",
|
||||||
|
help="Don't show interactive window (useful for CI)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
visualize(
|
||||||
|
robot_path=args.robot_path,
|
||||||
|
recording_path=args.recording,
|
||||||
|
result_path=args.result,
|
||||||
|
sim_dt=args.sim_dt,
|
||||||
|
substeps=args.substeps,
|
||||||
|
window_duration=args.window_duration,
|
||||||
|
save_path=args.save,
|
||||||
|
show=not args.no_show,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -35,6 +35,11 @@ class TrainerConfig:
|
|||||||
|
|
||||||
hidden_sizes: tuple[int, ...] = (64, 64)
|
hidden_sizes: tuple[int, ...] = (64, 64)
|
||||||
|
|
||||||
|
# Policy
|
||||||
|
initial_log_std: float = 0.5 # initial exploration noise
|
||||||
|
min_log_std: float = -2.0 # minimum exploration noise
|
||||||
|
max_log_std: float = 2.0 # maximum exploration noise (2.0 ≈ σ=7.4)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
total_timesteps: int = 1_000_000
|
total_timesteps: int = 1_000_000
|
||||||
log_interval: int = 10
|
log_interval: int = 10
|
||||||
@@ -110,6 +115,7 @@ class VideoRecordingTrainer(SequentialTrainer):
|
|||||||
return self._tcfg.record_video_fps
|
return self._tcfg.record_video_fps
|
||||||
dt = getattr(self.env.config, "dt", 0.02)
|
dt = getattr(self.env.config, "dt", 0.02)
|
||||||
substeps = getattr(self.env.config, "substeps", 1)
|
substeps = getattr(self.env.config, "substeps", 1)
|
||||||
|
# SerialRunner has dt but no substeps — dt *is* the control period.
|
||||||
return max(1, int(round(1.0 / (dt * substeps))))
|
return max(1, int(round(1.0 / (dt * substeps))))
|
||||||
|
|
||||||
def _record_video(self, timestep: int) -> None:
|
def _record_video(self, timestep: int) -> None:
|
||||||
@@ -181,8 +187,9 @@ class Trainer:
|
|||||||
action_space=act_space,
|
action_space=act_space,
|
||||||
device=device,
|
device=device,
|
||||||
hidden_sizes=self.config.hidden_sizes,
|
hidden_sizes=self.config.hidden_sizes,
|
||||||
initial_log_std=0.5,
|
initial_log_std=self.config.initial_log_std,
|
||||||
min_log_std=-2.0,
|
min_log_std=self.config.min_log_std,
|
||||||
|
max_log_std=self.config.max_log_std,
|
||||||
)
|
)
|
||||||
|
|
||||||
models = {"policy": self.model, "value": self.model}
|
models = {"policy": self.model, "value": self.model}
|
||||||
|
|||||||
3
train.py
3
train.py
@@ -27,7 +27,9 @@ logger = structlog.get_logger()
|
|||||||
|
|
||||||
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
|
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
|
||||||
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||||
|
"mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||||
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
|
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
|
||||||
|
"serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -94,6 +96,7 @@ def main(cfg: DictConfig) -> None:
|
|||||||
# execute_remotely() is a no-op on the worker side.
|
# execute_remotely() is a no-op on the worker side.
|
||||||
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
||||||
remote = training_dict.pop("remote", False)
|
remote = training_dict.pop("remote", False)
|
||||||
|
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
|
||||||
task = _init_clearml(choices, remote=remote)
|
task = _init_clearml(choices, remote=remote)
|
||||||
|
|
||||||
env_name = choices.get("env", "cartpole")
|
env_name = choices.get("env", "cartpole")
|
||||||
|
|||||||
135
viz.py
135
viz.py
@@ -1,9 +1,13 @@
|
|||||||
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
|
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
|
||||||
|
|
||||||
Usage:
|
Usage (simulation):
|
||||||
mjpython viz.py env=rotary_cartpole
|
mjpython viz.py env=rotary_cartpole
|
||||||
mjpython viz.py env=cartpole +com=true
|
mjpython viz.py env=cartpole +com=true
|
||||||
|
|
||||||
|
Usage (real hardware — digital twin):
|
||||||
|
mjpython viz.py env=rotary_cartpole runner=serial
|
||||||
|
mjpython viz.py env=rotary_cartpole runner=serial runner.port=/dev/ttyUSB0
|
||||||
|
|
||||||
Controls:
|
Controls:
|
||||||
Left/Right arrows — apply torque to first actuator
|
Left/Right arrows — apply torque to first actuator
|
||||||
R — reset environment
|
R — reset environment
|
||||||
@@ -15,6 +19,7 @@ import time
|
|||||||
import hydra
|
import hydra
|
||||||
import mujoco
|
import mujoco
|
||||||
import mujoco.viewer
|
import mujoco.viewer
|
||||||
|
import numpy as np
|
||||||
import structlog
|
import structlog
|
||||||
import torch
|
import torch
|
||||||
from hydra.core.hydra_config import HydraConfig
|
from hydra.core.hydra_config import HydraConfig
|
||||||
@@ -45,10 +50,64 @@ def _key_callback(keycode: int) -> None:
|
|||||||
_reset_flag[0] = True
|
_reset_flag[0] = True
|
||||||
|
|
||||||
|
|
||||||
|
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
||||||
|
"""Draw an arrow on the motor joint showing applied torque direction."""
|
||||||
|
if abs(action_val) < 0.01 or model.nu == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the body that the first actuator's joint belongs to
|
||||||
|
jnt_id = model.actuator_trnid[0, 0]
|
||||||
|
body_id = model.jnt_bodyid[jnt_id]
|
||||||
|
|
||||||
|
# Arrow origin: body position
|
||||||
|
pos = data.xpos[body_id].copy()
|
||||||
|
pos[2] += 0.02 # lift slightly above the body
|
||||||
|
|
||||||
|
# Arrow direction: along joint axis in world frame, scaled by action
|
||||||
|
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
|
||||||
|
arrow_len = 0.08 * action_val
|
||||||
|
direction = axis * np.sign(arrow_len)
|
||||||
|
|
||||||
|
# Build rotation matrix: arrow rendered along local z-axis
|
||||||
|
z = direction / (np.linalg.norm(direction) + 1e-8)
|
||||||
|
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
|
||||||
|
x = np.cross(up, z)
|
||||||
|
x /= np.linalg.norm(x) + 1e-8
|
||||||
|
y = np.cross(z, x)
|
||||||
|
mat = np.column_stack([x, y, z]).flatten()
|
||||||
|
|
||||||
|
# Color: green = positive, red = negative
|
||||||
|
rgba = np.array(
|
||||||
|
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
|
||||||
|
mujoco.mjv_initGeom(
|
||||||
|
geom,
|
||||||
|
type=mujoco.mjtGeom.mjGEOM_ARROW,
|
||||||
|
size=np.array([0.008, 0.008, abs(arrow_len)]),
|
||||||
|
pos=pos,
|
||||||
|
mat=mat,
|
||||||
|
rgba=rgba,
|
||||||
|
)
|
||||||
|
viewer.user_scn.ngeom += 1
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
||||||
def main(cfg: DictConfig) -> None:
|
def main(cfg: DictConfig) -> None:
|
||||||
choices = HydraConfig.get().runtime.choices
|
choices = HydraConfig.get().runtime.choices
|
||||||
env_name = choices.get("env", "cartpole")
|
env_name = choices.get("env", "cartpole")
|
||||||
|
runner_name = choices.get("runner", "mujoco")
|
||||||
|
|
||||||
|
if runner_name == "serial":
|
||||||
|
_main_serial(cfg, env_name)
|
||||||
|
else:
|
||||||
|
_main_sim(cfg, env_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _main_sim(cfg: DictConfig, env_name: str) -> None:
|
||||||
|
"""Simulation visualization — step MuJoCo physics with keyboard control."""
|
||||||
|
|
||||||
# Build env + runner (single env for viz)
|
# Build env + runner (single env for viz)
|
||||||
env = build_env(env_name, cfg)
|
env = build_env(env_name, cfg)
|
||||||
@@ -94,8 +153,10 @@ def main(cfg: DictConfig) -> None:
|
|||||||
action = torch.tensor([[action_val]])
|
action = torch.tensor([[action_val]])
|
||||||
obs, reward, terminated, truncated, info = runner.step(action)
|
obs, reward, terminated, truncated, info = runner.step(action)
|
||||||
|
|
||||||
# Sync viewer
|
# Sync viewer with action arrow overlay
|
||||||
mujoco.mj_forward(model, data)
|
mujoco.mj_forward(model, data)
|
||||||
|
viewer.user_scn.ngeom = 0 # clear previous frame's overlays
|
||||||
|
_add_action_arrow(viewer, model, data, action_val)
|
||||||
viewer.sync()
|
viewer.sync()
|
||||||
|
|
||||||
# Print state
|
# Print state
|
||||||
@@ -112,5 +173,75 @@ def main(cfg: DictConfig) -> None:
|
|||||||
runner.close()
|
runner.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _main_serial(cfg: DictConfig, env_name: str) -> None:
|
||||||
|
"""Digital-twin visualization — mirror real hardware in MuJoCo viewer.
|
||||||
|
|
||||||
|
The MuJoCo model is loaded for rendering only. Joint positions are
|
||||||
|
read from the ESP32 over serial and applied to the model each frame.
|
||||||
|
Keyboard arrows send motor commands to the real robot.
|
||||||
|
"""
|
||||||
|
from src.runners.serial import SerialRunner, SerialRunnerConfig
|
||||||
|
|
||||||
|
env = build_env(env_name, cfg)
|
||||||
|
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||||
|
serial_runner = SerialRunner(
|
||||||
|
env=env, config=SerialRunnerConfig(**runner_dict)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load MuJoCo model for visualisation (same URDF the sim uses).
|
||||||
|
serial_runner._ensure_viz_model()
|
||||||
|
model = serial_runner._viz_model
|
||||||
|
data = serial_runner._viz_data
|
||||||
|
|
||||||
|
with mujoco.viewer.launch_passive(
|
||||||
|
model, data, key_callback=_key_callback
|
||||||
|
) as viewer:
|
||||||
|
# Show CoM / inertia if requested.
|
||||||
|
show_com = cfg.get("com", False)
|
||||||
|
if show_com:
|
||||||
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||||
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"viewer_started",
|
||||||
|
env=env_name,
|
||||||
|
mode="serial (digital twin)",
|
||||||
|
port=serial_runner.config.port,
|
||||||
|
controls="Left/Right arrows = motor command, R = reset",
|
||||||
|
)
|
||||||
|
|
||||||
|
while viewer.is_running():
|
||||||
|
# Read action from keyboard callback.
|
||||||
|
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||||
|
action_val = _action_val[0]
|
||||||
|
else:
|
||||||
|
action_val = 0.0
|
||||||
|
|
||||||
|
# Reset on R press.
|
||||||
|
if _reset_flag[0]:
|
||||||
|
_reset_flag[0] = False
|
||||||
|
serial_runner._send("M0")
|
||||||
|
serial_runner._drive_to_center()
|
||||||
|
serial_runner._wait_for_pendulum_still()
|
||||||
|
logger.info("reset (drive-to-center + settle)")
|
||||||
|
|
||||||
|
# Send motor command to real hardware.
|
||||||
|
motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
|
||||||
|
serial_runner._send(f"M{motor_speed}")
|
||||||
|
|
||||||
|
# Sync MuJoCo model with real sensor data.
|
||||||
|
serial_runner._sync_viz()
|
||||||
|
|
||||||
|
# Render overlays and sync viewer.
|
||||||
|
viewer.user_scn.ngeom = 0
|
||||||
|
_add_action_arrow(viewer, model, data, action_val)
|
||||||
|
viewer.sync()
|
||||||
|
|
||||||
|
# Real-time pacing (~50 Hz, matches serial dt).
|
||||||
|
time.sleep(serial_runner.config.dt)
|
||||||
|
|
||||||
|
serial_runner.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user