import dataclasses import torch from src.core.env import BaseEnv, BaseEnvConfig from gymnasium import spaces @dataclasses.dataclass class RotaryCartPoleState: motor_angle: torch.Tensor # (num_envs,) motor_vel: torch.Tensor # (num_envs,) pendulum_angle: torch.Tensor # (num_envs,) pendulum_vel: torch.Tensor # (num_envs,) @dataclasses.dataclass class RotaryCartPoleConfig(BaseEnvConfig): """Rotary inverted pendulum (Furuta pendulum) task config. The motor rotates the arm horizontally; the pendulum swings freely at the arm tip. Goal: swing the pendulum up and balance it upright. """ # Reward shaping reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1 class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]): """Furuta pendulum / rotary inverted pendulum environment. Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum Observations (6): [sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel] Using sin/cos avoids discontinuities at ±π for continuous joints. Actions (1): Torque applied to the motor_joint. """ def __init__(self, config: RotaryCartPoleConfig): super().__init__(config) @property def observation_space(self) -> spaces.Space: return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,)) @property def action_space(self) -> spaces.Space: return spaces.Box(low=-1.0, high=1.0, shape=(1,)) def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState: return RotaryCartPoleState( motor_angle=qpos[:, 0], motor_vel=qvel[:, 0], pendulum_angle=qpos[:, 1], pendulum_vel=qvel[:, 1], ) def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor: return torch.stack([ torch.sin(state.motor_angle), torch.cos(state.motor_angle), torch.sin(state.pendulum_angle), torch.cos(state.pendulum_angle), state.motor_vel, state.pendulum_vel, ], dim=-1) def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor: # Upright reward: -cos(θ) ∈ [-1, +1] upright = -torch.cos(state.pendulum_angle) # Velocity penalties — make spinning expensive but allow swing-up pend_vel_penalty = 0.01 * state.pendulum_vel ** 2 motor_vel_penalty = 0.01 * state.motor_vel ** 2 return upright - pend_vel_penalty - motor_vel_penalty def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor: # No early termination — episode runs for max_steps (truncation only). # The agent must learn to swing up AND balance continuously. return torch.zeros_like(state.motor_angle, dtype=torch.bool) def get_default_qpos(self, nq: int) -> list[float] | None: # qpos=0 = pendulum hanging down (joint frame rotated in URDF). return None