update hpo

This commit is contained in:
2026-03-11 23:28:39 +01:00
parent 23801857f4
commit 3b2d6d08f9
3 changed files with 75 additions and 145 deletions

View File

@@ -132,6 +132,21 @@ def _build_hyper_parameters(config: dict) -> list:
return params
def _flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict:
"""Flatten a nested dict into dot-separated keys.
Example: {"a": {"b": 1}} → {"a.b": 1}
"""
items = {}
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.update(_flatten_dict(v, new_key, sep=sep))
else:
items[new_key] = v
return items
def _create_base_task(
env: str, runner: str, training: str, queue: str
) -> str:
@@ -139,6 +154,8 @@ def _create_base_task(
Uses Task.create() to register a task pointing at scripts/train.py
with the correct Hydra overrides. The HPO optimizer will clone this.
The full resolved OmegaConf config is attached as Hydra/* parameters
so cloned trial tasks inherit the complete configuration.
"""
script_path = str(Path(__file__).resolve().parent / "train.py")
project_root = str(Path(__file__).resolve().parent.parent)
@@ -157,14 +174,44 @@ def _create_base_task(
add_task_init_call=False,
)
# Explicitly set Hydra config-group choices so cloned tasks
# pick up the correct env / runner / training groups.
# Task.create() does not populate the Hydra parameter section
# because Hydra never actually runs during creation.
# ── Attach full resolved OmegaConf config ─────────────────────
# ClearML's Hydra binding normally does this when the script runs,
# but Task.create() never executes Hydra. We replicate the binding
# manually: config group choices + all resolved values.
base_task.set_parameter("Hydra/env", env)
base_task.set_parameter("Hydra/runner", runner)
base_task.set_parameter("Hydra/training", training)
# Load and resolve the full config for each group
configs_dir = Path(__file__).resolve().parent.parent / "configs"
for section, name in [("training", training), ("env", env), ("runner", runner)]:
cfg_path = configs_dir / section / f"{name}.yaml"
if not cfg_path.exists():
continue
cfg = OmegaConf.load(cfg_path)
# Handle Hydra defaults: inheritance (e.g. ppo_single → ppo)
if "defaults" in cfg:
defaults = OmegaConf.to_container(cfg.defaults)
base_cfg = OmegaConf.create({})
for d in defaults:
if isinstance(d, str):
base_path = configs_dir / section / f"{d}.yaml"
if base_path.exists():
loaded = OmegaConf.load(base_path)
base_cfg = OmegaConf.merge(base_cfg, loaded)
cfg_no_defaults = {
k: v for k, v in OmegaConf.to_container(cfg).items()
if k != "defaults"
}
cfg = OmegaConf.merge(base_cfg, OmegaConf.create(cfg_no_defaults))
resolved = OmegaConf.to_container(cfg, resolve=True)
# Remove hpo metadata — not a real config value
resolved.pop("hpo", None)
flat = _flatten_dict(resolved)
for key, value in flat.items():
base_task.set_parameter(f"Hydra/{section}.{key}", value)
# Set docker config
base_task.set_base_docker(
"registry.kube.optimize/worker-image:latest",
@@ -267,15 +314,6 @@ def main() -> None:
print(f"\nObjective: {args.objective_metric} ({objective_sign})")
return
# ── Create or reuse base task ─────────────────────────────────
if args.base_task_id:
base_task_id = args.base_task_id
logger.info("using_existing_base_task", task_id=base_task_id)
else:
base_task_id = _create_base_task(
args.env, args.runner, args.training, args.queue
)
# ── Initialize ClearML HPO task ───────────────────────────────
Task.ignore_requirements("torch")
task = Task.init(
@@ -295,6 +333,24 @@ def main() -> None:
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
task.set_packages(str(req_file))
# ── Create or reuse base task ─────────────────────────────────
# Store the base_task_id on the HPO task so that when the services
# worker re-runs this script it reuses the same base task instead
# of creating a duplicate.
if args.base_task_id:
base_task_id = args.base_task_id
logger.info("using_existing_base_task", task_id=base_task_id)
else:
existing = task.get_parameter("General/base_task_id")
if existing:
base_task_id = existing
logger.info("reusing_base_task_from_param", task_id=base_task_id)
else:
base_task_id = _create_base_task(
args.env, args.runner, args.training, args.queue
)
task.set_parameter("General/base_task_id", base_task_id)
# ── Build objective metric ────────────────────────────────────
# skrl's SequentialTrainer logs "Reward / Total reward (mean)" by default
objective_title = args.objective_metric

View File

@@ -181,6 +181,12 @@ class OptimizerSMAC(SearchStrategy):
"budget_param_name", "Hydra/training.total_timesteps"
)
# Pop our custom kwargs BEFORE passing smac_kwargs to SuccessiveHalving
self.max_consecutive_failures = int(
smac_kwargs.pop("max_consecutive_failures", 3)
)
self._consecutive_failures = 0
# build the Successive Halving intensifier (NOT Hyperband!)
# Hyperband runs multiple brackets with different starting budgets - wasteful
# Successive Halving: ALL configs start at min_budget, only best get promoted
@@ -204,12 +210,6 @@ class OptimizerSMAC(SearchStrategy):
self.best_score_so_far = float("-inf") if self.maximize_metric else float("inf")
self.time_limit_per_job = time_limit_per_job # Store time limit (minutes)
# Consecutive-failure abort: stop HPO if N trials in a row crash
self.max_consecutive_failures = int(
smac_kwargs.pop("max_consecutive_failures", 3)
)
self._consecutive_failures = 0
# Checkpoint continuation tracking: config_key -> {budget: task_id}
# Used to find the previous task's checkpoint when promoting a config
self.config_to_tasks = {} # config_key -> {budget: task_id}

126
train.py
View File

@@ -1,126 +0,0 @@
import os
import pathlib
import sys
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import).
# Always default on Linux — Docker containers may have DISPLAY set without a real X server.
if sys.platform == "linux":
os.environ.setdefault("MUJOCO_GL", "osmesa")
import hydra
import hydra.utils as hydra_utils
import structlog
from clearml import Task
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv
from src.core.registry import build_env
from src.core.runner import BaseRunner
from src.training.trainer import Trainer, TrainerConfig
logger = structlog.get_logger()
# ── runner registry ───────────────────────────────────────────────────
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
"mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
"serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
}
def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
"""Instantiate the right runner from the Hydra config-group name."""
if runner_name not in RUNNER_REGISTRY:
raise ValueError(
f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
)
module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
import importlib
mod = importlib.import_module(module_path)
runner_cls = getattr(mod, cls_name)
config_cls = getattr(mod, cfg_cls_name)
runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
return runner_cls(env=env, config=runner_config)
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
"""Initialize ClearML task with project structure and tags.
Project: RL-Trainings/<EnvName> (e.g. RL-Trainings/Rotary Cartpole)
Tags: env, runner, training algo choices from Hydra.
"""
Task.ignore_requirements("torch")
env_name = choices.get("env", "cartpole")
runner_name = choices.get("runner", "mujoco")
training_name = choices.get("training", "ppo")
project = "RL-Framework"
task_name = f"{env_name}-{runner_name}-{training_name}"
tags = [env_name, runner_name, training_name]
task = Task.init(project_name=project, task_name=task_name, tags=tags)
task.set_base_docker(
"registry.kube.optimize/worker-image:latest",
docker_setup_bash_script=(
"apt-get update && apt-get install -y --no-install-recommends "
"libosmesa6-dev libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
"&& pip install 'jax[cuda12]' mujoco-mjx PyOpenGL PyOpenGL-accelerate"
),
docker_arguments=[
"-e", "MUJOCO_GL=osmesa",
],
)
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
task.set_packages(str(req_file))
# Execute remotely if requested and running locally
if remote and task.running_locally():
logger.info("executing_task_remotely", queue="gpu-queue")
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
return task
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
# ClearML init — must happen before heavy work so remote execution
# can take over early. The remote worker re-runs the full script;
# execute_remotely() is a no-op on the worker side.
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
remote = training_dict.pop("remote", False)
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
task = _init_clearml(choices, remote=remote)
# Drop keys not recognised by TrainerConfig (e.g. ClearML-injected
# resume_from_task_id or any future additions)
import dataclasses as _dc
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
env_name = choices.get("env", "cartpole")
env = build_env(env_name, cfg)
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
trainer_config = TrainerConfig(**training_dict)
trainer = Trainer(runner=runner, config=trainer_config)
try:
trainer.train()
finally:
trainer.close()
task.close()
if __name__ == "__main__":
main()