From c8f28ffbcceb487e3fd47869ccd4aeaf554ab38c Mon Sep 17 00:00:00 2001 From: Victor Mylle Date: Fri, 6 Mar 2026 22:19:44 +0100 Subject: [PATCH] :sparkles: initial commit --- .gitignore | 3 + .python-version | 1 + assets/cartpole/cartpole.urdf | 64 +++++++++ configs/config.yaml | 5 + configs/env/cartpole.yaml | 11 ++ configs/runner/mujoco.yaml | 4 + configs/training/ppo.yaml | 13 ++ requirements.txt | 8 ++ src/core/__init__.py | 0 src/core/env.py | 59 +++++++++ src/core/runner.py | 97 ++++++++++++++ src/envs/cartpole.py | 53 ++++++++ src/models/__init__.py | 0 src/models/mlp.py | 48 +++++++ src/runners/mujoco.py | 155 ++++++++++++++++++++++ src/training/trainer.py | 243 ++++++++++++++++++++++++++++++++++ train.py | 47 +++++++ 17 files changed, 811 insertions(+) create mode 100644 .gitignore create mode 100644 .python-version create mode 100644 assets/cartpole/cartpole.urdf create mode 100644 configs/config.yaml create mode 100644 configs/env/cartpole.yaml create mode 100644 configs/runner/mujoco.yaml create mode 100644 configs/training/ppo.yaml create mode 100644 requirements.txt create mode 100644 src/core/__init__.py create mode 100644 src/core/env.py create mode 100644 src/core/runner.py create mode 100644 src/envs/cartpole.py create mode 100644 src/models/__init__.py create mode 100644 src/models/mlp.py create mode 100644 src/runners/mujoco.py create mode 100644 src/training/trainer.py create mode 100644 train.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..25335e9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +outputs/ +.vscode/ +runs/ \ No newline at end of file diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..0bc7efb --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +RL-Framework-7914bb diff --git a/assets/cartpole/cartpole.urdf b/assets/cartpole/cartpole.urdf new file mode 100644 index 0000000..82ed29b --- /dev/null +++ b/assets/cartpole/cartpole.urdf @@ -0,0 +1,64 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000..bbbbe20 --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,5 @@ +defaults: + - env: cartpole + - runner: mujoco + - training: ppo + - _self_ \ No newline at end of file diff --git a/configs/env/cartpole.yaml b/configs/env/cartpole.yaml new file mode 100644 index 0000000..f7d0722 --- /dev/null +++ b/configs/env/cartpole.yaml @@ -0,0 +1,11 @@ +max_steps: 500 +angle_threshold: 0.418 +cart_limit: 2.4 +reward_alive: 1.0 +reward_pole_upright_scale: 1.0 +reward_action_penalty_scale: 0.01 +model_path: assets/cartpole/cartpole.urdf +actuators: + - joint: cart_joint + gear: 10.0 + ctrl_range: [-1.0, 1.0] diff --git a/configs/runner/mujoco.yaml b/configs/runner/mujoco.yaml new file mode 100644 index 0000000..42a54b5 --- /dev/null +++ b/configs/runner/mujoco.yaml @@ -0,0 +1,4 @@ +num_envs: 16 +device: cpu +dt: 0.02 +substeps: 2 diff --git a/configs/training/ppo.yaml b/configs/training/ppo.yaml new file mode 100644 index 0000000..8025216 --- /dev/null +++ b/configs/training/ppo.yaml @@ -0,0 +1,13 @@ +hidden_sizes: [128, 128] +total_timesteps: 1000000 +rollout_steps: 1024 +learning_epochs: 4 +mini_batches: 4 +discount_factor: 0.99 +gae_lambda: 0.95 +learning_rate: 0.0003 +clip_ratio: 0.2 +value_loss_scale: 0.5 +entropy_loss_scale: 0.01 +log_interval: 10 +clearml_project: RL-Framework diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f40d2e8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +torch +gymnasium +hydra-core +omegaconf +mujoco +skrl[torch] +clearml +pytest \ No newline at end of file diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/env.py b/src/core/env.py new file mode 100644 index 0000000..2654898 --- /dev/null +++ b/src/core/env.py @@ -0,0 +1,59 @@ +import abc +import dataclasses +from typing import TypeVar, Generic, Any +from gymnasium import spaces +import torch +import pathlib + +T = TypeVar("T") + + +@dataclasses.dataclass +class ActuatorConfig: + """Actuator definition — maps a joint to a motor with gear ratio and control limits. + Kept in the env config (not runner config) because actuators define what the robot + can do, which determines action space — a task-level concept. + This mirrors Isaac Lab's pattern of separating actuator config from the robot file.""" + joint: str = "" + gear: float = 1.0 + ctrl_range: tuple[float, float] = (-1.0, 1.0) + + +@dataclasses.dataclass +class BaseEnvConfig: + max_steps: int = 1000 + model_path: pathlib.Path | None = None + actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list) + +class BaseEnv(abc.ABC, Generic[T]): + def __init__(self, config: BaseEnvConfig): + self.config = config + + @property + @abc.abstractmethod + def observation_space(self) -> spaces.Space: + ... + + @property + @abc.abstractmethod + def action_space(self) -> spaces.Space: + ... + + @abc.abstractmethod + def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> Any: + ... + + @abc.abstractmethod + def compute_observations(self, state: Any) -> torch.Tensor: + ... + + @abc.abstractmethod + def compute_rewards(self, state: Any, actions: torch.Tensor) -> torch.Tensor: + ... + + @abc.abstractmethod + def compute_terminations(self, state: Any) -> torch.Tensor: + ... + + def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor: + return step_counts >= self.config.max_steps diff --git a/src/core/runner.py b/src/core/runner.py new file mode 100644 index 0000000..f3e17fa --- /dev/null +++ b/src/core/runner.py @@ -0,0 +1,97 @@ +import dataclasses +import abc +from typing import Any, Generic, TypeVar +from src.core.env import BaseEnv +import torch + + +T = TypeVar("T") + +@dataclasses.dataclass +class BaseRunnerConfig: + num_envs: int = 1 + device: str = "cpu" + +class BaseRunner(abc.ABC, Generic[T]): + def __init__(self, env: BaseEnv, config: T) -> None: + self.env = env + self.config = config + + self._sim_initialize(config) + + self.observation_space = self.env.observation_space + self.action_space = self.env.action_space + self.num_agents: int = 1 # single-agent RL (required by skrl) + + self.step_counts = torch.zeros( + self.config.num_envs, dtype=torch.long, device=self.config.device + ) + + @property + @abc.abstractmethod + def num_envs(self) -> int: + ... + + @property + @abc.abstractmethod + def device(self) -> torch.device: + ... + + @abc.abstractmethod + def _sim_initialize(self, config: T) -> None: + ... + + @abc.abstractmethod + def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + ... + + @abc.abstractmethod + def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + ... + + @abc.abstractmethod + def _sim_close(self) -> None: + ... + + def reset(self) -> tuple[torch.Tensor, dict[str, Any]]: + all_ids = torch.arange(self.num_envs, device=self.device) + qpos, qvel = self._sim_reset(all_ids) + self.step_counts.zero_() + + state = self.env.build_state(qpos, qvel) + obs = self.env.compute_observations(state) + return obs, {} + + def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]: + qpos, qvel = self._sim_step(actions) + self.step_counts += 1 + + state = self.env.build_state(qpos, qvel) + obs = self.env.compute_observations(state) + rewards = self.env.compute_rewards(state, actions) + terminated = self.env.compute_terminations(state) + truncated = self.env.compute_truncations(self.step_counts) + + info: dict[str, Any] = {} + + done = terminated | truncated + done_ids = done.nonzero(as_tuple=False).squeeze(-1) + + if done_ids.numel() > 0: + info["final_observations"] = obs[done_ids].clone() + info["final_env_ids"] = done_ids.clone() + + reset_qpos, reset_qvel = self._sim_reset(done_ids) + self.step_counts[done_ids] = 0 + + reset_state = self.env.build_state(reset_qpos, reset_qvel) + obs[done_ids] = self.env.compute_observations(reset_state) + + # skrl expects (num_envs, 1) for rewards/terminated/truncated + return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info + + def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None: + raise NotImplementedError("Render method not implemented for this runner.") + + def close(self) -> None: + self._sim_close() \ No newline at end of file diff --git a/src/envs/cartpole.py b/src/envs/cartpole.py new file mode 100644 index 0000000..f3a579a --- /dev/null +++ b/src/envs/cartpole.py @@ -0,0 +1,53 @@ +import dataclasses +import torch +from src.core.env import BaseEnv, BaseEnvConfig +from gymnasium import spaces + +@dataclasses.dataclass +class CartPoleState: + cart_pos: torch.float # (num_envs,) + cart_vel: torch.float # (num_envs,) + pole_angle: torch.float # (num_envs,) + pole_vel: torch.float # (num_envs,) + +@dataclasses.dataclass +class CartPoleConfig(BaseEnvConfig): + """CartPole task config. All values come from Hydra YAML.""" + angle_threshold: float = 0.418 # ~24 degrees + cart_limit: float = 2.4 + reward_alive: float = 1.0 + reward_pole_upright_scale: float = 1.0 + reward_action_penalty_scale: float = 0.01 + +class CartPoleEnv(BaseEnv[CartPoleConfig]): + def __init__(self, config: CartPoleConfig): + super().__init__(config) + + @property + def observation_space(self) -> torch.Tensor: + return spaces.Box(low=-torch.inf, high=torch.inf, shape=(4,)) + + @property + def action_space(self) -> torch.Tensor: + return spaces.Box(low=-1.0, high=1.0, shape=(1,)) + + def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> CartPoleState: + return CartPoleState( + cart_pos=qpos[:, 0], + cart_vel=qvel[:, 0], + pole_angle=qpos[:, 1], + pole_vel=qvel[:, 1], + ) + + def compute_observations(self, state: CartPoleState) -> torch.Tensor: + return torch.stack([state.cart_pos, state.cart_vel, state.pole_angle, state.pole_vel], dim=-1) + + def compute_rewards(self, state: CartPoleState, actions: torch.Tensor) -> torch.Tensor: + upright = self.config.reward_pole_upright_scale * torch.cos(state.pole_angle) + action_penalty = self.config.reward_action_penalty_scale * torch.sum(actions**2, dim=-1) + return self.config.reward_alive + upright - action_penalty + + def compute_terminations(self, state: CartPoleState) -> torch.Tensor: + pole_fallen = torch.abs(state.pole_angle) > self.config.angle_threshold + cart_out_of_bounds = torch.abs(state.cart_pos) > self.config.cart_limit + return pole_fallen | cart_out_of_bounds \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/mlp.py b/src/models/mlp.py new file mode 100644 index 0000000..ebc4537 --- /dev/null +++ b/src/models/mlp.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +from gymnasium import spaces +from skrl.models.torch import Model, GaussianMixin, DeterministicMixin + +class SharedMLP(GaussianMixin, DeterministicMixin, Model): + def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20.0, max_log_std: float = 2.0, initial_log_std: float = 0.0): + Model.__init__(self, observation_space, action_space, device) + GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std) + DeterministicMixin.__init__(self, clip_actions) + + layers = [] + in_dim: int = self.num_observations + for hidden_size in hidden_sizes: + layers.append(nn.Linear(in_dim, hidden_size)) + layers.append(nn.ELU()) + in_dim = hidden_size + self.net: nn.Sequential = nn.Sequential(*layers) + + # Policy head + self.mean_layer = nn.Linear(in_dim, self.num_actions) + self.log_std_parameter: nn.Parameter = nn.Parameter(torch.full((self.num_actions,), initial_log_std)) + + # Value head + self.value_layer = nn.Linear(in_dim, 1) + self._shared_output: torch.Tensor | None = None + + + def act(self, inputs: dict[str, torch.Tensor], role: str = "") -> tuple[torch.Tensor, ...]: + if role == "policy": + return GaussianMixin.act(self, inputs, role) + elif role == "value": + return DeterministicMixin.act(self, inputs, role) + + def compute( + self, inputs: dict[str, torch.Tensor], role: str = "" + ) -> tuple[torch.Tensor, ...]: + if role == "policy": + self._shared_output = self.net(inputs["states"]) + return self.mean_layer(self._shared_output), self.log_std_parameter, {} + elif role == "value": + shared_output = ( + self._shared_output + if self._shared_output is not None + else self.net(inputs["states"]) + ) + self._shared_output = None + return self.value_layer(shared_output), {} \ No newline at end of file diff --git a/src/runners/mujoco.py b/src/runners/mujoco.py new file mode 100644 index 0000000..d91e54e --- /dev/null +++ b/src/runners/mujoco.py @@ -0,0 +1,155 @@ +import dataclasses +import tempfile +import xml.etree.ElementTree as ET +from src.core.env import BaseEnv, ActuatorConfig +from src.core.runner import BaseRunner, BaseRunnerConfig +import torch +import numpy as np +import mujoco +import mujoco.viewer + +@dataclasses.dataclass +class MuJoCoRunnerConfig(BaseRunnerConfig): + num_envs: int = 16 + device: str = "cpu" + dt: float = 0.02 + substeps: int = 2 + +class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): + def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig): + super().__init__(env, config) + + @property + def num_envs(self) -> int: + return self.config.num_envs + + @property + def device(self) -> torch.device: + return torch.device(self.config.device) + + @staticmethod + def _load_model_with_actuators(model_path: str, actuators: list[ActuatorConfig]) -> mujoco.MjModel: + """Load a URDF (or MJCF) file and programmatically inject actuators. + + Two-step approach required because MuJoCo's URDF parser ignores + in the extension block: + 1. Load the URDF → MuJoCo converts it to internal MJCF + 2. Export the MJCF XML, add elements, reload + + This keeps the URDF clean and standard — actuator config lives in + the env config (Isaac Lab pattern), not in the robot file. + """ + # Step 1: Load URDF/MJCF as-is (no actuators) + model_raw = mujoco.MjModel.from_xml_path(model_path) + + if not actuators: + return model_raw + + # Step 2: Export internal MJCF representation + tmp_mjcf = tempfile.mktemp(suffix=".xml") + try: + mujoco.mj_saveLastXML(tmp_mjcf, model_raw) + with open(tmp_mjcf) as f: + mjcf_str = f.read() + finally: + import os + os.unlink(tmp_mjcf) + + # Step 3: Inject actuators into the MJCF XML + root = ET.fromstring(mjcf_str) + act_elem = ET.SubElement(root, "actuator") + for act in actuators: + ET.SubElement(act_elem, "motor", attrib={ + "name": f"{act.joint}_motor", + "joint": act.joint, + "gear": str(act.gear), + "ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}", + }) + + # Step 4: Reload from modified MJCF + modified_xml = ET.tostring(root, encoding="unicode") + return mujoco.MjModel.from_xml_string(modified_xml) + + def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None: + model_path = self.env.config.model_path + if model_path is None: + raise ValueError("model_path must be specified in the environment config") + + actuators = self.env.config.actuators + self._model = self._load_model_with_actuators(str(model_path), actuators) + self._model.opt.timestep = config.dt + self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)] + + self._nq = self._model.nq + self._nv = self._model.nv + + def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + actions_np: np.ndarray = actions.cpu().numpy() + + qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32) + qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32) + + for i, data in enumerate(self._data): + data.ctrl[:] = actions_np[i] + for _ in range(self.config.substeps): + mujoco.mj_step(self._model, data) + + qpos_batch[i] = data.qpos + qvel_batch[i] = data.qvel + + return ( + torch.from_numpy(qpos_batch).to(self.device), + torch.from_numpy(qvel_batch).to(self.device), + ) + + def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + ids = env_ids.cpu().numpy() + n = len(ids) + + qpos_batch = np.zeros((n, self._nq), dtype=np.float32) + qvel_batch = np.zeros((n, self._nv), dtype=np.float32) + + for i, env_id in enumerate(ids): + data = self._data[env_id] + mujoco.mj_resetData(self._model, data) + + # Add small random perturbation so the pole doesn't start perfectly upright + data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq) + data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv) + + qpos_batch[i] = data.qpos + qvel_batch[i] = data.qvel + + return ( + torch.from_numpy(qpos_batch).to(self.device), + torch.from_numpy(qvel_batch).to(self.device), + ) + + def _sim_close(self) -> None: + if hasattr(self, "_viewer") and self._viewer is not None: + self._viewer.close() + self._viewer = None + + if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None: + self._offscreen_renderer.close() + self._offscreen_renderer = None + + self._data.clear() + + def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None: + if mode == "human": + if not hasattr(self, "_viewer") or self._viewer is None: + self._viewer = mujoco.viewer.launch_passive( + self._model, self._data[env_idx] + ) + # Update visual geometry from current physics state + mujoco.mj_forward(self._model, self._data[env_idx]) + self._viewer.sync() + return None + elif mode == "rgb_array": + # Cache the offscreen renderer to avoid create/destroy overhead + if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None: + self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640) + self._offscreen_renderer.update_scene(self._data[env_idx]) + pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused + return torch.from_numpy(pixels) \ No newline at end of file diff --git a/src/training/trainer.py b/src/training/trainer.py new file mode 100644 index 0000000..5d5a18e --- /dev/null +++ b/src/training/trainer.py @@ -0,0 +1,243 @@ +import dataclasses +import sys +import tempfile +from pathlib import Path + +import numpy as np +import tqdm + +from src.core.runner import BaseRunner +from clearml import Task, Logger +import torch +from gymnasium import spaces +from skrl.memories.torch import RandomMemory +from src.models.mlp import SharedMLP +from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG +from skrl.trainers.torch import SequentialTrainer + +@dataclasses.dataclass +class TrainerConfig: + rollout_steps: int = 2048 + learning_epochs: int = 8 + mini_batches: int = 4 + discount_factor: float = 0.99 + gae_lambda: float = 0.95 + learning_rate: float = 3e-4 + clip_ratio: float = 0.2 + value_loss_scale: float = 0.5 + entropy_loss_scale: float = 0.01 + + hidden_sizes: tuple[int, ...] = (64, 64) + + total_timesteps: int = 1_000_000 + log_interval: int = 10 + + # Video recording + record_video_every: int = 10000 # record a video every N timesteps (0 = disabled) + record_video_min_seconds: float = 10.0 # minimum video duration in seconds + record_video_fps: int = 0 # 0 = auto-derive from simulation rate + + clearml_project: str | None = None + clearml_task: str | None = None + + +class VideoRecordingTrainer(SequentialTrainer): + """Subclass of skrl's SequentialTrainer that records videos periodically.""" + + def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None): + super().__init__(env=env, agents=agents, cfg=cfg) + self._trainer_config = trainer_config + self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_")) + + def single_agent_train(self) -> None: + """Override to add periodic video recording.""" + assert self.num_simultaneous_agents == 1 + assert self.env.num_agents == 1 + + states, infos = self.env.reset() + + for timestep in tqdm.tqdm( + range(self.initial_timestep, self.timesteps), + disable=self.disable_progressbar, + file=sys.stdout, + ): + # Pre-interaction + self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) + + with torch.no_grad(): + actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] + next_states, rewards, terminated, truncated, infos = self.env.step(actions) + + if not self.headless: + self.env.render() + + self.agents.record_transition( + states=states, + actions=actions, + rewards=rewards, + next_states=next_states, + terminated=terminated, + truncated=truncated, + infos=infos, + timestep=timestep, + timesteps=self.timesteps, + ) + + if self.environment_info in infos: + for k, v in infos[self.environment_info].items(): + if isinstance(v, torch.Tensor) and v.numel() == 1: + self.agents.track_data(f"Info / {k}", v.item()) + + self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps) + + # Reset environments + if self.env.num_envs > 1: + states = next_states + else: + if terminated.any() or truncated.any(): + with torch.no_grad(): + states, infos = self.env.reset() + else: + states = next_states + + # Record video at intervals + cfg = self._trainer_config + if ( + cfg + and cfg.record_video_every > 0 + and (timestep + 1) % cfg.record_video_every == 0 + ): + self._record_video(timestep + 1) + + def _get_video_fps(self) -> int: + """Derive video fps from the simulation rate, or use configured value.""" + cfg = self._trainer_config + if cfg.record_video_fps > 0: + return cfg.record_video_fps + # Auto-derive from runner's simulation parameters + runner = self.env + dt = getattr(runner.config, "dt", 0.02) + substeps = getattr(runner.config, "substeps", 1) + return max(1, int(round(1.0 / (dt * substeps)))) + + def _record_video(self, timestep: int) -> None: + """Record evaluation episodes and upload to ClearML.""" + try: + import imageio.v3 as iio + except ImportError: + try: + import imageio as iio + except ImportError: + return + + cfg = self._trainer_config + fps = self._get_video_fps() + min_frames = int(cfg.record_video_min_seconds * fps) + max_frames = min_frames * 3 # hard cap to prevent runaway recording + frames: list[np.ndarray] = [] + + while len(frames) < min_frames and len(frames) < max_frames: + obs, _ = self.env.reset() + done = False + steps = 0 + max_episode_steps = getattr(self.env.env.config, "max_steps", 500) + while not done and steps < max_episode_steps: + with torch.no_grad(): + action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0] + obs, _, terminated, truncated, _ = self.env.step(action) + frame = self.env.render(mode="rgb_array") + if frame is not None: + frames.append(frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame) + done = (terminated | truncated).any().item() + steps += 1 + if len(frames) >= max_frames: + break + + if frames: + video_path = str(self._video_dir / f"step_{timestep}.mp4") + iio.imwrite(video_path, frames, fps=fps) + + logger = Logger.current_logger() + if logger: + logger.report_media( + title="Training Video", + series=f"step_{timestep}", + local_path=video_path, + iteration=timestep, + ) + + # Reset back to training state after recording + self.env.reset() + +class Trainer: + def __init__(self, runner: BaseRunner, config: TrainerConfig): + self.runner = runner + self.config = config + + self._init_clearml() + self._init_agent() + + def _init_clearml(self) -> None: + if self.config.clearml_project and self.config.clearml_task: + self.clearml_task = Task.init( + project_name=self.config.clearml_project, + task_name=self.config.clearml_task, + ) + else: + self.clearml_task = None + + def _init_agent(self) -> None: + device: torch.device = self.runner.device + obs_space: spaces.Space = self.runner.observation_space + act_space: spaces.Space = self.runner.action_space + num_envs: int = self.runner.num_envs + + self.memory: RandomMemory = RandomMemory(memory_size=self.config.rollout_steps, num_envs=num_envs, device=device) + + self.model: SharedMLP = SharedMLP( + observation_space=obs_space, + action_space=act_space, + device=device, + hidden_sizes=self.config.hidden_sizes, + ) + + models = { + "policy": self.model, + "value": self.model, + } + + agent_cfg = PPO_DEFAULT_CONFIG.copy() + agent_cfg.update({ + "rollouts": self.config.rollout_steps, + "learning_epochs": self.config.learning_epochs, + "mini_batches": self.config.mini_batches, + "discount_factor": self.config.discount_factor, + "lambda": self.config.gae_lambda, + "learning_rate": self.config.learning_rate, + "ratio_clip": self.config.clip_ratio, + "value_loss_scale": self.config.value_loss_scale, + "entropy_loss_scale": self.config.entropy_loss_scale, + }) + + self.agent: PPO = PPO( + models=models, + memory=self.memory, + observation_space=obs_space, + action_space=act_space, + device=device, + cfg=agent_cfg, + ) + + def train(self) -> None: + trainer = VideoRecordingTrainer( + env=self.runner, + agents=self.agent, + cfg={"timesteps": self.config.total_timesteps}, + trainer_config=self.config, + ) + trainer.train() + + def close(self) -> None: + self.runner.close() + if self.clearml_task: + self.clearml_task.close() \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..aa40206 --- /dev/null +++ b/train.py @@ -0,0 +1,47 @@ +import hydra +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf + +from src.envs.cartpole import CartPoleEnv, CartPoleConfig +from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig +from src.training.trainer import Trainer, TrainerConfig +from src.core.env import ActuatorConfig + + +def _build_env_config(cfg: DictConfig) -> CartPoleConfig: + env_dict = OmegaConf.to_container(cfg.env, resolve=True) + if "actuators" in env_dict: + for a in env_dict["actuators"]: + if "ctrl_range" in a: + a["ctrl_range"] = tuple(a["ctrl_range"]) + env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]] + return CartPoleConfig(**env_dict) + + +@hydra.main(version_base=None, config_path="configs", config_name="config") +def main(cfg: DictConfig) -> None: + env_config = _build_env_config(cfg) + runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True)) + + training_dict = OmegaConf.to_container(cfg.training, resolve=True) + # Build ClearML task name dynamically from Hydra config group choices + if not training_dict.get("clearml_task"): + choices = HydraConfig.get().runtime.choices + env_name = choices.get("env", "env") + runner_name = choices.get("runner", "runner") + training_name = choices.get("training", "algo") + training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}" + trainer_config = TrainerConfig(**training_dict) + + env = CartPoleEnv(env_config) + runner = MuJoCoRunner(env=env, config=runner_config) + trainer = Trainer(runner=runner, config=trainer_config) + + try: + trainer.train() + finally: + trainer.close() + + +if __name__ == "__main__": + main() \ No newline at end of file