add rotary cartpole env

This commit is contained in:
2026-03-08 22:58:32 +01:00
parent c8f28ffbcc
commit c753c369b4
15 changed files with 464 additions and 171 deletions

View File

@@ -0,0 +1,94 @@
import dataclasses
import torch
from src.core.env import BaseEnv, BaseEnvConfig
from gymnasium import spaces
@dataclasses.dataclass
class RotaryCartPoleState:
motor_angle: torch.Tensor # (num_envs,)
motor_vel: torch.Tensor # (num_envs,)
pendulum_angle: torch.Tensor # (num_envs,)
pendulum_vel: torch.Tensor # (num_envs,)
@dataclasses.dataclass
class RotaryCartPoleConfig(BaseEnvConfig):
"""Rotary inverted pendulum (Furuta pendulum) task config.
The motor rotates the arm horizontally; the pendulum swings freely
at the arm tip. Goal: swing the pendulum up and balance it upright.
"""
# Reward shaping
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
"""Furuta pendulum / rotary inverted pendulum environment.
Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum
Observations (6):
[sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel]
Using sin/cos avoids discontinuities at ±π for continuous joints.
Actions (1):
Torque applied to the motor_joint.
"""
def __init__(self, config: RotaryCartPoleConfig):
super().__init__(config)
@property
def observation_space(self) -> spaces.Space:
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,))
@property
def action_space(self) -> spaces.Space:
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState:
return RotaryCartPoleState(
motor_angle=qpos[:, 0],
motor_vel=qvel[:, 0],
pendulum_angle=qpos[:, 1],
pendulum_vel=qvel[:, 1],
)
def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor:
return torch.stack([
torch.sin(state.motor_angle),
torch.cos(state.motor_angle),
torch.sin(state.pendulum_angle),
torch.cos(state.pendulum_angle),
state.motor_vel,
state.pendulum_vel,
], dim=-1)
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
# height: sin(θ) → -1 (down) to +1 (up)
height = torch.sin(state.pendulum_angle)
# Upright reward: strongly rewards being near vertical.
# Uses cos(θ - π/2) = sin(θ), squared and scaled so:
# down (h=-1): 0.0
# horiz (h= 0): 0.0
# up (h=+1): 1.0
# Only kicks in above horizontal, so swing-up isn't penalised.
upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
# Motor effort penalty: small cost to avoid bang-bang control.
effort_penalty = 0.001 * actions.squeeze(-1) ** 2
return 5.0 * upright_reward - effort_penalty
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
# No early termination — episode runs for max_steps (truncation only).
# The agent must learn to swing up AND balance continuously.
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
def get_default_qpos(self, nq: int) -> list[float] | None:
# The STL mesh is horizontal at qpos=0.
# Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
import math
return [0.0, -math.pi / 2]