Files
RL-Sim-Framework/src/sysid/capture.py
2026-03-22 15:49:13 +01:00

431 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Capture a real-robot trajectory under random excitation (PRBS-style).
Connects to the ESP32 over serial, sends random PWM commands to excite
the system, and records motor + pendulum angles and velocities at ~50 Hz.
Saves a compressed numpy archive (.npz) that the optimizer can replay
in simulation to fit physics parameters.
Serial protocol (same as SerialRunner):
S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>
(7 comma-separated fields — firmware sends SI units)
Usage:
python -m src.sysid.capture \
--robot-path assets/rotary_cartpole \
--port /dev/cu.usbserial-0001 \
--duration 20
"""
from __future__ import annotations
import argparse
import math
import random
import threading
import time
from datetime import datetime
from pathlib import Path
from typing import Any
import numpy as np
import structlog
log = structlog.get_logger()
# ── Serial protocol helpers (mirrored from SerialRunner) ─────────────
def _parse_state_line(line: str) -> dict[str, Any] | None:
"""Parse an ``S,…`` state line from the ESP32.
Format: S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>
(7 comma-separated fields, firmware sends SI units)
"""
if not line.startswith("S,"):
return None
parts = line.split(",")
if len(parts) < 7:
return None
try:
return {
"timestamp_ms": int(parts[1]),
"motor_rad": float(parts[2]),
"motor_vel": float(parts[3]),
"pend_rad": float(parts[4]),
"pend_vel": float(parts[5]),
"motor_speed": int(parts[6]),
}
except (ValueError, IndexError):
return None
# ── Background serial reader ─────────────────────────────────────────
class _SerialReader:
"""Minimal background reader for the ESP32 serial stream.
Uses a sequence counter so ``read_blocking()`` guarantees it returns
a *new* state line (not a stale repeat). This keeps the capture
loop locked to the firmware's 50 Hz tick.
"""
def __init__(self, port: str, baud: int = 115200):
import serial as _serial
self._serial_mod = _serial
self.ser = _serial.Serial(port, baud, timeout=0.05)
time.sleep(2) # Wait for ESP32 boot.
self.ser.reset_input_buffer()
self._latest: dict[str, Any] = {}
self._seq: int = 0 # incremented on every new state line
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
self._running = True
self._thread = threading.Thread(target=self._reader_loop, daemon=True)
self._thread.start()
def _reader_loop(self) -> None:
_debug_count = 0
while self._running:
try:
if self.ser.in_waiting:
line = (
self.ser.readline()
.decode("utf-8", errors="ignore")
.strip()
)
# Debug: log first 10 raw lines so we can see what the firmware sends.
if _debug_count < 10 and line:
log.info("serial_raw_line", line=repr(line), count=_debug_count)
_debug_count += 1
parsed = _parse_state_line(line)
if parsed is not None:
with self._cond:
self._latest = parsed
self._seq += 1
self._cond.notify_all()
else:
time.sleep(0.001)
except (OSError, self._serial_mod.SerialException):
log.critical("serial_lost")
break
def send(self, cmd: str) -> None:
try:
self.ser.write(f"{cmd}\n".encode())
except (OSError, self._serial_mod.SerialException):
log.critical("serial_send_failed", cmd=cmd)
def read_blocking(self, timeout: float = 0.1) -> dict[str, Any]:
"""Wait until a *new* state line arrives, then return it.
Uses a sequence counter to guarantee the returned state is
different from whatever was available before this call.
"""
with self._cond:
seq_before = self._seq
if not self._cond.wait_for(
lambda: self._seq > seq_before, timeout=timeout
):
return {} # timeout — no new data
return dict(self._latest)
def close(self) -> None:
self._running = False
self.send("H")
self.send("M0")
time.sleep(0.1)
self._thread.join(timeout=1.0)
self.ser.close()
# ── PRBS excitation signal ───────────────────────────────────────────
class _PRBSExcitation:
"""Random hold-value excitation with configurable amplitude and hold time.
At each call to ``__call__``, returns the current PWM value.
The value is held for a random duration (``hold_min````hold_max`` ms),
then a new random value is drawn uniformly from ``[-amplitude, +amplitude]``.
"""
def __init__(
self,
amplitude: int = 150,
hold_min_ms: int = 50,
hold_max_ms: int = 300,
):
self.amplitude = amplitude
self.hold_min_ms = hold_min_ms
self.hold_max_ms = hold_max_ms
self._current: int = 0
self._switch_time: float = 0.0
self._new_value()
def _new_value(self) -> None:
self._current = random.randint(-self.amplitude, self.amplitude)
hold_ms = random.randint(self.hold_min_ms, self.hold_max_ms)
self._switch_time = time.monotonic() + hold_ms / 1000.0
def __call__(self) -> int:
if time.monotonic() >= self._switch_time:
self._new_value()
return self._current
# ── Main capture loop ────────────────────────────────────────────────
def capture(
robot_path: str | Path,
port: str = "/dev/cu.usbserial-0001",
baud: int = 115200,
duration: float = 20.0,
amplitude: int = 150,
hold_min_ms: int = 50,
hold_max_ms: int = 300,
dt: float = 0.02,
motor_angle_limit_deg: float = 90.0,
) -> Path:
"""Run the capture procedure and return the path to the saved .npz file.
The capture loop is **stream-driven**: it blocks on each incoming
state line from the firmware (which arrives at 50 Hz), sends the
next motor command immediately, and records both.
Parameters
----------
robot_path : path to robot asset directory
port : serial port for ESP32
baud : baud rate
duration : capture duration in seconds
amplitude : max PWM magnitude for excitation
hold_min_ms / hold_max_ms : random hold time range (ms)
dt : nominal sample period for buffer sizing (seconds)
motor_angle_limit_deg : safety limit for motor angle
"""
robot_path = Path(robot_path).resolve()
max_motor_rad = math.radians(motor_angle_limit_deg) if motor_angle_limit_deg > 0 else 0.0
# Connect.
reader = _SerialReader(port, baud)
excitation = _PRBSExcitation(amplitude, hold_min_ms, hold_max_ms)
# Prepare recording buffers (generous headroom).
max_samples = int(duration / dt) + 500
rec_time = np.zeros(max_samples, dtype=np.float64)
rec_action = np.zeros(max_samples, dtype=np.float64)
rec_motor_angle = np.zeros(max_samples, dtype=np.float64)
rec_motor_vel = np.zeros(max_samples, dtype=np.float64)
rec_pend_angle = np.zeros(max_samples, dtype=np.float64)
rec_pend_vel = np.zeros(max_samples, dtype=np.float64)
# Start streaming.
reader.send("G")
time.sleep(0.1)
log.info(
"capture_starting",
port=port,
duration=duration,
amplitude=amplitude,
hold_range_ms=f"{hold_min_ms}{hold_max_ms}",
mode="stream-driven (firmware clock)",
)
idx = 0
pwm = 0
last_esp_ms = -1 # firmware timestamp of last recorded sample
no_data_count = 0 # consecutive timeouts with no data
t0 = time.monotonic()
try:
while True:
# Block until the firmware sends the next state line (~20 ms).
# Timeout at 100 ms prevents hanging if the ESP32 disconnects.
state = reader.read_blocking(timeout=0.1)
if not state:
no_data_count += 1
if no_data_count == 30: # 3 seconds with no data
log.warning(
"no_data_received",
msg="No state lines from firmware after 3s. "
"Check: is the ESP32 powered? Is it running the right firmware? "
"Try pressing the RESET button.",
)
if no_data_count == 100: # 10 seconds
log.critical(
"no_data_timeout",
msg="No data for 10s — aborting capture.",
)
break
continue # no data yet — retry
no_data_count = 0
# Deduplicate: the firmware may send multiple state lines per
# tick (e.g. M-command echo + tick). Only record one sample
# per unique firmware timestamp.
esp_ms = state.get("timestamp_ms", 0)
if esp_ms == last_esp_ms:
continue
last_esp_ms = esp_ms
elapsed = time.monotonic() - t0
if elapsed >= duration:
break
# Get excitation PWM for the NEXT tick.
pwm = excitation()
# Safety: keep the arm well within its mechanical range.
# Firmware sends motor angle in radians — use directly.
motor_angle_rad = state.get("motor_rad", 0.0)
if max_motor_rad > 0:
ratio = motor_angle_rad / max_motor_rad # signed, -1..+1
abs_ratio = abs(ratio)
if abs_ratio > 0.90:
# Deep in the danger zone — force a strong return.
brake_strength = min(1.0, (abs_ratio - 0.90) / 0.10) # 0→1
brake_pwm = int(amplitude * (0.5 + 0.5 * brake_strength))
pwm = -brake_pwm if ratio > 0 else brake_pwm
elif abs_ratio > 0.70:
# Soft zone — only allow actions pointing back to centre.
if ratio > 0 and pwm > 0:
pwm = -abs(pwm)
elif ratio < 0 and pwm < 0:
pwm = abs(pwm)
# Send command immediately — it will take effect on the next tick.
reader.send(f"M{pwm}")
# Record this tick's state + the action the motor *actually*
# received. Firmware sends SI units — use directly.
motor_angle = state.get("motor_rad", 0.0)
motor_vel = state.get("motor_vel", 0.0)
pend_angle = state.get("pend_rad", 0.0)
pend_vel = state.get("pend_vel", 0.0)
# Firmware constrains to ±255; normalise to [-1, 1].
applied = state.get("motor_speed", 0)
action_norm = max(-255, min(255, applied)) / 255.0
if idx < max_samples:
rec_time[idx] = elapsed
rec_action[idx] = action_norm
rec_motor_angle[idx] = motor_angle
rec_motor_vel[idx] = motor_vel
rec_pend_angle[idx] = pend_angle
rec_pend_vel[idx] = pend_vel
idx += 1
else:
break # buffer full
# Progress (every 50 samples ≈ once per second at 50 Hz).
if idx % 50 == 0:
log.info(
"capture_progress",
elapsed=f"{elapsed:.1f}/{duration:.0f}s",
samples=idx,
pwm=pwm,
)
finally:
reader.send("M0")
reader.close()
# Trim to actual sample count.
rec_time = rec_time[:idx]
rec_action = rec_action[:idx]
rec_motor_angle = rec_motor_angle[:idx]
rec_motor_vel = rec_motor_vel[:idx]
rec_pend_angle = rec_pend_angle[:idx]
rec_pend_vel = rec_pend_vel[:idx]
# Save.
recordings_dir = robot_path / "recordings"
recordings_dir.mkdir(exist_ok=True)
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = recordings_dir / f"capture_{stamp}.npz"
np.savez_compressed(
out_path,
time=rec_time,
action=rec_action,
motor_angle=rec_motor_angle,
motor_vel=rec_motor_vel,
pendulum_angle=rec_pend_angle,
pendulum_vel=rec_pend_vel,
)
log.info(
"capture_saved",
path=str(out_path),
samples=idx,
duration_actual=f"{rec_time[-1]:.2f}s" if idx > 0 else "0s",
)
return out_path
# ── CLI entry point ──────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(
description="Capture a real-robot trajectory for system identification."
)
parser.add_argument(
"--robot-path",
type=str,
default="assets/rotary_cartpole",
help="Path to robot asset directory",
)
parser.add_argument(
"--port",
type=str,
default="/dev/cu.usbserial-0001",
help="Serial port for ESP32",
)
parser.add_argument("--baud", type=int, default=115200)
parser.add_argument(
"--duration", type=float, default=20.0, help="Capture duration (s)"
)
parser.add_argument(
"--amplitude", type=int, default=150,
help="Max PWM magnitude (should not exceed firmware MAX_MOTOR_SPEED=150)",
)
parser.add_argument(
"--hold-min-ms", type=int, default=50, help="Min hold time (ms)"
)
parser.add_argument(
"--hold-max-ms", type=int, default=300, help="Max hold time (ms)"
)
parser.add_argument(
"--dt", type=float, default=0.02, help="Nominal sample period for buffer sizing (s)"
)
parser.add_argument(
"--motor-angle-limit", type=float, default=90.0,
help="Motor angle safety limit in degrees (0 = disabled)",
)
args = parser.parse_args()
capture(
robot_path=args.robot_path,
port=args.port,
baud=args.baud,
duration=args.duration,
amplitude=args.amplitude,
hold_min_ms=args.hold_min_ms,
hold_max_ms=args.hold_max_ms,
dt=args.dt,
motor_angle_limit_deg=args.motor_angle_limit,
)
if __name__ == "__main__":
main()