431 lines
15 KiB
Python
431 lines
15 KiB
Python
"""Capture a real-robot trajectory under random excitation (PRBS-style).
|
||
|
||
Connects to the ESP32 over serial, sends random PWM commands to excite
|
||
the system, and records motor + pendulum angles and velocities at ~50 Hz.
|
||
|
||
Saves a compressed numpy archive (.npz) that the optimizer can replay
|
||
in simulation to fit physics parameters.
|
||
|
||
Serial protocol (same as SerialRunner):
|
||
S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>
|
||
(7 comma-separated fields — firmware sends SI units)
|
||
|
||
Usage:
|
||
python -m src.sysid.capture \
|
||
--robot-path assets/rotary_cartpole \
|
||
--port /dev/cu.usbserial-0001 \
|
||
--duration 20
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import math
|
||
import random
|
||
import threading
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import numpy as np
|
||
import structlog
|
||
|
||
log = structlog.get_logger()
|
||
|
||
|
||
# ── Serial protocol helpers (mirrored from SerialRunner) ─────────────
|
||
|
||
|
||
def _parse_state_line(line: str) -> dict[str, Any] | None:
|
||
"""Parse an ``S,…`` state line from the ESP32.
|
||
|
||
Format: S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>
|
||
(7 comma-separated fields, firmware sends SI units)
|
||
"""
|
||
if not line.startswith("S,"):
|
||
return None
|
||
parts = line.split(",")
|
||
if len(parts) < 7:
|
||
return None
|
||
try:
|
||
return {
|
||
"timestamp_ms": int(parts[1]),
|
||
"motor_rad": float(parts[2]),
|
||
"motor_vel": float(parts[3]),
|
||
"pend_rad": float(parts[4]),
|
||
"pend_vel": float(parts[5]),
|
||
"motor_speed": int(parts[6]),
|
||
}
|
||
except (ValueError, IndexError):
|
||
return None
|
||
|
||
|
||
# ── Background serial reader ─────────────────────────────────────────
|
||
|
||
|
||
class _SerialReader:
|
||
"""Minimal background reader for the ESP32 serial stream.
|
||
|
||
Uses a sequence counter so ``read_blocking()`` guarantees it returns
|
||
a *new* state line (not a stale repeat). This keeps the capture
|
||
loop locked to the firmware's 50 Hz tick.
|
||
"""
|
||
|
||
def __init__(self, port: str, baud: int = 115200):
|
||
import serial as _serial
|
||
|
||
self._serial_mod = _serial
|
||
self.ser = _serial.Serial(port, baud, timeout=0.05)
|
||
time.sleep(2) # Wait for ESP32 boot.
|
||
self.ser.reset_input_buffer()
|
||
|
||
self._latest: dict[str, Any] = {}
|
||
self._seq: int = 0 # incremented on every new state line
|
||
self._lock = threading.Lock()
|
||
self._cond = threading.Condition(self._lock)
|
||
self._running = True
|
||
|
||
self._thread = threading.Thread(target=self._reader_loop, daemon=True)
|
||
self._thread.start()
|
||
|
||
def _reader_loop(self) -> None:
|
||
_debug_count = 0
|
||
while self._running:
|
||
try:
|
||
if self.ser.in_waiting:
|
||
line = (
|
||
self.ser.readline()
|
||
.decode("utf-8", errors="ignore")
|
||
.strip()
|
||
)
|
||
# Debug: log first 10 raw lines so we can see what the firmware sends.
|
||
if _debug_count < 10 and line:
|
||
log.info("serial_raw_line", line=repr(line), count=_debug_count)
|
||
_debug_count += 1
|
||
parsed = _parse_state_line(line)
|
||
if parsed is not None:
|
||
with self._cond:
|
||
self._latest = parsed
|
||
self._seq += 1
|
||
self._cond.notify_all()
|
||
else:
|
||
time.sleep(0.001)
|
||
except (OSError, self._serial_mod.SerialException):
|
||
log.critical("serial_lost")
|
||
break
|
||
|
||
def send(self, cmd: str) -> None:
|
||
try:
|
||
self.ser.write(f"{cmd}\n".encode())
|
||
except (OSError, self._serial_mod.SerialException):
|
||
log.critical("serial_send_failed", cmd=cmd)
|
||
|
||
def read_blocking(self, timeout: float = 0.1) -> dict[str, Any]:
|
||
"""Wait until a *new* state line arrives, then return it.
|
||
|
||
Uses a sequence counter to guarantee the returned state is
|
||
different from whatever was available before this call.
|
||
"""
|
||
with self._cond:
|
||
seq_before = self._seq
|
||
if not self._cond.wait_for(
|
||
lambda: self._seq > seq_before, timeout=timeout
|
||
):
|
||
return {} # timeout — no new data
|
||
return dict(self._latest)
|
||
|
||
def close(self) -> None:
|
||
self._running = False
|
||
self.send("H")
|
||
self.send("M0")
|
||
time.sleep(0.1)
|
||
self._thread.join(timeout=1.0)
|
||
self.ser.close()
|
||
|
||
|
||
# ── PRBS excitation signal ───────────────────────────────────────────
|
||
|
||
|
||
class _PRBSExcitation:
|
||
"""Random hold-value excitation with configurable amplitude and hold time.
|
||
|
||
At each call to ``__call__``, returns the current PWM value.
|
||
The value is held for a random duration (``hold_min``–``hold_max`` ms),
|
||
then a new random value is drawn uniformly from ``[-amplitude, +amplitude]``.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
amplitude: int = 150,
|
||
hold_min_ms: int = 50,
|
||
hold_max_ms: int = 300,
|
||
):
|
||
self.amplitude = amplitude
|
||
self.hold_min_ms = hold_min_ms
|
||
self.hold_max_ms = hold_max_ms
|
||
self._current: int = 0
|
||
self._switch_time: float = 0.0
|
||
self._new_value()
|
||
|
||
def _new_value(self) -> None:
|
||
self._current = random.randint(-self.amplitude, self.amplitude)
|
||
hold_ms = random.randint(self.hold_min_ms, self.hold_max_ms)
|
||
self._switch_time = time.monotonic() + hold_ms / 1000.0
|
||
|
||
def __call__(self) -> int:
|
||
if time.monotonic() >= self._switch_time:
|
||
self._new_value()
|
||
return self._current
|
||
|
||
|
||
# ── Main capture loop ────────────────────────────────────────────────
|
||
|
||
|
||
def capture(
|
||
robot_path: str | Path,
|
||
port: str = "/dev/cu.usbserial-0001",
|
||
baud: int = 115200,
|
||
duration: float = 20.0,
|
||
amplitude: int = 150,
|
||
hold_min_ms: int = 50,
|
||
hold_max_ms: int = 300,
|
||
dt: float = 0.02,
|
||
motor_angle_limit_deg: float = 90.0,
|
||
) -> Path:
|
||
"""Run the capture procedure and return the path to the saved .npz file.
|
||
|
||
The capture loop is **stream-driven**: it blocks on each incoming
|
||
state line from the firmware (which arrives at 50 Hz), sends the
|
||
next motor command immediately, and records both.
|
||
|
||
Parameters
|
||
----------
|
||
robot_path : path to robot asset directory
|
||
port : serial port for ESP32
|
||
baud : baud rate
|
||
duration : capture duration in seconds
|
||
amplitude : max PWM magnitude for excitation
|
||
hold_min_ms / hold_max_ms : random hold time range (ms)
|
||
dt : nominal sample period for buffer sizing (seconds)
|
||
motor_angle_limit_deg : safety limit for motor angle
|
||
"""
|
||
robot_path = Path(robot_path).resolve()
|
||
|
||
max_motor_rad = math.radians(motor_angle_limit_deg) if motor_angle_limit_deg > 0 else 0.0
|
||
|
||
# Connect.
|
||
reader = _SerialReader(port, baud)
|
||
excitation = _PRBSExcitation(amplitude, hold_min_ms, hold_max_ms)
|
||
|
||
# Prepare recording buffers (generous headroom).
|
||
max_samples = int(duration / dt) + 500
|
||
rec_time = np.zeros(max_samples, dtype=np.float64)
|
||
rec_action = np.zeros(max_samples, dtype=np.float64)
|
||
rec_motor_angle = np.zeros(max_samples, dtype=np.float64)
|
||
rec_motor_vel = np.zeros(max_samples, dtype=np.float64)
|
||
rec_pend_angle = np.zeros(max_samples, dtype=np.float64)
|
||
rec_pend_vel = np.zeros(max_samples, dtype=np.float64)
|
||
|
||
# Start streaming.
|
||
reader.send("G")
|
||
time.sleep(0.1)
|
||
|
||
log.info(
|
||
"capture_starting",
|
||
port=port,
|
||
duration=duration,
|
||
amplitude=amplitude,
|
||
hold_range_ms=f"{hold_min_ms}–{hold_max_ms}",
|
||
mode="stream-driven (firmware clock)",
|
||
)
|
||
|
||
idx = 0
|
||
pwm = 0
|
||
last_esp_ms = -1 # firmware timestamp of last recorded sample
|
||
no_data_count = 0 # consecutive timeouts with no data
|
||
t0 = time.monotonic()
|
||
try:
|
||
while True:
|
||
# Block until the firmware sends the next state line (~20 ms).
|
||
# Timeout at 100 ms prevents hanging if the ESP32 disconnects.
|
||
state = reader.read_blocking(timeout=0.1)
|
||
if not state:
|
||
no_data_count += 1
|
||
if no_data_count == 30: # 3 seconds with no data
|
||
log.warning(
|
||
"no_data_received",
|
||
msg="No state lines from firmware after 3s. "
|
||
"Check: is the ESP32 powered? Is it running the right firmware? "
|
||
"Try pressing the RESET button.",
|
||
)
|
||
if no_data_count == 100: # 10 seconds
|
||
log.critical(
|
||
"no_data_timeout",
|
||
msg="No data for 10s — aborting capture.",
|
||
)
|
||
break
|
||
continue # no data yet — retry
|
||
no_data_count = 0
|
||
|
||
# Deduplicate: the firmware may send multiple state lines per
|
||
# tick (e.g. M-command echo + tick). Only record one sample
|
||
# per unique firmware timestamp.
|
||
esp_ms = state.get("timestamp_ms", 0)
|
||
if esp_ms == last_esp_ms:
|
||
continue
|
||
last_esp_ms = esp_ms
|
||
|
||
elapsed = time.monotonic() - t0
|
||
if elapsed >= duration:
|
||
break
|
||
|
||
# Get excitation PWM for the NEXT tick.
|
||
pwm = excitation()
|
||
|
||
# Safety: keep the arm well within its mechanical range.
|
||
# Firmware sends motor angle in radians — use directly.
|
||
motor_angle_rad = state.get("motor_rad", 0.0)
|
||
if max_motor_rad > 0:
|
||
ratio = motor_angle_rad / max_motor_rad # signed, -1..+1
|
||
abs_ratio = abs(ratio)
|
||
|
||
if abs_ratio > 0.90:
|
||
# Deep in the danger zone — force a strong return.
|
||
brake_strength = min(1.0, (abs_ratio - 0.90) / 0.10) # 0→1
|
||
brake_pwm = int(amplitude * (0.5 + 0.5 * brake_strength))
|
||
pwm = -brake_pwm if ratio > 0 else brake_pwm
|
||
elif abs_ratio > 0.70:
|
||
# Soft zone — only allow actions pointing back to centre.
|
||
if ratio > 0 and pwm > 0:
|
||
pwm = -abs(pwm)
|
||
elif ratio < 0 and pwm < 0:
|
||
pwm = abs(pwm)
|
||
|
||
# Send command immediately — it will take effect on the next tick.
|
||
reader.send(f"M{pwm}")
|
||
|
||
# Record this tick's state + the action the motor *actually*
|
||
# received. Firmware sends SI units — use directly.
|
||
motor_angle = state.get("motor_rad", 0.0)
|
||
motor_vel = state.get("motor_vel", 0.0)
|
||
pend_angle = state.get("pend_rad", 0.0)
|
||
pend_vel = state.get("pend_vel", 0.0)
|
||
# Firmware constrains to ±255; normalise to [-1, 1].
|
||
applied = state.get("motor_speed", 0)
|
||
action_norm = max(-255, min(255, applied)) / 255.0
|
||
|
||
if idx < max_samples:
|
||
rec_time[idx] = elapsed
|
||
rec_action[idx] = action_norm
|
||
rec_motor_angle[idx] = motor_angle
|
||
rec_motor_vel[idx] = motor_vel
|
||
rec_pend_angle[idx] = pend_angle
|
||
rec_pend_vel[idx] = pend_vel
|
||
idx += 1
|
||
else:
|
||
break # buffer full
|
||
|
||
# Progress (every 50 samples ≈ once per second at 50 Hz).
|
||
if idx % 50 == 0:
|
||
log.info(
|
||
"capture_progress",
|
||
elapsed=f"{elapsed:.1f}/{duration:.0f}s",
|
||
samples=idx,
|
||
pwm=pwm,
|
||
)
|
||
|
||
finally:
|
||
reader.send("M0")
|
||
reader.close()
|
||
|
||
# Trim to actual sample count.
|
||
rec_time = rec_time[:idx]
|
||
rec_action = rec_action[:idx]
|
||
rec_motor_angle = rec_motor_angle[:idx]
|
||
rec_motor_vel = rec_motor_vel[:idx]
|
||
rec_pend_angle = rec_pend_angle[:idx]
|
||
rec_pend_vel = rec_pend_vel[:idx]
|
||
|
||
# Save.
|
||
recordings_dir = robot_path / "recordings"
|
||
recordings_dir.mkdir(exist_ok=True)
|
||
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
out_path = recordings_dir / f"capture_{stamp}.npz"
|
||
np.savez_compressed(
|
||
out_path,
|
||
time=rec_time,
|
||
action=rec_action,
|
||
motor_angle=rec_motor_angle,
|
||
motor_vel=rec_motor_vel,
|
||
pendulum_angle=rec_pend_angle,
|
||
pendulum_vel=rec_pend_vel,
|
||
)
|
||
|
||
log.info(
|
||
"capture_saved",
|
||
path=str(out_path),
|
||
samples=idx,
|
||
duration_actual=f"{rec_time[-1]:.2f}s" if idx > 0 else "0s",
|
||
)
|
||
return out_path
|
||
|
||
|
||
# ── CLI entry point ──────────────────────────────────────────────────
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(
|
||
description="Capture a real-robot trajectory for system identification."
|
||
)
|
||
parser.add_argument(
|
||
"--robot-path",
|
||
type=str,
|
||
default="assets/rotary_cartpole",
|
||
help="Path to robot asset directory",
|
||
)
|
||
parser.add_argument(
|
||
"--port",
|
||
type=str,
|
||
default="/dev/cu.usbserial-0001",
|
||
help="Serial port for ESP32",
|
||
)
|
||
parser.add_argument("--baud", type=int, default=115200)
|
||
parser.add_argument(
|
||
"--duration", type=float, default=20.0, help="Capture duration (s)"
|
||
)
|
||
parser.add_argument(
|
||
"--amplitude", type=int, default=150,
|
||
help="Max PWM magnitude (should not exceed firmware MAX_MOTOR_SPEED=150)",
|
||
)
|
||
parser.add_argument(
|
||
"--hold-min-ms", type=int, default=50, help="Min hold time (ms)"
|
||
)
|
||
parser.add_argument(
|
||
"--hold-max-ms", type=int, default=300, help="Max hold time (ms)"
|
||
)
|
||
parser.add_argument(
|
||
"--dt", type=float, default=0.02, help="Nominal sample period for buffer sizing (s)"
|
||
)
|
||
parser.add_argument(
|
||
"--motor-angle-limit", type=float, default=90.0,
|
||
help="Motor angle safety limit in degrees (0 = disabled)",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
capture(
|
||
robot_path=args.robot_path,
|
||
port=args.port,
|
||
baud=args.baud,
|
||
duration=args.duration,
|
||
amplitude=args.amplitude,
|
||
hold_min_ms=args.hold_min_ms,
|
||
hold_max_ms=args.hold_max_ms,
|
||
dt=args.dt,
|
||
motor_angle_limit_deg=args.motor_angle_limit,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|