✨ update hpo
This commit is contained in:
@@ -132,6 +132,21 @@ def _build_hyper_parameters(config: dict) -> list:
|
||||
return params
|
||||
|
||||
|
||||
def _flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict:
|
||||
"""Flatten a nested dict into dot-separated keys.
|
||||
|
||||
Example: {"a": {"b": 1}} → {"a.b": 1}
|
||||
"""
|
||||
items = {}
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.update(_flatten_dict(v, new_key, sep=sep))
|
||||
else:
|
||||
items[new_key] = v
|
||||
return items
|
||||
|
||||
|
||||
def _create_base_task(
|
||||
env: str, runner: str, training: str, queue: str
|
||||
) -> str:
|
||||
@@ -139,6 +154,8 @@ def _create_base_task(
|
||||
|
||||
Uses Task.create() to register a task pointing at scripts/train.py
|
||||
with the correct Hydra overrides. The HPO optimizer will clone this.
|
||||
The full resolved OmegaConf config is attached as Hydra/* parameters
|
||||
so cloned trial tasks inherit the complete configuration.
|
||||
"""
|
||||
script_path = str(Path(__file__).resolve().parent / "train.py")
|
||||
project_root = str(Path(__file__).resolve().parent.parent)
|
||||
@@ -157,14 +174,44 @@ def _create_base_task(
|
||||
add_task_init_call=False,
|
||||
)
|
||||
|
||||
# Explicitly set Hydra config-group choices so cloned tasks
|
||||
# pick up the correct env / runner / training groups.
|
||||
# Task.create() does not populate the Hydra parameter section
|
||||
# because Hydra never actually runs during creation.
|
||||
# ── Attach full resolved OmegaConf config ─────────────────────
|
||||
# ClearML's Hydra binding normally does this when the script runs,
|
||||
# but Task.create() never executes Hydra. We replicate the binding
|
||||
# manually: config group choices + all resolved values.
|
||||
base_task.set_parameter("Hydra/env", env)
|
||||
base_task.set_parameter("Hydra/runner", runner)
|
||||
base_task.set_parameter("Hydra/training", training)
|
||||
|
||||
# Load and resolve the full config for each group
|
||||
configs_dir = Path(__file__).resolve().parent.parent / "configs"
|
||||
for section, name in [("training", training), ("env", env), ("runner", runner)]:
|
||||
cfg_path = configs_dir / section / f"{name}.yaml"
|
||||
if not cfg_path.exists():
|
||||
continue
|
||||
cfg = OmegaConf.load(cfg_path)
|
||||
# Handle Hydra defaults: inheritance (e.g. ppo_single → ppo)
|
||||
if "defaults" in cfg:
|
||||
defaults = OmegaConf.to_container(cfg.defaults)
|
||||
base_cfg = OmegaConf.create({})
|
||||
for d in defaults:
|
||||
if isinstance(d, str):
|
||||
base_path = configs_dir / section / f"{d}.yaml"
|
||||
if base_path.exists():
|
||||
loaded = OmegaConf.load(base_path)
|
||||
base_cfg = OmegaConf.merge(base_cfg, loaded)
|
||||
cfg_no_defaults = {
|
||||
k: v for k, v in OmegaConf.to_container(cfg).items()
|
||||
if k != "defaults"
|
||||
}
|
||||
cfg = OmegaConf.merge(base_cfg, OmegaConf.create(cfg_no_defaults))
|
||||
|
||||
resolved = OmegaConf.to_container(cfg, resolve=True)
|
||||
# Remove hpo metadata — not a real config value
|
||||
resolved.pop("hpo", None)
|
||||
flat = _flatten_dict(resolved)
|
||||
for key, value in flat.items():
|
||||
base_task.set_parameter(f"Hydra/{section}.{key}", value)
|
||||
|
||||
# Set docker config
|
||||
base_task.set_base_docker(
|
||||
"registry.kube.optimize/worker-image:latest",
|
||||
@@ -267,15 +314,6 @@ def main() -> None:
|
||||
print(f"\nObjective: {args.objective_metric} ({objective_sign})")
|
||||
return
|
||||
|
||||
# ── Create or reuse base task ─────────────────────────────────
|
||||
if args.base_task_id:
|
||||
base_task_id = args.base_task_id
|
||||
logger.info("using_existing_base_task", task_id=base_task_id)
|
||||
else:
|
||||
base_task_id = _create_base_task(
|
||||
args.env, args.runner, args.training, args.queue
|
||||
)
|
||||
|
||||
# ── Initialize ClearML HPO task ───────────────────────────────
|
||||
Task.ignore_requirements("torch")
|
||||
task = Task.init(
|
||||
@@ -295,6 +333,24 @@ def main() -> None:
|
||||
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
|
||||
task.set_packages(str(req_file))
|
||||
|
||||
# ── Create or reuse base task ─────────────────────────────────
|
||||
# Store the base_task_id on the HPO task so that when the services
|
||||
# worker re-runs this script it reuses the same base task instead
|
||||
# of creating a duplicate.
|
||||
if args.base_task_id:
|
||||
base_task_id = args.base_task_id
|
||||
logger.info("using_existing_base_task", task_id=base_task_id)
|
||||
else:
|
||||
existing = task.get_parameter("General/base_task_id")
|
||||
if existing:
|
||||
base_task_id = existing
|
||||
logger.info("reusing_base_task_from_param", task_id=base_task_id)
|
||||
else:
|
||||
base_task_id = _create_base_task(
|
||||
args.env, args.runner, args.training, args.queue
|
||||
)
|
||||
task.set_parameter("General/base_task_id", base_task_id)
|
||||
|
||||
# ── Build objective metric ────────────────────────────────────
|
||||
# skrl's SequentialTrainer logs "Reward / Total reward (mean)" by default
|
||||
objective_title = args.objective_metric
|
||||
|
||||
@@ -181,6 +181,12 @@ class OptimizerSMAC(SearchStrategy):
|
||||
"budget_param_name", "Hydra/training.total_timesteps"
|
||||
)
|
||||
|
||||
# Pop our custom kwargs BEFORE passing smac_kwargs to SuccessiveHalving
|
||||
self.max_consecutive_failures = int(
|
||||
smac_kwargs.pop("max_consecutive_failures", 3)
|
||||
)
|
||||
self._consecutive_failures = 0
|
||||
|
||||
# build the Successive Halving intensifier (NOT Hyperband!)
|
||||
# Hyperband runs multiple brackets with different starting budgets - wasteful
|
||||
# Successive Halving: ALL configs start at min_budget, only best get promoted
|
||||
@@ -204,12 +210,6 @@ class OptimizerSMAC(SearchStrategy):
|
||||
self.best_score_so_far = float("-inf") if self.maximize_metric else float("inf")
|
||||
self.time_limit_per_job = time_limit_per_job # Store time limit (minutes)
|
||||
|
||||
# Consecutive-failure abort: stop HPO if N trials in a row crash
|
||||
self.max_consecutive_failures = int(
|
||||
smac_kwargs.pop("max_consecutive_failures", 3)
|
||||
)
|
||||
self._consecutive_failures = 0
|
||||
|
||||
# Checkpoint continuation tracking: config_key -> {budget: task_id}
|
||||
# Used to find the previous task's checkpoint when promoting a config
|
||||
self.config_to_tasks = {} # config_key -> {budget: task_id}
|
||||
|
||||
126
train.py
126
train.py
@@ -1,126 +0,0 @@
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import).
|
||||
# Always default on Linux — Docker containers may have DISPLAY set without a real X server.
|
||||
if sys.platform == "linux":
|
||||
os.environ.setdefault("MUJOCO_GL", "osmesa")
|
||||
|
||||
import hydra
|
||||
import hydra.utils as hydra_utils
|
||||
import structlog
|
||||
from clearml import Task
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import BaseEnv
|
||||
from src.core.registry import build_env
|
||||
from src.core.runner import BaseRunner
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ── runner registry ───────────────────────────────────────────────────
|
||||
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
|
||||
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
|
||||
|
||||
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
|
||||
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||
"mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
|
||||
"serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
|
||||
}
|
||||
|
||||
|
||||
def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
|
||||
"""Instantiate the right runner from the Hydra config-group name."""
|
||||
if runner_name not in RUNNER_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
|
||||
)
|
||||
module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
|
||||
|
||||
import importlib
|
||||
mod = importlib.import_module(module_path)
|
||||
runner_cls = getattr(mod, cls_name)
|
||||
config_cls = getattr(mod, cfg_cls_name)
|
||||
|
||||
runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||
return runner_cls(env=env, config=runner_config)
|
||||
|
||||
|
||||
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||
"""Initialize ClearML task with project structure and tags.
|
||||
|
||||
Project: RL-Trainings/<EnvName> (e.g. RL-Trainings/Rotary Cartpole)
|
||||
Tags: env, runner, training algo choices from Hydra.
|
||||
"""
|
||||
Task.ignore_requirements("torch")
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
runner_name = choices.get("runner", "mujoco")
|
||||
training_name = choices.get("training", "ppo")
|
||||
|
||||
project = "RL-Framework"
|
||||
task_name = f"{env_name}-{runner_name}-{training_name}"
|
||||
tags = [env_name, runner_name, training_name]
|
||||
|
||||
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||
task.set_base_docker(
|
||||
"registry.kube.optimize/worker-image:latest",
|
||||
docker_setup_bash_script=(
|
||||
"apt-get update && apt-get install -y --no-install-recommends "
|
||||
"libosmesa6-dev libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]' mujoco-mjx PyOpenGL PyOpenGL-accelerate"
|
||||
),
|
||||
docker_arguments=[
|
||||
"-e", "MUJOCO_GL=osmesa",
|
||||
],
|
||||
)
|
||||
|
||||
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
||||
task.set_packages(str(req_file))
|
||||
|
||||
# Execute remotely if requested and running locally
|
||||
if remote and task.running_locally():
|
||||
logger.info("executing_task_remotely", queue="gpu-queue")
|
||||
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
|
||||
# ClearML init — must happen before heavy work so remote execution
|
||||
# can take over early. The remote worker re-runs the full script;
|
||||
# execute_remotely() is a no-op on the worker side.
|
||||
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
||||
remote = training_dict.pop("remote", False)
|
||||
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
|
||||
task = _init_clearml(choices, remote=remote)
|
||||
|
||||
# Drop keys not recognised by TrainerConfig (e.g. ClearML-injected
|
||||
# resume_from_task_id or any future additions)
|
||||
import dataclasses as _dc
|
||||
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
|
||||
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
env = build_env(env_name, cfg)
|
||||
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||
trainer_config = TrainerConfig(**training_dict)
|
||||
trainer = Trainer(runner=runner, config=trainer_config)
|
||||
|
||||
try:
|
||||
trainer.train()
|
||||
finally:
|
||||
trainer.close()
|
||||
task.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user