diff --git a/configs/training/ppo_single.yaml b/configs/training/ppo_single.yaml index 914a1f9..7534433 100644 --- a/configs/training/ppo_single.yaml +++ b/configs/training/ppo_single.yaml @@ -8,15 +8,17 @@ defaults: - _self_ hidden_sizes: [256, 256] -total_timesteps: 500000 -learning_epochs: 5 -learning_rate: 0.001 -entropy_loss_scale: 0.0001 -log_interval: 1024 +total_timesteps: 1000000 +learning_epochs: 10 +learning_rate: 0.0003 +entropy_loss_scale: 0.01 +rollout_steps: 2048 +mini_batches: 8 +log_interval: 2048 checkpoint_interval: 10000 initial_log_std: -0.5 min_log_std: -4.0 -max_log_std: 0.0 +max_log_std: 2.0 record_video_every: 50000 diff --git a/scripts/hpo.py b/scripts/hpo.py index 1bca7ac..aaa1ff5 100644 --- a/scripts/hpo.py +++ b/scripts/hpo.py @@ -5,14 +5,17 @@ search ranges from the Hydra config's `training.hpo` and `env.hpo` blocks, and launches SMAC3 Successive Halving optimization. Usage: - python scripts/hpo.py \ - --env rotary_cartpole \ - --runner mujoco_single \ - --training ppo_single \ - --queue gpu-queue + python scripts/hpo.py env=rotary_cartpole runner=mujoco_single training=ppo_single + + # With HPO-specific options: + python scripts/hpo.py env=rotary_cartpole runner=mujoco_single training=ppo_single \\ + --queue gpu-queue --total-trials 100 # Or use an existing base task: python scripts/hpo.py --base-task-id + + # Dry run (print search space only): + python scripts/hpo.py env=rotary_cartpole --dry-run """ from __future__ import annotations @@ -233,9 +236,33 @@ def _create_base_task( return task_id +def _parse_overrides(argv: list[str]) -> dict[str, str]: + """Parse Hydra-style key=value overrides from argv. + + Returns a dict of parsed key-value pairs. Unknown args (--flags) + are left in argv for argparse to handle. + """ + overrides = {} + remaining = [] + for arg in argv: + if "=" in arg and not arg.startswith("-"): + key, value = arg.split("=", 1) + overrides[key] = value + else: + remaining.append(arg) + argv.clear() + argv.extend(remaining) + return overrides + + def main() -> None: + # First pass: extract Hydra-style key=value overrides from sys.argv + raw_args = sys.argv[1:] + overrides = _parse_overrides(raw_args) + parser = argparse.ArgumentParser( - description="Hyperparameter optimization for RL-Framework" + description="Hyperparameter optimization for RL-Framework", + usage="%(prog)s env= runner= training= [options]", ) parser.add_argument( "--base-task-id", @@ -243,9 +270,6 @@ def main() -> None: default=None, help="Existing ClearML task ID to use as base (skip auto-creation)", ) - parser.add_argument("--env", type=str, default="rotary_cartpole") - parser.add_argument("--runner", type=str, default="mujoco_single") - parser.add_argument("--training", type=str, default="ppo_single") parser.add_argument("--queue", type=str, default="gpu-queue") parser.add_argument( "--max-concurrent", type=int, default=2, @@ -292,12 +316,17 @@ def main() -> None: "--dry-run", action="store_true", help="Print search space and exit without running", ) - args = parser.parse_args() + args = parser.parse_args(raw_args) + + # Resolve env/runner/training from Hydra-style overrides (same as train.py) + env = overrides.get("env", "rotary_cartpole") + runner = overrides.get("runner", "mujoco_single") + training = overrides.get("training", "ppo_single") objective_sign = "min" if args.minimize else "max" # ── Load config and build search space ──────────────────────── - config = _load_hydra_config(args.env, args.runner, args.training) + config = _load_hydra_config(env, runner, training) hyper_parameters = _build_hyper_parameters(config) if not hyper_parameters: @@ -318,7 +347,7 @@ def main() -> None: Task.ignore_requirements("torch") task = Task.init( project_name="RL-Framework", - task_name=f"HPO {args.env}-{args.runner}-{args.training}", + task_name=f"HPO {env}-{runner}-{training}", task_type=Task.TaskTypes.optimizer, reuse_last_task_id=False, ) @@ -347,7 +376,7 @@ def main() -> None: logger.info("reusing_base_task_from_param", task_id=base_task_id) else: base_task_id = _create_base_task( - args.env, args.runner, args.training, args.queue + env, runner, training, args.queue ) task.set_parameter("General/base_task_id", base_task_id)