♻️ crazy refactor
This commit is contained in:
7
train.py
7
train.py
@@ -26,8 +26,10 @@ logger = structlog.get_logger()
|
||||
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
|
||||
|
||||
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
|
||||
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
|
||||
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||
"mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
|
||||
"serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
|
||||
}
|
||||
|
||||
|
||||
@@ -94,6 +96,7 @@ def main(cfg: DictConfig) -> None:
|
||||
# execute_remotely() is a no-op on the worker side.
|
||||
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
||||
remote = training_dict.pop("remote", False)
|
||||
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
|
||||
task = _init_clearml(choices, remote=remote)
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
|
||||
Reference in New Issue
Block a user