diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8491f2647a4..5d99c35f5f0 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -89,6 +89,7 @@ Guidelines for modifications: * Kourosh Darvish * Kousheek Chakraborty * Lionel Gulich +* Litian Gong * Lotus Li * Louis Le Lay * Lorenz Wellhausen diff --git a/scripts/imitation_learning/robomimic/train.py b/scripts/imitation_learning/robomimic/train.py index c97df13260f..9c645773318 100644 --- a/scripts/imitation_learning/robomimic/train.py +++ b/scripts/imitation_learning/robomimic/train.py @@ -87,6 +87,7 @@ import isaaclab_tasks # noqa: F401 import isaaclab_tasks.manager_based.locomanipulation.pick_place # noqa: F401 import isaaclab_tasks.manager_based.manipulation.pick_place # noqa: F401 +from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry def normalize_hdf5_actions(config: Config, log_dir: str) -> str: @@ -362,28 +363,19 @@ def main(args: argparse.Namespace): print(f"Loading configuration for task: {task_name}") print(gym.envs.registry.keys()) print(" ") - cfg_entry_point_file = gym.spec(task_name).kwargs.pop(cfg_entry_point_key) - # check if entry point exists - if cfg_entry_point_file is None: - raise ValueError( - f"Could not find configuration for the environment: '{task_name}'." - f" Please check that the gym registry has the entry point: '{cfg_entry_point_key}'." + + # use the unified configuration loading utility (supports YAML, JSON, and Python classes) + ext_cfg = load_cfg_from_registry(task_name, cfg_entry_point_key) + + # ensure the configuration is a dictionary (robomimic expects JSON/YAML dict format) + if not isinstance(ext_cfg, dict): + raise TypeError( + f"Expected robomimic configuration to be a dictionary, but got {type(ext_cfg)}." + " Please ensure the configuration file is in JSON or YAML format." ) - # resolve module path if needed - if ":" in cfg_entry_point_file: - mod_name, file_name = cfg_entry_point_file.split(":") - mod = importlib.import_module(mod_name) - if mod.__file__ is None: - raise ValueError(f"Could not find module file for: '{mod_name}'") - mod_path = os.path.dirname(mod.__file__) - config_file = os.path.join(mod_path, file_name) - else: - config_file = cfg_entry_point_file - - with open(config_file) as f: - ext_cfg = json.load(f) - config = config_factory(ext_cfg["algo_name"]) + # create robomimic config from the loaded dictionary + config = config_factory(ext_cfg["algo_name"]) # update config with external json - this will throw errors if # the external config has keys not present in the base algo config with config.values_unlocked(): diff --git a/source/isaaclab_tasks/isaaclab_tasks/utils/parse_cfg.py b/source/isaaclab_tasks/isaaclab_tasks/utils/parse_cfg.py index b4f788a9bcb..374ae6b2228 100644 --- a/source/isaaclab_tasks/isaaclab_tasks/utils/parse_cfg.py +++ b/source/isaaclab_tasks/isaaclab_tasks/utils/parse_cfg.py @@ -80,20 +80,28 @@ def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | objec f"{msg if agents else ''}" ) # parse the default config file - if isinstance(cfg_entry_point, str) and cfg_entry_point.endswith(".yaml"): + if isinstance(cfg_entry_point, str) and (cfg_entry_point.endswith(".yaml") or cfg_entry_point.endswith(".json")): if os.path.exists(cfg_entry_point): # absolute path for the config file config_file = cfg_entry_point else: # resolve path to the module location mod_name, file_name = cfg_entry_point.split(":") - mod_path = os.path.dirname(importlib.import_module(mod_name).__file__) + mod = importlib.import_module(mod_name) + if mod.__file__ is None: + raise ValueError(f"Could not determine file path for module: {mod_name}") + mod_path = os.path.dirname(mod.__file__) # obtain the configuration file path config_file = os.path.join(mod_path, file_name) # load the configuration print(f"[INFO]: Parsing configuration from: {config_file}") with open(config_file, encoding="utf-8") as f: - cfg = yaml.full_load(f) + if cfg_entry_point.endswith(".yaml"): + cfg = yaml.full_load(f) + else: # .json + import json + + cfg = json.load(f) else: if callable(cfg_entry_point): # resolve path to the module location