47 lines
1.7 KiB
Python
47 lines
1.7 KiB
Python
import hydra
|
|
from hydra.core.hydra_config import HydraConfig
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
|
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
|
from src.training.trainer import Trainer, TrainerConfig
|
|
from src.core.env import ActuatorConfig
|
|
|
|
|
|
def _build_env_config(cfg: DictConfig) -> CartPoleConfig:
|
|
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
|
if "actuators" in env_dict:
|
|
for a in env_dict["actuators"]:
|
|
if "ctrl_range" in a:
|
|
a["ctrl_range"] = tuple(a["ctrl_range"])
|
|
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
|
return CartPoleConfig(**env_dict)
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
|
def main(cfg: DictConfig) -> None:
|
|
env_config = _build_env_config(cfg)
|
|
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
|
|
|
|
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
|
# Build ClearML task name dynamically from Hydra config group choices
|
|
if not training_dict.get("clearml_task"):
|
|
choices = HydraConfig.get().runtime.choices
|
|
env_name = choices.get("env", "env")
|
|
runner_name = choices.get("runner", "runner")
|
|
training_name = choices.get("training", "algo")
|
|
training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}"
|
|
trainer_config = TrainerConfig(**training_dict)
|
|
|
|
env = CartPoleEnv(env_config)
|
|
runner = MuJoCoRunner(env=env, config=runner_config)
|
|
trainer = Trainer(runner=runner, config=trainer_config)
|
|
|
|
try:
|
|
trainer.train()
|
|
finally:
|
|
trainer.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |