2026-03-06 22:19:44 +01:00

47 lines
1.7 KiB
Python

import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
from src.training.trainer import Trainer, TrainerConfig
from src.core.env import ActuatorConfig
def _build_env_config(cfg: DictConfig) -> CartPoleConfig:
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
if "actuators" in env_dict:
for a in env_dict["actuators"]:
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
return CartPoleConfig(**env_dict)
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
env_config = _build_env_config(cfg)
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
# Build ClearML task name dynamically from Hydra config group choices
if not training_dict.get("clearml_task"):
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "env")
runner_name = choices.get("runner", "runner")
training_name = choices.get("training", "algo")
training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}"
trainer_config = TrainerConfig(**training_dict)
env = CartPoleEnv(env_config)
runner = MuJoCoRunner(env=env, config=runner_config)
trainer = Trainer(runner=runner, config=trainer_config)
try:
trainer.train()
finally:
trainer.close()
if __name__ == "__main__":
main()