✨ add rotary cartpole env
This commit is contained in:
94
src/envs/rotary_cartpole.py
Normal file
94
src/envs/rotary_cartpole.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import dataclasses
|
||||
import torch
|
||||
from src.core.env import BaseEnv, BaseEnvConfig
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RotaryCartPoleState:
|
||||
motor_angle: torch.Tensor # (num_envs,)
|
||||
motor_vel: torch.Tensor # (num_envs,)
|
||||
pendulum_angle: torch.Tensor # (num_envs,)
|
||||
pendulum_vel: torch.Tensor # (num_envs,)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RotaryCartPoleConfig(BaseEnvConfig):
|
||||
"""Rotary inverted pendulum (Furuta pendulum) task config.
|
||||
|
||||
The motor rotates the arm horizontally; the pendulum swings freely
|
||||
at the arm tip. Goal: swing the pendulum up and balance it upright.
|
||||
"""
|
||||
# Reward shaping
|
||||
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
|
||||
|
||||
|
||||
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||
"""Furuta pendulum / rotary inverted pendulum environment.
|
||||
|
||||
Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum
|
||||
|
||||
Observations (6):
|
||||
[sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel]
|
||||
Using sin/cos avoids discontinuities at ±π for continuous joints.
|
||||
|
||||
Actions (1):
|
||||
Torque applied to the motor_joint.
|
||||
"""
|
||||
|
||||
def __init__(self, config: RotaryCartPoleConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,))
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
||||
|
||||
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState:
|
||||
return RotaryCartPoleState(
|
||||
motor_angle=qpos[:, 0],
|
||||
motor_vel=qvel[:, 0],
|
||||
pendulum_angle=qpos[:, 1],
|
||||
pendulum_vel=qvel[:, 1],
|
||||
)
|
||||
|
||||
def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||
return torch.stack([
|
||||
torch.sin(state.motor_angle),
|
||||
torch.cos(state.motor_angle),
|
||||
torch.sin(state.pendulum_angle),
|
||||
torch.cos(state.pendulum_angle),
|
||||
state.motor_vel,
|
||||
state.pendulum_vel,
|
||||
], dim=-1)
|
||||
|
||||
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
||||
# height: sin(θ) → -1 (down) to +1 (up)
|
||||
height = torch.sin(state.pendulum_angle)
|
||||
|
||||
# Upright reward: strongly rewards being near vertical.
|
||||
# Uses cos(θ - π/2) = sin(θ), squared and scaled so:
|
||||
# down (h=-1): 0.0
|
||||
# horiz (h= 0): 0.0
|
||||
# up (h=+1): 1.0
|
||||
# Only kicks in above horizontal, so swing-up isn't penalised.
|
||||
upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
|
||||
|
||||
# Motor effort penalty: small cost to avoid bang-bang control.
|
||||
effort_penalty = 0.001 * actions.squeeze(-1) ** 2
|
||||
|
||||
return 5.0 * upright_reward - effort_penalty
|
||||
|
||||
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||
# No early termination — episode runs for max_steps (truncation only).
|
||||
# The agent must learn to swing up AND balance continuously.
|
||||
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
|
||||
|
||||
def get_default_qpos(self, nq: int) -> list[float] | None:
|
||||
# The STL mesh is horizontal at qpos=0.
|
||||
# Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
|
||||
import math
|
||||
return [0.0, -math.pi / 2]
|
||||
Reference in New Issue
Block a user