Files
RL-Sim-Framework/train.py
2026-03-09 21:18:19 +01:00

139 lines
5.3 KiB
Python

import os
import pathlib
import sys
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import)
if sys.platform == "linux" and "DISPLAY" not in os.environ:
os.environ.setdefault("MUJOCO_GL", "osmesa")
import hydra
import hydra.utils as hydra_utils
import structlog
from clearml import Task
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
from src.core.runner import BaseRunner
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
from src.training.trainer import Trainer, TrainerConfig
logger = structlog.get_logger()
# ── env registry ──────────────────────────────────────────────────────
# Maps Hydra config-group name → (EnvClass, ConfigClass)
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
"cartpole": (CartPoleEnv, CartPoleConfig),
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
}
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
"""Instantiate the right env + config from the Hydra config-group name."""
if env_name not in ENV_REGISTRY:
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
env_cls, config_cls = ENV_REGISTRY[env_name]
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
# Convert actuator dicts → ActuatorConfig objects
if "actuators" in env_dict:
for a in env_dict["actuators"]:
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
return env_cls(config_cls(**env_dict))
# ── runner registry ───────────────────────────────────────────────────
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
}
def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
"""Instantiate the right runner from the Hydra config-group name."""
if runner_name not in RUNNER_REGISTRY:
raise ValueError(
f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
)
module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
import importlib
mod = importlib.import_module(module_path)
runner_cls = getattr(mod, cls_name)
config_cls = getattr(mod, cfg_cls_name)
runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
return runner_cls(env=env, config=runner_config)
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
"""Initialize ClearML task with project structure and tags.
Project: RL-Trainings/<EnvName> (e.g. RL-Trainings/Rotary Cartpole)
Tags: env, runner, training algo choices from Hydra.
"""
Task.ignore_requirements("torch")
env_name = choices.get("env", "cartpole")
runner_name = choices.get("runner", "mujoco")
training_name = choices.get("training", "ppo")
project = "RL-Framework"
task_name = f"{env_name}-{runner_name}-{training_name}"
tags = [env_name, runner_name, training_name]
task = Task.init(project_name=project, task_name=task_name, tags=tags)
task.set_base_docker(
"registry.kube.optimize/worker-image:latest",
docker_setup_bash_script=(
"apt-get update && apt-get install -y --no-install-recommends "
"libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
"&& pip install 'jax[cuda12]' mujoco-mjx"
),
)
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
task.set_packages(str(req_file))
# Execute remotely if requested and running locally
if remote and task.running_locally():
logger.info("executing_task_remotely", queue="gpu-queue")
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
return task
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
# ClearML init — must happen before heavy work so remote execution
# can take over early. The remote worker re-runs the full script;
# execute_remotely() is a no-op on the worker side.
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
remote = training_dict.pop("remote", False)
task = _init_clearml(choices, remote=remote)
env_name = choices.get("env", "cartpole")
env = _build_env(env_name, cfg)
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
trainer_config = TrainerConfig(**training_dict)
trainer = Trainer(runner=runner, config=trainer_config)
try:
trainer.train()
finally:
trainer.close()
task.close()
if __name__ == "__main__":
main()