✨ better robot joint loading
This commit is contained in:
10
assets/cartpole/robot.yaml
Normal file
10
assets/cartpole/robot.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
# Classic cartpole — robot hardware config.
|
||||
|
||||
urdf: cartpole.urdf
|
||||
|
||||
actuators:
|
||||
- joint: cart_joint
|
||||
type: motor
|
||||
gear: 10.0
|
||||
ctrl_range: [-1.0, 1.0]
|
||||
damping: 0.05
|
||||
15
assets/rotary_cartpole/robot.yaml
Normal file
15
assets/rotary_cartpole/robot.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
# Rotary cartpole (Furuta pendulum) — robot hardware config.
|
||||
# Lives next to the URDF so all robot-specific settings are in one place.
|
||||
|
||||
urdf: rotary_cartpole.urdf
|
||||
|
||||
actuators:
|
||||
- joint: motor_joint
|
||||
type: motor # direct torque control
|
||||
gear: 0.5 # torque multiplier
|
||||
ctrl_range: [-1.0, 1.0]
|
||||
damping: 0.1 # motor friction / back-EMF
|
||||
|
||||
joints:
|
||||
pendulum_joint:
|
||||
damping: 0.0001 # bearing friction
|
||||
7
configs/env/cartpole.yaml
vendored
7
configs/env/cartpole.yaml
vendored
@@ -1,12 +1,7 @@
|
||||
max_steps: 500
|
||||
robot_path: assets/cartpole
|
||||
angle_threshold: 0.418
|
||||
cart_limit: 2.4
|
||||
reward_alive: 1.0
|
||||
reward_pole_upright_scale: 1.0
|
||||
reward_action_penalty_scale: 0.01
|
||||
model_path: assets/cartpole/cartpole.urdf
|
||||
actuators:
|
||||
- joint: cart_joint
|
||||
gear: 10.0
|
||||
ctrl_range: [-1.0, 1.0]
|
||||
damping: 0.05
|
||||
|
||||
9
configs/env/rotary_cartpole.yaml
vendored
9
configs/env/rotary_cartpole.yaml
vendored
@@ -1,8 +1,3 @@
|
||||
max_steps: 1000
|
||||
model_path: assets/rotary_cartpole/rotary_cartpole.urdf
|
||||
reward_upright_scale: 1.0
|
||||
actuators:
|
||||
- joint: motor_joint
|
||||
gear: 0.5
|
||||
ctrl_range: [-1.0, 1.0]
|
||||
damping: 0.1
|
||||
robot_path: assets/rotary_cartpole
|
||||
reward_upright_scale: 1.0
|
||||
@@ -10,4 +10,5 @@ clearml
|
||||
imageio
|
||||
imageio-ffmpeg
|
||||
structlog
|
||||
pyyaml
|
||||
pytest
|
||||
@@ -5,30 +5,20 @@ from gymnasium import spaces
|
||||
import torch
|
||||
import pathlib
|
||||
|
||||
from src.core.robot import RobotConfig, load_robot_config
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ActuatorConfig:
|
||||
"""Actuator definition — maps a joint to a motor with gear ratio and control limits.
|
||||
Kept in the env config (not runner config) because actuators define what the robot
|
||||
can do, which determines action space — a task-level concept.
|
||||
This mirrors Isaac Lab's pattern of separating actuator config from the robot file."""
|
||||
joint: str = ""
|
||||
gear: float = 1.0
|
||||
ctrl_range: tuple[float, float] = (-1.0, 1.0)
|
||||
damping: float = 0.05 # joint damping — limits max speed: vel_max ≈ torque / damping
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseEnvConfig:
|
||||
max_steps: int = 1000
|
||||
model_path: pathlib.Path | None = None
|
||||
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
|
||||
robot_path: str = "" # directory containing robot.yaml + URDF
|
||||
|
||||
class BaseEnv(abc.ABC, Generic[T]):
|
||||
def __init__(self, config: BaseEnvConfig):
|
||||
self.config = config
|
||||
self.robot: RobotConfig = load_robot_config(config.robot_path)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
|
||||
23
src/core/registry.py
Normal file
23
src/core/registry.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Shared env registry and builder used by train.py and viz.py."""
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import BaseEnv, BaseEnvConfig
|
||||
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
||||
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
||||
|
||||
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||
"cartpole": (CartPoleEnv, CartPoleConfig),
|
||||
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
||||
}
|
||||
|
||||
|
||||
def build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
||||
"""Instantiate the right env + config from the Hydra config-group name."""
|
||||
if env_name not in ENV_REGISTRY:
|
||||
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
||||
|
||||
env_cls, config_cls = ENV_REGISTRY[env_name]
|
||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||
return env_cls(config_cls(**env_dict))
|
||||
101
src/core/robot.py
Normal file
101
src/core/robot.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Robot hardware configuration — loaded from robot.yaml next to the URDF.
|
||||
|
||||
Separates robot hardware (actuators, joint tuning) from task config
|
||||
(rewards, episode length) and from the URDF (clean CAD export).
|
||||
|
||||
Usage:
|
||||
robot = load_robot_config(Path("assets/rotary_cartpole"))
|
||||
# robot.urdf_path → resolved absolute path to the URDF
|
||||
# robot.actuators → list of ActuatorConfig
|
||||
# robot.joints → dict of per-joint overrides
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ActuatorConfig:
|
||||
"""Motor/actuator attached to a joint.
|
||||
|
||||
type:
|
||||
motor — direct torque control (ctrl = normalised torque)
|
||||
position — PD position servo (ctrl = target angle, needs kp)
|
||||
velocity — P velocity servo (ctrl = target velocity, needs kp)
|
||||
"""
|
||||
joint: str = ""
|
||||
type: str = "motor"
|
||||
gear: float = 1.0
|
||||
ctrl_range: tuple[float, float] = (-1.0, 1.0)
|
||||
damping: float = 0.05
|
||||
kp: float = 0.0 # proportional gain (position / velocity actuators)
|
||||
kv: float = 0.0 # derivative gain (position actuators)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JointConfig:
|
||||
"""Per-joint overrides applied on top of the URDF values."""
|
||||
damping: float | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RobotConfig:
|
||||
"""Complete robot hardware description."""
|
||||
urdf_path: Path = dataclasses.field(default_factory=lambda: Path())
|
||||
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
|
||||
joints: dict[str, JointConfig] = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
def load_robot_config(robot_dir: str | Path) -> RobotConfig:
|
||||
"""Load robot.yaml from a directory and resolve the URDF path.
|
||||
|
||||
Expected layout:
|
||||
robot_dir/
|
||||
robot.yaml ← hardware config
|
||||
some_robot.urdf ← CAD export
|
||||
meshes/ ← optional mesh files
|
||||
"""
|
||||
robot_dir = Path(robot_dir).resolve()
|
||||
yaml_path = robot_dir / "robot.yaml"
|
||||
|
||||
if not yaml_path.exists():
|
||||
raise FileNotFoundError(f"Robot config not found: {yaml_path}")
|
||||
|
||||
raw = yaml.safe_load(yaml_path.read_text())
|
||||
|
||||
# Resolve URDF path relative to robot.yaml directory
|
||||
urdf_filename = raw.get("urdf", "")
|
||||
if not urdf_filename:
|
||||
raise ValueError(f"robot.yaml must specify 'urdf' filename: {yaml_path}")
|
||||
urdf_path = robot_dir / urdf_filename
|
||||
if not urdf_path.exists():
|
||||
raise FileNotFoundError(f"URDF not found: {urdf_path}")
|
||||
|
||||
# Parse actuators
|
||||
actuators = []
|
||||
for a in raw.get("actuators", []):
|
||||
if "ctrl_range" in a:
|
||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||
actuators.append(ActuatorConfig(**a))
|
||||
|
||||
# Parse joint overrides
|
||||
joints = {}
|
||||
for name, jcfg in raw.get("joints", {}).items():
|
||||
joints[name] = JointConfig(**jcfg)
|
||||
|
||||
config = RobotConfig(
|
||||
urdf_path=urdf_path,
|
||||
actuators=actuators,
|
||||
joints=joints,
|
||||
)
|
||||
|
||||
log.debug("robot_config_loaded", robot_dir=str(robot_dir),
|
||||
urdf=urdf_filename, num_actuators=len(actuators),
|
||||
joint_overrides=list(joints.keys()))
|
||||
|
||||
return config
|
||||
@@ -53,9 +53,10 @@ class BaseRunner(abc.ABC, Generic[T]):
|
||||
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sim_close(self) -> None:
|
||||
...
|
||||
"""Release simulator resources. Override for extra cleanup."""
|
||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||
self._offscreen_renderer.close()
|
||||
|
||||
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
|
||||
all_ids = torch.arange(self.num_envs, device=self.device)
|
||||
|
||||
@@ -29,7 +29,7 @@ import numpy as np
|
||||
|
||||
from src.core.env import BaseEnv
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
from src.runners.mujoco import MuJoCoRunner # reuse _load_model_with_actuators
|
||||
from src.runners.mujoco import MuJoCoRunner # reuse _load_model
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
@@ -64,14 +64,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
# ── Initialization ───────────────────────────────────────────────
|
||||
|
||||
def _sim_initialize(self, config: MJXRunnerConfig) -> None:
|
||||
model_path = self.env.config.model_path
|
||||
if model_path is None:
|
||||
raise ValueError("model_path must be specified")
|
||||
|
||||
# Step 1: Load CPU model (reuses URDF → MJCF → actuator injection)
|
||||
self._mj_model = MuJoCoRunner._load_model_with_actuators(
|
||||
str(model_path), self.env.config.actuators,
|
||||
)
|
||||
self._mj_model = MuJoCoRunner._load_model(self.env.robot)
|
||||
self._mj_model.opt.timestep = config.dt
|
||||
self._nq = self._mj_model.nq
|
||||
self._nv = self._mj_model.nv
|
||||
@@ -207,10 +201,6 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
rv = self._batch_data.qvel[ids_np].astype(jnp.float32)
|
||||
return torch.from_dlpack(rq), torch.from_dlpack(rv)
|
||||
|
||||
def _sim_close(self) -> None:
|
||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||
self._offscreen_renderer.close()
|
||||
|
||||
# ── Rendering ────────────────────────────────────────────────────
|
||||
|
||||
def render(self, env_idx: int = 0) -> np.ndarray:
|
||||
|
||||
@@ -6,7 +6,8 @@ import mujoco
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.core.env import BaseEnv, ActuatorConfig
|
||||
from src.core.env import BaseEnv
|
||||
from src.core.robot import RobotConfig
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -30,18 +31,18 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
return torch.device(self.config.device)
|
||||
|
||||
@staticmethod
|
||||
def _load_model_with_actuators(model_path: str, actuators: list[ActuatorConfig]) -> mujoco.MjModel:
|
||||
"""Load a URDF (or MJCF) file and programmatically inject actuators.
|
||||
def _load_model(robot: RobotConfig) -> mujoco.MjModel:
|
||||
"""Load a URDF (or MJCF) and apply robot.yaml settings.
|
||||
|
||||
Two-step approach required because MuJoCo's URDF parser ignores
|
||||
<actuator> in the <mujoco> extension block:
|
||||
1. Load the URDF → MuJoCo converts it to internal MJCF
|
||||
2. Export the MJCF XML, add <actuator> elements, reload
|
||||
2. Export the MJCF XML, inject actuators + joint overrides, reload
|
||||
|
||||
This keeps the URDF clean and standard — actuator config lives in
|
||||
the env config (Isaac Lab pattern), not in the robot file.
|
||||
This keeps the URDF clean (re-exportable from CAD) — all hardware
|
||||
tuning lives in robot.yaml.
|
||||
"""
|
||||
abs_path = Path(model_path).resolve()
|
||||
abs_path = robot.urdf_path.resolve()
|
||||
model_dir = abs_path.parent
|
||||
is_urdf = abs_path.suffix.lower() == ".urdf"
|
||||
|
||||
@@ -74,33 +75,45 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
else:
|
||||
model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
|
||||
|
||||
if not actuators:
|
||||
if not robot.actuators and not robot.joints:
|
||||
return model_raw
|
||||
|
||||
# Step 2: Export internal MJCF representation (save next to original
|
||||
# model so relative mesh/asset paths resolve correctly on reload)
|
||||
# Step 2: Export internal MJCF, inject actuators + joint overrides, reload
|
||||
tmp_mjcf = model_dir / "_tmp_actuator_inject.xml"
|
||||
try:
|
||||
mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
|
||||
mjcf_str = tmp_mjcf.read_text()
|
||||
|
||||
# Step 3: Inject actuators into the MJCF XML
|
||||
# Use torque actuator. Speed is limited by joint damping:
|
||||
# at steady state, vel_max = torque / damping.
|
||||
root = ET.fromstring(mjcf_str)
|
||||
act_elem = ET.SubElement(root, "actuator")
|
||||
for act in actuators:
|
||||
ET.SubElement(act_elem, "motor", attrib={
|
||||
"name": f"{act.joint}_motor",
|
||||
"joint": act.joint,
|
||||
"gear": str(act.gear),
|
||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||
})
|
||||
|
||||
# Add damping to actuated joints to limit max speed and
|
||||
# mimic real motor friction / back-EMF.
|
||||
# vel_max ≈ max_torque / damping
|
||||
joint_damping = {a.joint: a.damping for a in actuators}
|
||||
# ── Inject actuators ────────────────────────────────────
|
||||
if robot.actuators:
|
||||
act_elem = ET.SubElement(root, "actuator")
|
||||
for act in robot.actuators:
|
||||
attribs = {
|
||||
"name": f"{act.joint}_{act.type}",
|
||||
"joint": act.joint,
|
||||
"gear": str(act.gear),
|
||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||
}
|
||||
if act.type == "position":
|
||||
attribs["kp"] = str(act.kp)
|
||||
if act.kv > 0:
|
||||
attribs["kv"] = str(act.kv)
|
||||
ET.SubElement(act_elem, "position", attrib=attribs)
|
||||
elif act.type == "velocity":
|
||||
attribs["kp"] = str(act.kp)
|
||||
ET.SubElement(act_elem, "velocity", attrib=attribs)
|
||||
else: # motor (default)
|
||||
ET.SubElement(act_elem, "motor", attrib=attribs)
|
||||
|
||||
# ── Apply joint overrides from robot.yaml ───────────────
|
||||
# Merge actuator damping + explicit joint overrides
|
||||
joint_damping = {a.joint: a.damping for a in robot.actuators}
|
||||
for name, jcfg in robot.joints.items():
|
||||
if jcfg.damping is not None:
|
||||
joint_damping[name] = jcfg.damping
|
||||
|
||||
for body in root.iter("body"):
|
||||
for jnt in body.findall("joint"):
|
||||
name = jnt.get("name")
|
||||
@@ -115,6 +128,15 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
geom.set("contype", "0")
|
||||
geom.set("conaffinity", "0")
|
||||
|
||||
# Harden joint limits: MuJoCo's default soft limits are too
|
||||
# weak and allow overshoot. Negative solref = hard constraint
|
||||
# (direct stiffness/damping instead of impedance match).
|
||||
for body in root.iter("body"):
|
||||
for jnt in body.findall("joint"):
|
||||
if jnt.get("limited") == "true" or jnt.get("range"):
|
||||
jnt.set("solreflimit", "-1000 -100")
|
||||
jnt.set("solimplimit", "0.95 0.99 0.001")
|
||||
|
||||
# Step 4: Write modified MJCF and reload from file path
|
||||
# (from_xml_path resolves mesh paths relative to the file location)
|
||||
modified_xml = ET.tostring(root, encoding="unicode")
|
||||
@@ -124,12 +146,7 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
tmp_mjcf.unlink(missing_ok=True)
|
||||
|
||||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||||
model_path = self.env.config.model_path
|
||||
if model_path is None:
|
||||
raise ValueError("model_path must be specified in the environment config")
|
||||
|
||||
actuators = self.env.config.actuators
|
||||
self._model = self._load_model_with_actuators(str(model_path), actuators)
|
||||
self._model = self._load_model(self.env.robot)
|
||||
self._model.opt.timestep = config.dt
|
||||
self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)]
|
||||
|
||||
@@ -195,10 +212,6 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
torch.from_numpy(qvel_batch).to(self.device),
|
||||
)
|
||||
|
||||
def _sim_close(self) -> None:
|
||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||
self._offscreen_renderer.close()
|
||||
|
||||
def render(self, env_idx: int = 0) -> np.ndarray:
|
||||
"""Offscreen render of a single environment."""
|
||||
if not hasattr(self, "_offscreen_renderer"):
|
||||
|
||||
32
train.py
32
train.py
@@ -13,39 +13,13 @@ from clearml import Task
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
|
||||
from src.core.env import BaseEnv
|
||||
from src.core.registry import build_env
|
||||
from src.core.runner import BaseRunner
|
||||
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
||||
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# ── env registry ──────────────────────────────────────────────────────
|
||||
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||
"cartpole": (CartPoleEnv, CartPoleConfig),
|
||||
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
||||
}
|
||||
|
||||
|
||||
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
||||
"""Instantiate the right env + config from the Hydra config-group name."""
|
||||
if env_name not in ENV_REGISTRY:
|
||||
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
||||
|
||||
env_cls, config_cls = ENV_REGISTRY[env_name]
|
||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||
|
||||
# Convert actuator dicts → ActuatorConfig objects
|
||||
if "actuators" in env_dict:
|
||||
for a in env_dict["actuators"]:
|
||||
if "ctrl_range" in a:
|
||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
||||
|
||||
return env_cls(config_cls(**env_dict))
|
||||
|
||||
|
||||
# ── runner registry ───────────────────────────────────────────────────
|
||||
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
|
||||
@@ -123,7 +97,7 @@ def main(cfg: DictConfig) -> None:
|
||||
task = _init_clearml(choices, remote=remote)
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
env = _build_env(env_name, cfg)
|
||||
env = build_env(env_name, cfg)
|
||||
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||
trainer_config = TrainerConfig(**training_dict)
|
||||
trainer = Trainer(runner=runner, config=trainer_config)
|
||||
|
||||
25
viz.py
25
viz.py
@@ -20,32 +20,11 @@ import torch
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import ActuatorConfig, BaseEnv, BaseEnvConfig
|
||||
from src.envs.cartpole import CartPoleConfig, CartPoleEnv
|
||||
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
|
||||
from src.core.registry import build_env
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# ── registry (same as train.py) ──────────────────────────────────────
|
||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||
"cartpole": (CartPoleEnv, CartPoleConfig),
|
||||
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
||||
}
|
||||
|
||||
|
||||
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
||||
if env_name not in ENV_REGISTRY:
|
||||
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
||||
env_cls, config_cls = ENV_REGISTRY[env_name]
|
||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||
if "actuators" in env_dict:
|
||||
for a in env_dict["actuators"]:
|
||||
if "ctrl_range" in a:
|
||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
||||
return env_cls(config_cls(**env_dict))
|
||||
|
||||
|
||||
# ── keyboard state ───────────────────────────────────────────────────
|
||||
_action_val = [0.0] # mutable container shared with callback
|
||||
@@ -72,7 +51,7 @@ def main(cfg: DictConfig) -> None:
|
||||
env_name = choices.get("env", "cartpole")
|
||||
|
||||
# Build env + runner (single env for viz)
|
||||
env = _build_env(env_name, cfg)
|
||||
env = build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
runner_dict["num_envs"] = 1
|
||||
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
||||
|
||||
Reference in New Issue
Block a user