diff --git a/benchmarking/README.md b/benchmarking/README.md index 3f5114545..ba89da69d 100644 --- a/benchmarking/README.md +++ b/benchmarking/README.md @@ -15,5 +15,27 @@ python -m imitation.scripts. with benchmarking/.json') +from imitation.scripts. import +.run(command_name="", named_configs=["benchmarking/.json"]) ``` + +# Tuning Hyperparameters + +The hyperparameters of any algorithm in imitation can be tuned using the `tuning.py` script. +The benchmarking hyperparameter configs were generated by tuning the hyperparameters using +the search space defined in the `tuning_config.py` script. The tuning script proceeds in two +phases: 1) The hyperparameters are tuned using the search space provided, and 2) the best +hyperparameter config found in the first phase based on the maximum mean return is +re-evaluated on a separate set of seeds, and the mean and standard deviation of these trials +are reported. + +To tune the hyperparameters of an algorithm using the default search space provided: +```bash +python tuning.py with {algo} 'parallel_run_config.base_named_configs=["{env}"]' +``` + +In this command, `{algo}` provides the default search space and settings to be used for +the specific algorithm, which is defined in the `tuning_config.py` script and +`'parallel_run_config.base_named_configs=["{env}"]'` sets the environment to tune the algorithm in. +See the documentation of `tuning.py` and `parallel.py` scripts for many other arguments that can be +provided through the command line to change the tuning behavior. diff --git a/benchmarking/example_airl_seals_ant_best_hp_eval.json b/benchmarking/airl_seals_ant_best_hp_eval.json similarity index 98% rename from benchmarking/example_airl_seals_ant_best_hp_eval.json rename to benchmarking/airl_seals_ant_best_hp_eval.json index 17f969ff0..d4131433e 100644 --- a/benchmarking/example_airl_seals_ant_best_hp_eval.json +++ b/benchmarking/airl_seals_ant_best_hp_eval.json @@ -62,6 +62,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Ant-v0" + "gym_id": "seals/Ant-v1" } } diff --git a/benchmarking/example_airl_seals_half_cheetah_best_hp_eval.json b/benchmarking/airl_seals_half_cheetah_best_hp_eval.json similarity index 97% rename from benchmarking/example_airl_seals_half_cheetah_best_hp_eval.json rename to benchmarking/airl_seals_half_cheetah_best_hp_eval.json index 754ba6736..f69ba5cb5 100644 --- a/benchmarking/example_airl_seals_half_cheetah_best_hp_eval.json +++ b/benchmarking/airl_seals_half_cheetah_best_hp_eval.json @@ -62,6 +62,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/HalfCheetah-v0" + "gym_id": "seals/HalfCheetah-v1" } } diff --git a/benchmarking/example_airl_seals_hopper_best_hp_eval.json b/benchmarking/airl_seals_hopper_best_hp_eval.json similarity index 98% rename from benchmarking/example_airl_seals_hopper_best_hp_eval.json rename to benchmarking/airl_seals_hopper_best_hp_eval.json index 91080d7ce..58c2475f5 100644 --- a/benchmarking/example_airl_seals_hopper_best_hp_eval.json +++ b/benchmarking/airl_seals_hopper_best_hp_eval.json @@ -75,6 +75,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Hopper-v0" + "gym_id": "seals/Hopper-v1" } } diff --git a/benchmarking/example_airl_seals_swimmer_best_hp_eval.json b/benchmarking/airl_seals_swimmer_best_hp_eval.json similarity index 96% rename from benchmarking/example_airl_seals_swimmer_best_hp_eval.json rename to benchmarking/airl_seals_swimmer_best_hp_eval.json index fcca8e6b3..8529c58b5 100644 --- a/benchmarking/example_airl_seals_swimmer_best_hp_eval.json +++ b/benchmarking/airl_seals_swimmer_best_hp_eval.json @@ -12,7 +12,7 @@ }, "expert": { "loader_kwargs": { - "gym_id": "seals/Swimmer-v0", + "gym_id": "seals/Swimmer-v1", "organization": "HumanCompatibleAI" } }, @@ -81,6 +81,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Swimmer-v0" + "gym_id": "seals/Swimmer-v1" } } diff --git a/benchmarking/example_airl_seals_walker_best_hp_eval.json b/benchmarking/airl_seals_walker_best_hp_eval.json similarity index 96% rename from benchmarking/example_airl_seals_walker_best_hp_eval.json rename to benchmarking/airl_seals_walker_best_hp_eval.json index c63070751..edd99806d 100644 --- a/benchmarking/example_airl_seals_walker_best_hp_eval.json +++ b/benchmarking/airl_seals_walker_best_hp_eval.json @@ -12,7 +12,7 @@ }, "expert": { "loader_kwargs": { - "gym_id": "seals/Walker2d-v0", + "gym_id": "seals/Walker2d-v1", "organization": "HumanCompatibleAI" } }, @@ -81,6 +81,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Walker2d-v0" + "gym_id": "seals/Walker2d-v1" } } diff --git a/benchmarking/example_bc_seals_ant_best_hp_eval.json b/benchmarking/bc_seals_ant_best_hp_eval.json similarity index 97% rename from benchmarking/example_bc_seals_ant_best_hp_eval.json rename to benchmarking/bc_seals_ant_best_hp_eval.json index 108a93ce7..e9baa8fc1 100644 --- a/benchmarking/example_bc_seals_ant_best_hp_eval.json +++ b/benchmarking/bc_seals_ant_best_hp_eval.json @@ -43,6 +43,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Ant-v0" + "gym_id": "seals/Ant-v1" } } diff --git a/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json b/benchmarking/bc_seals_half_cheetah_best_hp_eval.json similarity index 96% rename from benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json rename to benchmarking/bc_seals_half_cheetah_best_hp_eval.json index ecaff2eb0..041f159b0 100644 --- a/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json +++ b/benchmarking/bc_seals_half_cheetah_best_hp_eval.json @@ -43,6 +43,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/HalfCheetah-v0" + "gym_id": "seals/HalfCheetah-v1" } } diff --git a/benchmarking/example_bc_seals_hopper_best_hp_eval.json b/benchmarking/bc_seals_hopper_best_hp_eval.json similarity index 96% rename from benchmarking/example_bc_seals_hopper_best_hp_eval.json rename to benchmarking/bc_seals_hopper_best_hp_eval.json index e8c821841..9a7872d37 100644 --- a/benchmarking/example_bc_seals_hopper_best_hp_eval.json +++ b/benchmarking/bc_seals_hopper_best_hp_eval.json @@ -43,6 +43,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Hopper-v0" + "gym_id": "seals/Hopper-v1" } } diff --git a/benchmarking/example_bc_seals_swimmer_best_hp_eval.json b/benchmarking/bc_seals_swimmer_best_hp_eval.json similarity index 96% rename from benchmarking/example_bc_seals_swimmer_best_hp_eval.json rename to benchmarking/bc_seals_swimmer_best_hp_eval.json index 30884c9c4..8a8f2456a 100644 --- a/benchmarking/example_bc_seals_swimmer_best_hp_eval.json +++ b/benchmarking/bc_seals_swimmer_best_hp_eval.json @@ -43,6 +43,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Swimmer-v0" + "gym_id": "seals/Swimmer-v1" } } diff --git a/benchmarking/example_bc_seals_walker_best_hp_eval.json b/benchmarking/bc_seals_walker_best_hp_eval.json similarity index 96% rename from benchmarking/example_bc_seals_walker_best_hp_eval.json rename to benchmarking/bc_seals_walker_best_hp_eval.json index 0ca30120e..f33e6c5a2 100644 --- a/benchmarking/example_bc_seals_walker_best_hp_eval.json +++ b/benchmarking/bc_seals_walker_best_hp_eval.json @@ -43,6 +43,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Walker2d-v0" + "gym_id": "seals/Walker2d-v1" } } diff --git a/benchmarking/example_dagger_seals_ant_best_hp_eval.json b/benchmarking/dagger_seals_ant_best_hp_eval.json similarity index 97% rename from benchmarking/example_dagger_seals_ant_best_hp_eval.json rename to benchmarking/dagger_seals_ant_best_hp_eval.json index de75b80f1..e02828667 100644 --- a/benchmarking/example_dagger_seals_ant_best_hp_eval.json +++ b/benchmarking/dagger_seals_ant_best_hp_eval.json @@ -47,6 +47,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Ant-v0" + "gym_id": "seals/Ant-v1" } } diff --git a/benchmarking/example_dagger_seals_half_cheetah_best_hp_eval.json b/benchmarking/dagger_seals_half_cheetah_best_hp_eval.json similarity index 96% rename from benchmarking/example_dagger_seals_half_cheetah_best_hp_eval.json rename to benchmarking/dagger_seals_half_cheetah_best_hp_eval.json index 7f42bfdf9..d1c9e5923 100644 --- a/benchmarking/example_dagger_seals_half_cheetah_best_hp_eval.json +++ b/benchmarking/dagger_seals_half_cheetah_best_hp_eval.json @@ -47,6 +47,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/HalfCheetah-v0" + "gym_id": "seals/HalfCheetah-v1" } } diff --git a/benchmarking/example_dagger_seals_hopper_best_hp_eval.json b/benchmarking/dagger_seals_hopper_best_hp_eval.json similarity index 97% rename from benchmarking/example_dagger_seals_hopper_best_hp_eval.json rename to benchmarking/dagger_seals_hopper_best_hp_eval.json index 1cf29a1a4..b91f66298 100644 --- a/benchmarking/example_dagger_seals_hopper_best_hp_eval.json +++ b/benchmarking/dagger_seals_hopper_best_hp_eval.json @@ -47,6 +47,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Hopper-v0" + "gym_id": "seals/Hopper-v1" } } diff --git a/benchmarking/example_dagger_seals_swimmer_best_hp_eval.json b/benchmarking/dagger_seals_swimmer_best_hp_eval.json similarity index 97% rename from benchmarking/example_dagger_seals_swimmer_best_hp_eval.json rename to benchmarking/dagger_seals_swimmer_best_hp_eval.json index c112db680..545761cbc 100644 --- a/benchmarking/example_dagger_seals_swimmer_best_hp_eval.json +++ b/benchmarking/dagger_seals_swimmer_best_hp_eval.json @@ -47,6 +47,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Swimmer-v0" + "gym_id": "seals/Swimmer-v1" } } diff --git a/benchmarking/example_dagger_seals_walker_best_hp_eval.json b/benchmarking/dagger_seals_walker_best_hp_eval.json similarity index 97% rename from benchmarking/example_dagger_seals_walker_best_hp_eval.json rename to benchmarking/dagger_seals_walker_best_hp_eval.json index e59bef464..7b694c8d2 100644 --- a/benchmarking/example_dagger_seals_walker_best_hp_eval.json +++ b/benchmarking/dagger_seals_walker_best_hp_eval.json @@ -47,6 +47,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Walker2d-v0" + "gym_id": "seals/Walker2d-v1" } } diff --git a/benchmarking/example_gail_seals_ant_best_hp_eval.json b/benchmarking/gail_seals_ant_best_hp_eval.json similarity index 98% rename from benchmarking/example_gail_seals_ant_best_hp_eval.json rename to benchmarking/gail_seals_ant_best_hp_eval.json index 81399b00c..3d43b34ba 100644 --- a/benchmarking/example_gail_seals_ant_best_hp_eval.json +++ b/benchmarking/gail_seals_ant_best_hp_eval.json @@ -62,6 +62,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Ant-v0" + "gym_id": "seals/Ant-v1" } } diff --git a/benchmarking/example_gail_seals_half_cheetah_best_hp_eval.json b/benchmarking/gail_seals_half_cheetah_best_hp_eval.json similarity index 97% rename from benchmarking/example_gail_seals_half_cheetah_best_hp_eval.json rename to benchmarking/gail_seals_half_cheetah_best_hp_eval.json index 1d2f26648..914f3712a 100644 --- a/benchmarking/example_gail_seals_half_cheetah_best_hp_eval.json +++ b/benchmarking/gail_seals_half_cheetah_best_hp_eval.json @@ -62,6 +62,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/HalfCheetah-v0" + "gym_id": "seals/HalfCheetah-v1" } } diff --git a/benchmarking/example_gail_seals_hopper_best_hp_eval.json b/benchmarking/gail_seals_hopper_best_hp_eval.json similarity index 98% rename from benchmarking/example_gail_seals_hopper_best_hp_eval.json rename to benchmarking/gail_seals_hopper_best_hp_eval.json index 70787ff7e..cebdae71c 100644 --- a/benchmarking/example_gail_seals_hopper_best_hp_eval.json +++ b/benchmarking/gail_seals_hopper_best_hp_eval.json @@ -75,6 +75,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Hopper-v0" + "gym_id": "seals/Hopper-v1" } } diff --git a/benchmarking/example_gail_seals_swimmer_best_hp_eval.json b/benchmarking/gail_seals_swimmer_best_hp_eval.json similarity index 96% rename from benchmarking/example_gail_seals_swimmer_best_hp_eval.json rename to benchmarking/gail_seals_swimmer_best_hp_eval.json index 650c5f46a..b0bd0e645 100644 --- a/benchmarking/example_gail_seals_swimmer_best_hp_eval.json +++ b/benchmarking/gail_seals_swimmer_best_hp_eval.json @@ -12,7 +12,7 @@ }, "expert": { "loader_kwargs": { - "gym_id": "seals/Swimmer-v0", + "gym_id": "seals/Swimmer-v1", "organization": "HumanCompatibleAI" } }, @@ -81,6 +81,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Swimmer-v0" + "gym_id": "seals/Swimmer-v1" } } diff --git a/benchmarking/example_gail_seals_walker_best_hp_eval.json b/benchmarking/gail_seals_walker_best_hp_eval.json similarity index 96% rename from benchmarking/example_gail_seals_walker_best_hp_eval.json rename to benchmarking/gail_seals_walker_best_hp_eval.json index d85eb46d5..2626b4c43 100644 --- a/benchmarking/example_gail_seals_walker_best_hp_eval.json +++ b/benchmarking/gail_seals_walker_best_hp_eval.json @@ -12,7 +12,7 @@ }, "expert": { "loader_kwargs": { - "gym_id": "seals/Walker2d-v0", + "gym_id": "seals/Walker2d-v1", "organization": "HumanCompatibleAI" } }, @@ -81,6 +81,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Walker2d-v0" + "gym_id": "seals/Walker2d-v1" } } diff --git a/benchmarking/tuning.py b/benchmarking/tuning.py new file mode 100644 index 000000000..9c3f52498 --- /dev/null +++ b/benchmarking/tuning.py @@ -0,0 +1,174 @@ +"""Tunes the hyperparameters of the algorithms.""" + +import copy +import pathlib +from typing import Any, Dict + +import numpy as np +import ray +from pandas.api import types as pd_types +from ray.tune.search import optuna +from sacred.observers import FileStorageObserver +from tuning_config import parallel_ex, tuning_ex + + +@tuning_ex.main +def tune( + parallel_run_config: Dict[str, Any], + eval_best_trial_resource_multiplier: int = 1, + num_eval_seeds: int = 5, +) -> None: + """Tune hyperparameters of imitation algorithms using parallel script. + + Args: + parallel_run_config: Dictionary of arguments to pass to the parallel script. + eval_best_trial_resource_multiplier: Factor by which to multiply the + number of cpus per trial in `resources_per_trial`. This is useful for + allocating more resources per trial to the evaluation trials than the + resources for hyperparameter tuning since number of evaluation trials + is usually much smaller than the number of tuning trials. + num_eval_seeds: Number of distinct seeds to evaluate the best trial on. + Set to 0 to disable evaluation. + + Raises: + ValueError: If no trials are returned by the parallel run of tuning. + """ + updated_parallel_run_config = copy.deepcopy(parallel_run_config) + search_alg = optuna.OptunaSearch() + if "tune_run_kwargs" in updated_parallel_run_config: + updated_parallel_run_config["tune_run_kwargs"]["search_alg"] = search_alg + else: + updated_parallel_run_config["tune_run_kwargs"] = dict(search_alg=search_alg) + run = parallel_ex.run(config_updates=updated_parallel_run_config) + experiment_analysis = run.result + if not experiment_analysis.trials: + raise ValueError( + "No trials found. Please ensure that the `experiment_checkpoint_path` " + "in `parallel_run_config` is passed correctly " + "or that the tuning run finished properly.", + ) + + return_key = "imit_stats/monitor_return_mean" + if updated_parallel_run_config["sacred_ex_name"] == "train_rl": + return_key = "monitor_return_mean" + best_trial = find_best_trial(experiment_analysis, return_key, print_return=True) + + if num_eval_seeds > 0: # evaluate the best trial + resources_per_trial_eval = copy.deepcopy( + updated_parallel_run_config["resources_per_trial"], + ) + # update cpus per trial only if it is provided in `resources_per_trial` + # Uses the default values (cpu=1) if it is not provided + if "cpu" in updated_parallel_run_config["resources_per_trial"]: + resources_per_trial_eval["cpu"] *= eval_best_trial_resource_multiplier + evaluate_trial( + best_trial, + num_eval_seeds, + updated_parallel_run_config["run_name"] + "_best_hp_eval", + updated_parallel_run_config, + resources_per_trial_eval, + return_key, + ) + + +def find_best_trial( + experiment_analysis: ray.tune.analysis.ExperimentAnalysis, + return_key: str, + print_return: bool = False, +) -> ray.tune.experiment.Trial: + """Find the trial with the best mean return across all seeds. + + Args: + experiment_analysis: The result of a parallel/tuning experiment. + return_key: The key of the return metric in the results dataframe. + print_return: Whether to print the mean and std of the returns + of the best trial. + + Returns: + best_trial: The trial with the best mean return across all seeds. + """ + df = experiment_analysis.results_df + # convert object dtype to str required by df.groupby + for col in df.columns: + if pd_types.is_object_dtype(df[col]): + df[col] = df[col].astype("str") + # group into separate HP configs + grp_keys = [c for c in df.columns if c.startswith("config") and "seed" not in c] + grps = df.groupby(grp_keys) + # store mean return of runs across all seeds in a group + df["mean_return"] = grps[return_key].transform(lambda x: x.mean()) + best_config_df = df[df["mean_return"] == df["mean_return"].max()] + row = best_config_df.iloc[0] + best_config_tag = row["experiment_tag"] + assert experiment_analysis.trials is not None # for mypy + best_trial = [ + t for t in experiment_analysis.trials if best_config_tag in t.experiment_tag + ][0] + + if print_return: + all_returns = df[df["mean_return"] == row["mean_return"]][return_key] + all_returns = all_returns.to_numpy() + print("All returns:", all_returns) + print("Mean return:", row["mean_return"]) + print("Std return:", np.std(all_returns)) + print("Total seeds:", len(all_returns)) + return best_trial + + +def evaluate_trial( + trial: ray.tune.experiment.Trial, + num_eval_seeds: int, + run_name: str, + parallel_run_config: Dict[str, Any], + resources_per_trial: Dict[str, int], + return_key: str, + print_return: bool = False, +): + """Evaluate a given trial of a parallel run on a separate set of seeds. + + Args: + trial: The trial to evaluate. + num_eval_seeds: Number of distinct seeds to evaluate the best trial on. + run_name: The name of the evaluation run. + parallel_run_config: Dictionary of arguments passed to the parallel + script to get best_trial. + resources_per_trial: Resources to be used for each evaluation trial. + return_key: The key of the return metric in the results dataframe. + print_return: Whether to print the mean and std of the evaluation returns. + + Returns: + eval_run: The result of the evaluation run. + """ + config = trial.config + config["config_updates"].update( + seed=ray.tune.grid_search(list(range(100, 100 + num_eval_seeds))), + ) + eval_config_updates = parallel_run_config.copy() + eval_config_updates.update( + run_name=run_name, + num_samples=1, + search_space=config, + resources_per_trial=resources_per_trial, + search_alg=None, + repeat=1, + experiment_checkpoint_path="", + ) + eval_run = parallel_ex.run(config_updates=eval_config_updates) + eval_result = eval_run.result + returns = eval_result.results_df[return_key].to_numpy() + if print_return: + print("All returns:", returns) + print("Mean:", np.mean(returns)) + print("Std:", np.std(returns)) + return eval_run + + +def main_console(): + observer_path = pathlib.Path.cwd() / "output" / "sacred" / "tuning" + observer = FileStorageObserver(observer_path) + tuning_ex.observers.append(observer) + tuning_ex.run_commandline() + + +if __name__ == "__main__": # pragma: no cover + main_console() diff --git a/benchmarking/tuning_config.py b/benchmarking/tuning_config.py new file mode 100644 index 000000000..239537406 --- /dev/null +++ b/benchmarking/tuning_config.py @@ -0,0 +1,232 @@ +"""Config files for tuning experiments.""" + +import ray.tune as tune +import sacred +from torch import nn + +from imitation.algorithms import dagger as dagger_alg +from imitation.scripts.parallel import parallel_ex + +tuning_ex = sacred.Experiment("tuning", ingredients=[parallel_ex]) + + +@tuning_ex.named_config +def rl(): + parallel_run_config = dict( + sacred_ex_name="train_rl", + run_name="rl_tuning", + base_named_configs=["logging.wandb_logging"], + base_config_updates={"environment": {"num_vec": 1}}, + search_space={ + "config_updates": { + "rl": { + "batch_size": tune.choice([512, 1024, 2048, 4096, 8192]), + "rl_kwargs": { + "learning_rate": tune.loguniform(1e-5, 1e-2), + "batch_size": tune.choice([64, 128, 256, 512]), + "n_epochs": tune.choice([5, 10, 20]), + }, + }, + }, + }, + num_samples=100, + repeat=1, + resources_per_trial=dict(cpu=1), + ) + num_eval_seeds = 5 + + +@tuning_ex.named_config +def bc(): + parallel_run_config = dict( + sacred_ex_name="train_imitation", + run_name="bc_tuning", + base_named_configs=["logging.wandb_logging"], + base_config_updates={ + "environment": {"num_vec": 1}, + "demonstrations": {"source": "huggingface"}, + }, + search_space={ + "config_updates": { + "bc": dict( + batch_size=tune.choice([8, 16, 32, 64]), + l2_weight=tune.loguniform(1e-6, 1e-2), # L2 regularization weight + optimizer_kwargs=dict( + lr=tune.loguniform(1e-5, 1e-2), + ), + train_kwargs=dict( + n_epochs=tune.choice([1, 5, 10, 20]), + ), + ), + }, + "command_name": "bc", + }, + num_samples=64, + repeat=3, + resources_per_trial=dict(cpu=1), + ) + + num_eval_seeds = 5 + eval_best_trial_resource_multiplier = 1 + + +@tuning_ex.named_config +def dagger(): + parallel_run_config = dict( + sacred_ex_name="train_imitation", + run_name="dagger_tuning", + base_named_configs=["logging.wandb_logging"], + base_config_updates={ + "environment": {"num_vec": 1}, + "demonstrations": {"source": "huggingface"}, + "dagger": {"total_timesteps": 1e5}, + "bc": { + "batch_size": 16, + "l2_weight": 1e-4, + "optimizer_kwargs": {"lr": 1e-3}, + }, + }, + search_space={ + "config_updates": { + "bc": dict( + train_kwargs=dict( + n_epochs=tune.choice([1, 5, 10]), + ), + ), + "dagger": dict( + beta_schedule=tune.choice( + [dagger_alg.LinearBetaSchedule(i) for i in [1, 5, 15]] + + [ + dagger_alg.ExponentialBetaSchedule(i) + for i in [0.3, 0.5, 0.7] + ], + ), + rollout_round_min_episodes=tune.choice([3, 5, 10]), + ), + }, + "command_name": "dagger", + }, + num_samples=50, + repeat=3, + resources_per_trial=dict(cpu=1), + ) + num_eval_seeds = 5 + + +@tuning_ex.named_config +def gail(): + parallel_run_config = dict( + sacred_ex_name="train_adversarial", + run_name="gail_tuning_hc", + base_named_configs=["logging.wandb_logging"], + base_config_updates={ + "environment": {"num_vec": 1}, + "demonstrations": {"source": "huggingface"}, + "total_timesteps": 1e7, + }, + search_space={ + "config_updates": { + "algorithm_kwargs": dict( + demo_batch_size=tune.choice([32, 128, 512, 2048, 8192]), + n_disc_updates_per_round=tune.choice([8, 16]), + ), + "rl": { + "batch_size": tune.choice([4096, 8192, 16384]), + "rl_kwargs": { + "ent_coef": tune.loguniform(1e-7, 1e-3), + "learning_rate": tune.loguniform(1e-5, 1e-2), + }, + }, + "algorithm_specific": {}, + }, + "command_name": "gail", + }, + num_samples=100, + repeat=3, + resources_per_trial=dict(cpu=1), + ) + num_eval_seeds = 5 + + +@tuning_ex.named_config +def airl(): + parallel_run_config = dict( + sacred_ex_name="train_adversarial", + run_name="airl_tuning", + base_named_configs=["logging.wandb_logging"], + base_config_updates={ + "environment": {"num_vec": 1}, + "demonstrations": {"source": "huggingface"}, + "total_timesteps": 1e7, + }, + search_space={ + "config_updates": { + "algorithm_kwargs": dict( + demo_batch_size=tune.choice([32, 128, 512, 2048, 8192]), + n_disc_updates_per_round=tune.choice([8, 16]), + ), + "rl": { + "batch_size": tune.choice([4096, 8192, 16384]), + "rl_kwargs": { + "ent_coef": tune.loguniform(1e-7, 1e-3), + "learning_rate": tune.loguniform(1e-5, 1e-2), + }, + }, + "algorithm_specific": {}, + }, + "command_name": "airl", + }, + num_samples=100, + repeat=3, + resources_per_trial=dict(cpu=1), + ) + num_eval_seeds = 5 + + +@tuning_ex.named_config +def pc(): + parallel_run_config = dict( + sacred_ex_name="train_preference_comparisons", + run_name="pc_tuning", + base_named_configs=["logging.wandb_logging"], + base_config_updates={ + "environment": {"num_vec": 1}, + "demonstrations": {"source": "huggingface"}, + "total_timesteps": 2e7, + "total_comparisons": 5000, + "query_schedule": "hyperbolic", + "gatherer_kwargs": {"sample": True}, + }, + search_space={ + "named_configs": [ + ["reward.normalize_output_disable"], + ], + "config_updates": { + "train": { + "policy_kwargs": { + "activation_fn": tune.choice( + [ + nn.ReLU, + ], + ), + }, + }, + "num_iterations": tune.choice([25, 50]), + "initial_comparison_frac": tune.choice([0.1, 0.25]), + "reward_trainer_kwargs": { + "epochs": tune.choice([1, 3, 6]), + }, + "rl": { + "batch_size": tune.choice([512, 2048, 8192]), + "rl_kwargs": { + "learning_rate": tune.loguniform(1e-5, 1e-2), + "ent_coef": tune.loguniform(1e-7, 1e-3), + }, + }, + }, + }, + num_samples=100, + repeat=3, + resources_per_trial=dict(cpu=1), + ) + num_eval_seeds = 5 diff --git a/benchmarking/util.py b/benchmarking/util.py index 408f0d812..88416344d 100644 --- a/benchmarking/util.py +++ b/benchmarking/util.py @@ -79,7 +79,7 @@ def clean_config_file(file: pathlib.Path, write_path: pathlib.Path, /) -> None: remove_empty_dicts(config) # files are of the format - # /path/to/file/example___best_hp_eval//sacred/1/config.json + # /path/to/file/__best_hp_eval//sacred/1/config.json # we want to write to //_.json with open(write_path / f"{file.parents[3].name}.json", "w") as f: json.dump(config, f, indent=4) diff --git a/experiments/commands.py b/experiments/commands.py index 2ac737e06..738a55011 100644 --- a/experiments/commands.py +++ b/experiments/commands.py @@ -22,13 +22,13 @@ python -m imitation.scripts.train_adversarial airl \ --capture=sys --name=run0 \ --file_storage=output/sacred/$USER-cmd-run0-airl-0-a3531726 \ - with ../benchmarking/example_airl_seals_walker_best_hp_eval.json \ + with ../benchmarking/airl_seals_walker_best_hp_eval.json \ seed=0 logging.log_root=output python -m imitation.scripts.train_adversarial gail \ --capture=sys --name=run0 \ --file_storage=output/sacred/$USER-cmd-run0-gail-0-a1ec171b \ - with ../benchmarking/example_gail_seals_walker_best_hp_eval.json \ + with ../benchmarking/gail_seals_walker_best_hp_eval.json \ seed=0 logging.log_root=output We can execute commands in parallel by piping them to GNU parallel: @@ -42,7 +42,7 @@ python commands.py \ --name=run0 \ - --cfg_pattern=../benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json \ + --cfg_pattern=../benchmarking/bc_seals_half_cheetah_best_hp_eval.json \ --output_dir=/data/output \ --remote @@ -52,7 +52,7 @@ --command "python -m imitation.scripts.train_imitation bc \ --capture=sys --name=run0 \ --file_storage=/data/output/sacred/$USER-cmd-run0-bc-0-72cb1df3 \ - with /data/imitation/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json \ + with /data/imitation/benchmarking/bc_seals_half_cheetah_best_hp_eval.json \ seed=0 logging.log_root=/data/output" \ --container hacobe/devbox:imitation \ --login --force-pull --never-restart --gpu 0 --shared-host-dir-mount /data @@ -85,7 +85,7 @@ def _get_algo_name(cfg_file: str) -> str: """Get the algorithm name from the given config filename.""" algo_names = set() for key in _ALGO_NAME_TO_SCRIPT_NAME: - if cfg_file.find("_" + key + "_") != -1: + if cfg_file.find(key + "_") != -1: algo_names.add(key) if len(algo_names) == 0: @@ -177,19 +177,19 @@ def parse() -> argparse.Namespace: parser.add_argument( "--cfg_pattern", type=str, - default="example_bc_seals_half_cheetah_best_hp_eval.json", + default="bc_seals_half_cheetah_best_hp_eval.json", help="""Generate a command for every file that matches this glob pattern. \ Each matching file should be a config file that has its algorithm name \ (bc, dagger, airl or gail) bookended by underscores in the filename. \ If the --remote flag is enabled, then generate a command for every file in the \ --remote_cfg_dir directory that has the same filename as a file that matches this \ glob pattern. E.g., suppose the current, local working directory is 'foo' and \ -the subdirectory 'foo/bar' contains the config files 'example_bc_best.json' and \ -'example_dagger_best.json'. If the pattern 'bar/*.json' is supplied, then globbing \ -will return ['bar/example_bc_best.json', 'bar/example_dagger_best.json']. \ +the subdirectory 'foo/bar' contains the config files 'bc_best.json' and \ +'dagger_best.json'. If the pattern 'bar/*.json' is supplied, then globbing \ +will return ['bar/bc_best.json', 'bar/dagger_best.json']. \ If the --remote flag is enabled, 'bar' will be replaced with `remote_cfg_dir` and \ commands will be created for the following configs: \ -[`remote_cfg_dir`/example_bc_best.json, `remote_cfg_dir`/example_dagger_best.json] \ +[`remote_cfg_dir`/bc_best.json, `remote_cfg_dir`/dagger_best.json] \ Why not just supply the pattern '`remote_cfg_dir`/*.json' directly? \ Because the `remote_cfg_dir` directory may not exist on the local machine.""", ) diff --git a/setup.cfg b/setup.cfg index 2fa805d49..95f2223d9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,6 +7,7 @@ per-file-ignores = # F841 local variable unused [for Sacred config scopes] src/imitation/scripts/config/*.py:F841 ../src/imitation/scripts/config/*.py:F841 + benchmarking/tuning_config.py:F841 src/imitation/envs/examples/airl_envs/*.py:D [darglint] @@ -41,6 +42,8 @@ source = imitation include= src/* tests/* +omit = + src/imitation/scripts/config/* [coverage:report] exclude_lines = diff --git a/setup.py b/setup.py index 5fc3354ad..0384014ee 100644 --- a/setup.py +++ b/setup.py @@ -187,7 +187,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: # encode only known incompatibilities here. This prevents nasty dependency issues # for our users. install_requires=[ - "gymnasium[classic-control]~=0.28.1", + "gymnasium[classic-control]~=0.29", "matplotlib", "numpy>=1.15", "torch>=1.4.0", @@ -199,6 +199,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: "sacred>=0.8.4", "tensorboard>=1.14", "huggingface_sb3~=3.0", + "optuna>=3.0.1", "datasets>=2.8.0", ], tests_require=TESTS_REQUIRE, @@ -219,7 +220,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: "docs": DOCS_REQUIRE, "parallel": PARALLEL_REQUIRE, "mujoco": [ - "gymnasium[classic-control,mujoco]~=0.28.1", + "gymnasium[classic-control,mujoco]~=0.29", ], "atari": ATARI_REQUIRE, }, diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index ece30b011..c9e880c07 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -2,7 +2,7 @@ import abc import dataclasses import logging -from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload +from typing import Callable, Iterable, Iterator, List, Mapping, Optional, Type, overload import numpy as np import torch as th @@ -10,7 +10,9 @@ import tqdm from stable_baselines3.common import ( base_class, + callbacks, distributions, + off_policy_algorithm, on_policy_algorithm, policies, vec_env, @@ -20,6 +22,7 @@ from imitation.algorithms import base from imitation.data import buffer, rollout, types, wrappers +from imitation.policies import replay_buffer_wrapper from imitation.rewards import reward_nets, reward_wrapper from imitation.util import logger, networks, util @@ -92,6 +95,55 @@ def compute_train_stats( } +class TrainDiscriminatorCallback(callbacks.BaseCallback): + """Callback for training discriminator after collecting rollouts.""" + + def __init__(self, adversarial_trainer, *args, **kwargs): + """Builds TrainDiscriminatorCallback. + + Args: + adversarial_trainer: The AdversarialTrainer instance in which + this callback will be called. + *args: Passed through to `callbacks.BaseCallback`. + **kwargs: Passed through to `callbacks.BaseCallback`. + """ + self.adversarial_trainer = adversarial_trainer + self.gen_ctx_manager = None + super().__init__(*args, **kwargs) + + def _on_step(self) -> bool: + return True + + def _on_rollout_end(self) -> None: + if self.gen_ctx_manager is not None: + self.exit_gen_ctx_manager() + gen_trajs, ep_lens = self.adversarial_trainer.venv_buffering.pop_trajectories() + self.adversarial_trainer._check_fixed_horizon(ep_lens) + gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs) + self.adversarial_trainer._gen_replay_buffer.store(gen_samples) + + for _ in range(self.adversarial_trainer.n_disc_updates_per_round): + with networks.training(self.adversarial_trainer.reward_train): + # switch to training mode (affects dropout, normalization) + self.adversarial_trainer.train_disc() + + # update the rollouts with the reward of the latest discriminator + self.adversarial_trainer.update_rewards_of_rollouts() + + # This is a hacky way to enable logger.accumulate_means for generator + # This is done to avoid nested loggers of discriminator and generator + self.gen_ctx_manager = self.adversarial_trainer.logger.accumulate_means("gen") + self.gen_ctx_manager.__enter__() + + def exit_gen_ctx_manager(self) -> None: + assert self.gen_ctx_manager is not None + self.gen_ctx_manager.__exit__(None, None, None) + self.gen_ctx_manager = None + + def _on_training_end(self) -> None: + self.exit_gen_ctx_manager() + + class AdversarialTrainer(base.DemonstrationAlgorithm[types.Transitions]): """Base class for adversarial imitation learning algorithms like GAIL and AIRL.""" @@ -228,16 +280,22 @@ def __init__( self.venv_buffering = wrappers.BufferingWrapper(self.venv) + self.disc_trainer_callback = TrainDiscriminatorCallback(self) if debug_use_ground_truth: # Would use an identity reward fn here, but RewardFns can't see rewards. self.venv_wrapped = self.venv_buffering - self.gen_callback = None + self.gen_callback: List[callbacks.BaseCallback] = [ + self.disc_trainer_callback, + ] else: self.venv_wrapped = reward_wrapper.RewardVecEnvWrapper( self.venv_buffering, reward_fn=self.reward_train.predict_processed, ) - self.gen_callback = self.venv_wrapped.make_log_callback() + self.gen_callback = [ + self.venv_wrapped.make_log_callback(), + self.disc_trainer_callback, + ] self.venv_train = self.venv_wrapped self.gen_algo.set_env(self.venv_train) @@ -314,6 +372,35 @@ def _next_expert_batch(self) -> Mapping: assert self._endless_expert_iterator is not None return next(self._endless_expert_iterator) + def update_rewards_of_rollouts(self) -> None: + """Updates the rewards of the rollouts using the latest discriminator.""" + if isinstance(self.gen_algo, on_policy_algorithm.OnPolicyAlgorithm): + buffer = self.gen_algo.rollout_buffer + assert buffer is not None + reward_fn_inputs = replay_buffer_wrapper._rollout_buffer_to_reward_fn_input( + self.gen_algo.rollout_buffer, + ) + rewards = self._reward_net.predict(**reward_fn_inputs) + rewards = rewards.reshape(buffer.rewards.shape) + last_values = buffer.advantages[-1] - buffer.rewards[-1] + buffer.values[-1] + last_values = last_values / buffer.gamma + # here we assume that the actual last_values cannot exactly be 0.0 and so if + # last_values is 0.0 then we know that the episode terminated + last_dones = last_values == 0.0 + self.gen_algo.rollout_buffer.rewards[:] = rewards + self.gen_algo.rollout_buffer.compute_returns_and_advantage( + th.tensor(last_values), + last_dones, + ) + elif isinstance(self.gen_algo, off_policy_algorithm.OffPolicyAlgorithm): + buffer = self.gen_algo.replay_buffer + assert buffer is not None + reward_fn_inputs = replay_buffer_wrapper._replay_buffer_to_reward_fn_input( + buffer, + ) + rewards = self._reward_net.predict(**reward_fn_inputs) + buffer.rewards[:] = rewards.reshape(buffer.rewards.shape) + def train_disc( self, *, @@ -388,13 +475,15 @@ def train_disc( return train_stats - def train_gen( + def train_gen_with_disc( self, total_timesteps: Optional[int] = None, learn_kwargs: Optional[Mapping] = None, ) -> None: """Trains the generator to maximize the discriminator loss. + The discriminator is also trained after the rollouts are collected and before + the generator is trained. After the end of training populates the generator replay buffer (used in discriminator training) with `self.disc_batch_size` transitions. @@ -410,19 +499,13 @@ def train_gen( if learn_kwargs is None: learn_kwargs = {} - with self.logger.accumulate_means("gen"): - self.gen_algo.learn( - total_timesteps=total_timesteps, - reset_num_timesteps=False, - callback=self.gen_callback, - **learn_kwargs, - ) - self._global_step += 1 - - gen_trajs, ep_lens = self.venv_buffering.pop_trajectories() - self._check_fixed_horizon(ep_lens) - gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs) - self._gen_replay_buffer.store(gen_samples) + self.gen_algo.learn( + total_timesteps=total_timesteps, + reset_num_timesteps=False, + callback=self.gen_callback, + **learn_kwargs, + ) + self._global_step += 1 def train( self, @@ -431,8 +514,8 @@ def train( ) -> None: """Alternates between training the generator and discriminator. - Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`, - a call to `train_disc`, and finally a call to `callback(round)`. + Every "round" consists of a call to + `train_gen_with_disc(self.gen_train_timesteps)` and a call to `callback(round)`. Training ends once an additional "round" would cause the number of transitions sampled from the environment to exceed `total_timesteps`. @@ -451,11 +534,7 @@ def train( f"total_timesteps={total_timesteps})!" ) for r in tqdm.tqdm(range(0, n_rounds), desc="round"): - self.train_gen(self.gen_train_timesteps) - for _ in range(self.n_disc_updates_per_round): - with networks.training(self.reward_train): - # switch to training mode (affects dropout, normalization) - self.train_disc() + self.train_gen_with_disc(self.gen_train_timesteps) if callback: callback(r) self.logger.dump(self._global_step) @@ -547,7 +626,8 @@ def _make_disc_train_batches( if gen_samples is None: if self._gen_replay_buffer.size() == 0: raise RuntimeError( - "No generator samples for training. " "Call `train_gen()` first.", + "No generator samples for training. " + "Call `train_gen_with_disc()` first.", ) gen_samples_dataclass = self._gen_replay_buffer.sample(batch_size) gen_samples = types.dataclass_quick_asdict(gen_samples_dataclass) diff --git a/src/imitation/policies/replay_buffer_wrapper.py b/src/imitation/policies/replay_buffer_wrapper.py index 7177e2dc1..a8649f78f 100644 --- a/src/imitation/policies/replay_buffer_wrapper.py +++ b/src/imitation/policies/replay_buffer_wrapper.py @@ -4,7 +4,7 @@ import numpy as np from gymnasium import spaces -from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.buffers import ReplayBuffer, RolloutBuffer from stable_baselines3.common.type_aliases import ReplayBufferSamples from imitation.rewards.reward_function import RewardFn @@ -23,6 +23,48 @@ def _samples_to_reward_fn_input( ) +def _rollout_buffer_to_reward_fn_input( + buffer: RolloutBuffer, +) -> Mapping[str, np.ndarray]: + """Convert a sample from a rollout buffer to a numpy array.""" + assert buffer.observations is not None + assert buffer.actions is not None + obs = buffer.observations + next_obs = obs[1:] + next_obs = np.concatenate([next_obs, obs[-1:]], axis=0) # last obs not available + actions = buffer.actions + dones = buffer.episode_starts + dones = np.roll(dones, -1, axis=0) + dones[-1] = np.ones_like(dones[-1]) # last dones not available + + return dict( + state=obs.reshape(-1, *obs.shape[2:]), + action=actions.reshape(-1, *actions.shape[2:]), + next_state=next_obs.reshape(-1, *next_obs.shape[2:]), + done=dones.reshape(-1), + ) + + +def _replay_buffer_to_reward_fn_input( + buffer: ReplayBuffer, +) -> Mapping[str, np.ndarray]: + """Convert a sample from a replay buffer to a numpy array.""" + assert buffer.observations is not None + assert buffer.next_observations is not None + assert buffer.actions is not None + obs = buffer.observations + next_obs = buffer.next_observations + actions = buffer.actions + dones = buffer.dones + + return dict( + state=obs.reshape(-1, *obs.shape[2:]), + action=actions.reshape(-1, *actions.shape[2:]), + next_state=next_obs.reshape(-1, *next_obs.shape[2:]), + done=dones.reshape(-1), + ) + + class ReplayBufferRewardWrapper(ReplayBuffer): """Relabel the rewards in transitions sampled from a ReplayBuffer.""" diff --git a/src/imitation/scripts/analyze.py b/src/imitation/scripts/analyze.py index df8ad6b79..96b34bd6e 100644 --- a/src/imitation/scripts/analyze.py +++ b/src/imitation/scripts/analyze.py @@ -262,38 +262,50 @@ def analyze_imitation( csv_output_path: If provided, then save a CSV output file to this path. tex_output_path: If provided, then save a LaTeX-format table to this path. print_table: If True, then print the dataframe to stdout. - table_verbosity: Increasing levels of verbosity, from 0 to 2, increase the - number of columns in the table. + table_verbosity: Increasing levels of verbosity, from 0 to 3, increase the + number of columns in the table. Level 3 prints all of the columns available. Returns: The DataFrame generated from the Sacred logs. """ - table_entry_fns_subset = _get_table_entry_fns_subset(table_verbosity) + if table_verbosity == 3: + # Get column names for which we have get value using make_entry_fn + # These are same across Level 2 & 3. In Level 3, we additionally add remaining + # config columns. + table_entry_fns_subset = _get_table_entry_fns_subset(2) + else: + table_entry_fns_subset = _get_table_entry_fns_subset(table_verbosity) - rows = [] + output_table = pd.DataFrame() for sd in _gather_sacred_dicts(): - row = {} + if table_verbosity == 3: + # gets all config columns + row = pd.json_normalize(sd.config) + else: + # create an empty dataframe with a single row + row = pd.DataFrame(index=[0]) + for col_name, make_entry_fn in table_entry_fns_subset.items(): row[col_name] = make_entry_fn(sd) - rows.append(row) - df = pd.DataFrame(rows) - if len(df) > 0: - df.sort_values(by=["algo", "env_name"], inplace=True) + output_table = pd.concat([output_table, row]) + + if len(output_table) > 0: + output_table.sort_values(by=["algo", "env_name"], inplace=True) - display_options = dict(index=False) + display_options: Mapping[str, Any] = dict(index=False) if csv_output_path is not None: - df.to_csv(csv_output_path, **display_options) + output_table.to_csv(csv_output_path, **display_options) print(f"Wrote CSV file to {csv_output_path}") if tex_output_path is not None: - s: str = df.to_latex(**display_options) + s: str = output_table.to_latex(**display_options) with open(tex_output_path, "w") as f: f.write(s) print(f"Wrote TeX file to {tex_output_path}") if print_table: - print(df.to_string(**display_options)) - return df + print(output_table.to_string(**display_options)) + return output_table def _make_return_summary(stats: dict, prefix="") -> str: diff --git a/src/imitation/scripts/config/analyze.py b/src/imitation/scripts/config/analyze.py index 5213a875d..01cc2d035 100644 --- a/src/imitation/scripts/config/analyze.py +++ b/src/imitation/scripts/config/analyze.py @@ -18,7 +18,7 @@ def config(): tex_output_path = None # Write LaTex output to this path print_table = True # Set to True to print analysis to stdout split_str = "," # str used to split source_dir_str into multiple source dirs - table_verbosity = 1 # Choose from 0, 1, or 2 + table_verbosity = 1 # Choose from 0, 1, 2 or 3 source_dirs = None diff --git a/src/imitation/scripts/config/parallel.py b/src/imitation/scripts/config/parallel.py index 8ea76f522..c9c898feb 100644 --- a/src/imitation/scripts/config/parallel.py +++ b/src/imitation/scripts/config/parallel.py @@ -5,7 +5,10 @@ `@parallel_ex.named_config` to define a new parallel experiment. Adding custom named configs is necessary because the CLI interface can't add -search spaces to the config like `"seed": tune.grid_search([0, 1, 2, 3])`. +search spaces to the config like `"seed": tune.choice([0, 1, 2, 3])`. + +For tuning hyperparameters of an algorithm on a given environment, +check out the benchmarking/tuning.py script. """ import numpy as np @@ -31,19 +34,10 @@ def config(): "config_updates": {}, } # `config` argument to `ray.tune.run(trainable, config)` - local_dir = None # `local_dir` arg for `ray.tune.run` - upload_dir = None # `upload_dir` arg for `ray.tune.run` - n_seeds = 3 # Number of seeds to search over by default - - -@parallel_ex.config -def seeds(n_seeds): - search_space = {"config_updates": {"seed": tune.grid_search(list(range(n_seeds)))}} - - -@parallel_ex.named_config -def s3(): - upload_dir = "s3://shwang-chai/private" + num_samples = 1 # Number of samples per grid search configuration + repeat = 1 # Number of times to repeat a sampled configuration + experiment_checkpoint_path = "" # Path to checkpoint of experiment + tune_run_kwargs = {} # Additional kwargs to pass to `tune.run` # Debug named configs @@ -58,12 +52,12 @@ def generate_test_data(): """ sacred_ex_name = "train_rl" run_name = "TEST" - n_seeds = 1 + repeat = 1 search_space = { "config_updates": { "rl": { "rl_kwargs": { - "learning_rate": tune.grid_search( + "learning_rate": tune.choice( [3e-4 * x for x in (1 / 3, 1 / 2)], ), }, @@ -86,63 +80,16 @@ def generate_test_data(): def example_cartpole_rl(): sacred_ex_name = "train_rl" run_name = "example-cartpole" - n_seeds = 2 + repeat = 2 search_space = { "config_updates": { "rl": { "rl_kwargs": { - "learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)), - "nminibatches": tune.grid_search([16, 32, 64]), + "learning_rate": tune.choice(np.logspace(3e-6, 1e-1, num=3)), + "nminibatches": tune.choice([16, 32, 64]), }, }, }, } base_named_configs = ["cartpole"] resources_per_trial = dict(cpu=4) - - -EASY_ENVS = ["cartpole", "pendulum", "mountain_car"] - - -@parallel_ex.named_config -def example_rl_easy(): - sacred_ex_name = "train_rl" - run_name = "example-rl-easy" - n_seeds = 2 - search_space = { - "named_configs": tune.grid_search([[env] for env in EASY_ENVS]), - "config_updates": { - "rl": { - "rl_kwargs": { - "learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)), - "nminibatches": tune.grid_search([16, 32, 64]), - }, - }, - }, - } - resources_per_trial = dict(cpu=4) - - -@parallel_ex.named_config -def example_gail_easy(): - sacred_ex_name = "train_adversarial" - run_name = "example-gail-easy" - n_seeds = 1 - search_space = { - "named_configs": tune.grid_search([[env] for env in EASY_ENVS]), - "config_updates": { - "init_trainer_kwargs": { - "rl": { - "rl_kwargs": { - "learning_rate": tune.grid_search( - np.logspace(3e-6, 1e-1, num=3), - ), - "nminibatches": tune.grid_search([16, 32, 64]), - }, - }, - }, - }, - } - search_space = { - "command_name": "gail", - } diff --git a/src/imitation/scripts/config/train_adversarial.py b/src/imitation/scripts/config/train_adversarial.py index 55e6effec..acc842095 100644 --- a/src/imitation/scripts/config/train_adversarial.py +++ b/src/imitation/scripts/config/train_adversarial.py @@ -1,12 +1,17 @@ """Configuration for imitation.scripts.train_adversarial.""" import sacred +from torch import nn from imitation.rewards import reward_nets from imitation.scripts.ingredients import demonstrations, environment, expert from imitation.scripts.ingredients import logging as logging_ingredient from imitation.scripts.ingredients import policy_evaluation, reward, rl +# Note: All the hyperparameter configs in the file are of the tuned +# hyperparameters of the RL algorithm of the respective environment. +# Taken from imitation/scripts/config/train_rl.py + train_adversarial_ex = sacred.Experiment( "train_adversarial", ingredients=[ @@ -101,6 +106,22 @@ def seals_ant(): locals().update(**MUJOCO_SHARED_LOCALS) locals().update(**ANT_SHARED_LOCALS) environment = dict(gym_id="seals/Ant-v0") + demonstrations = dict(rollout_type="ppo-huggingface") + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=16, + clip_range=0.3, + ent_coef=3.1441389214159857e-06, + gae_lambda=0.8, + gamma=0.995, + learning_rate=0.00017959211641976886, + max_grad_norm=0.9, + n_epochs=10, + # policy_kwargs are same as the defaults + vf_coef=0.4351450387648799, + ), + ) CHEETAH_SHARED_LOCALS = dict( @@ -139,40 +160,126 @@ def half_cheetah(): @train_adversarial_ex.named_config def seals_half_cheetah(): - locals().update(**CHEETAH_SHARED_LOCALS) environment = dict(gym_id="seals/HalfCheetah-v0") + demonstrations = dict(rollout_type="ppo-huggingface") + rl = dict( + batch_size=512, + rl_kwargs=dict( + batch_size=64, + clip_range=0.1, + ent_coef=3.794797423594763e-06, + gae_lambda=0.95, + gamma=0.95, + learning_rate=0.0003286871805949382, + max_grad_norm=0.8, + n_epochs=5, + vf_coef=0.11483689492120866, + ), + ) + algorithm_kwargs = dict( + # Number of discriminator updates after each round of generator updates + n_disc_updates_per_round=16, + # Equivalent to no replay buffer if batch size is the same + gen_replay_buffer_capacity=512, + demo_batch_size=8192, + ) @train_adversarial_ex.named_config def seals_hopper(): - locals().update(**MUJOCO_SHARED_LOCALS) environment = dict(gym_id="seals/Hopper-v0") + demonstrations = dict(rollout_type="ppo-huggingface") + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=512, + clip_range=0.1, + ent_coef=0.0010159833764878474, + gae_lambda=0.98, + gamma=0.995, + learning_rate=0.0003904770450788824, + max_grad_norm=0.9, + n_epochs=20, + vf_coef=0.20315938606555833, + ), + ) @train_adversarial_ex.named_config -def seals_humanoid(): - locals().update(**MUJOCO_SHARED_LOCALS) - environment = dict(gym_id="seals/Humanoid-v0") - total_timesteps = int(4e6) +def seals_swimmer(): + environment = dict(gym_id="seals/Swimmer-v0") + total_timesteps = int(2e6) + demonstrations = dict(rollout_type="ppo-huggingface") + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=64, + clip_range=0.1, + ent_coef=5.167107294612664e-08, + gae_lambda=0.95, + gamma=0.999, + learning_rate=0.000414936134792374, + max_grad_norm=2, + n_epochs=5, + # policy_kwargs are same as the defaults + vf_coef=0.6162112311062333, + ), + ) @train_adversarial_ex.named_config -def reacher(): - environment = dict(gym_id="Reacher-v2") - algorithm_kwargs = {"allow_variable_horizon": True} +def seals_walker(): + environment = dict(gym_id="seals/Walker2d-v0") + demonstrations = dict(rollout_type="ppo-huggingface") + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + rl = dict( + batch_size=8192, + rl_kwargs=dict( + batch_size=128, + clip_range=0.4, + ent_coef=0.00013057334805552262, + gae_lambda=0.92, + gamma=0.98, + learning_rate=0.000138575372312869, + max_grad_norm=0.6, + n_epochs=20, + # policy_kwargs are same as the defaults + vf_coef=0.6167177795726859, + ), + ) @train_adversarial_ex.named_config -def seals_swimmer(): +def seals_humanoid(): locals().update(**MUJOCO_SHARED_LOCALS) - environment = dict(gym_id="seals/Swimmer-v0") - total_timesteps = int(2e6) + environment = dict(gym_id="seals/Humanoid-v0") + total_timesteps = int(4e6) @train_adversarial_ex.named_config -def seals_walker(): - locals().update(**MUJOCO_SHARED_LOCALS) - environment = dict(gym_id="seals/Walker2d-v0") +def reacher(): + environment = dict(gym_id="Reacher-v2") + algorithm_kwargs = {"allow_variable_horizon": True} # Debug configs diff --git a/src/imitation/scripts/config/train_imitation.py b/src/imitation/scripts/config/train_imitation.py index 88bc4888c..4f3a8a415 100644 --- a/src/imitation/scripts/config/train_imitation.py +++ b/src/imitation/scripts/config/train_imitation.py @@ -70,6 +70,8 @@ def ant(): @train_imitation_ex.named_config def seals_ant(): environment = dict(gym_id="seals/Ant-v0") + demonstrations = dict(rollout_type="ppo-huggingface") + expert = {"policy_type": "ppo-huggingface"} @train_imitation_ex.named_config @@ -84,6 +86,29 @@ def seals_half_cheetah(): environment = dict(gym_id="seals/HalfCheetah-v0") bc = dict(l2_weight=0.0) dagger = dict(total_timesteps=60000) + demonstrations = dict(rollout_type="ppo-huggingface") + expert = {"policy_type": "ppo-huggingface"} + + +@train_imitation_ex.named_config +def seals_hopper(): + environment = dict(gym_id="seals/Hopper-v0") + demonstrations = dict(rollout_type="ppo-huggingface") + expert = {"policy_type": "ppo-huggingface"} + + +@train_imitation_ex.named_config +def seals_swimmer(): + environment = dict(gym_id="seals/Swimmer-v0") + demonstrations = dict(rollout_type="ppo-huggingface") + expert = {"policy_type": "ppo-huggingface"} + + +@train_imitation_ex.named_config +def seals_walker(): + environment = dict(gym_id="seals/Walker2d-v0") + demonstrations = dict(rollout_type="ppo-huggingface") + expert = {"policy_type": "ppo-huggingface"} @train_imitation_ex.named_config diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 28890bf33..4d8531732 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -1,12 +1,17 @@ """Configuration for imitation.scripts.train_preference_comparisons.""" import sacred +from torch import nn from imitation.algorithms import preference_comparisons from imitation.scripts.ingredients import environment from imitation.scripts.ingredients import logging as logging_ingredient from imitation.scripts.ingredients import policy_evaluation, reward, rl +# Note: All the hyperparameter configs in the file are of the tuned +# hyperparameters of the RL algorithm of the respective environment. +# Taken from imitation/scripts/config/train_rl.py + train_preference_comparisons_ex = sacred.Experiment( "train_preference_comparisons", ingredients=[ @@ -72,9 +77,22 @@ def cartpole(): @train_preference_comparisons_ex.named_config def seals_ant(): - locals().update(**MUJOCO_SHARED_LOCALS) - locals().update(**ANT_SHARED_LOCALS) environment = dict(gym_id="seals/Ant-v0") + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=16, + clip_range=0.3, + ent_coef=3.1441389214159857e-06, + gae_lambda=0.8, + gamma=0.995, + learning_rate=0.00017959211641976886, + max_grad_norm=0.9, + n_epochs=10, + # policy_kwargs are same as the defaults + vf_coef=0.4351450387648799, + ), + ) @train_preference_comparisons_ex.named_config @@ -84,10 +102,105 @@ def half_cheetah(): rl = dict(batch_size=16384, rl_kwargs=dict(batch_size=1024)) +@train_preference_comparisons_ex.named_config +def seals_half_cheetah(): + environment = dict(gym_id="seals/HalfCheetah-v0") + rl = dict( + batch_size=512, + rl_kwargs=dict( + batch_size=64, + clip_range=0.1, + ent_coef=3.794797423594763e-06, + gae_lambda=0.95, + gamma=0.95, + learning_rate=0.0003286871805949382, + max_grad_norm=0.8, + n_epochs=5, + vf_coef=0.11483689492120866, + ), + ) + num_iterations = 50 + total_timesteps = 20000000 + + @train_preference_comparisons_ex.named_config def seals_hopper(): - locals().update(**MUJOCO_SHARED_LOCALS) environment = dict(gym_id="seals/Hopper-v0") + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=512, + clip_range=0.1, + ent_coef=0.0010159833764878474, + gae_lambda=0.98, + gamma=0.995, + learning_rate=0.0003904770450788824, + max_grad_norm=0.9, + n_epochs=20, + vf_coef=0.20315938606555833, + ), + ) + + +@train_preference_comparisons_ex.named_config +def seals_swimmer(): + environment = dict(gym_id="seals/Swimmer-v0") + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=64, + clip_range=0.1, + ent_coef=5.167107294612664e-08, + gae_lambda=0.95, + gamma=0.999, + learning_rate=0.000414936134792374, + max_grad_norm=2, + n_epochs=5, + # policy_kwargs are same as the defaults + vf_coef=0.6162112311062333, + ), + ) + + +@train_preference_comparisons_ex.named_config +def seals_walker(): + environment = dict(gym_id="seals/Walker2d-v0") + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + rl = dict( + batch_size=8192, + rl_kwargs=dict( + batch_size=128, + clip_range=0.4, + ent_coef=0.00013057334805552262, + gae_lambda=0.92, + gamma=0.98, + learning_rate=0.000138575372312869, + max_grad_norm=0.6, + n_epochs=20, + # policy_kwargs are same as the defaults + vf_coef=0.6167177795726859, + ), + ) @train_preference_comparisons_ex.named_config diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index d24d9492d..e4ab71da1 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -1,11 +1,18 @@ """Configuration settings for train_rl, training a policy with RL.""" + import sacred +from torch import nn from imitation.scripts.ingredients import environment from imitation.scripts.ingredients import logging as logging_ingredient from imitation.scripts.ingredients import policy_evaluation, rl +# Note: All the hyperparameter configs in the file are tuned +# for the PPO algorithm on the respective environment using the +# RL Baselines Zoo library: +# https://github.com/HumanCompatibleAI/rl-baselines3-zoo/ + train_rl_ex = sacred.Experiment( "train_rl", ingredients=[ @@ -70,8 +77,30 @@ def cartpole(): @train_rl_ex.named_config def seals_cartpole(): - environment = dict(gym_id="seals/CartPole-v0") - total_timesteps = int(1e6) + environment = dict(gym_id="seals/CartPole-v0", num_vec=8) + total_timesteps = int(1e5) + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + normalize_reward = False + rl = dict( + batch_size=4096, + rl_kwargs=dict( + batch_size=256, + clip_range=0.4, + ent_coef=0.008508727919228772, + gae_lambda=0.9, + gamma=0.9999, + learning_rate=0.0012403278189645594, + max_grad_norm=0.8, + n_epochs=10, + vf_coef=0.489343896591493, + ), + ) @train_rl_ex.named_config @@ -80,9 +109,69 @@ def half_cheetah(): total_timesteps = int(5e6) # does OK after 1e6, but continues improving +@train_rl_ex.named_config +def seals_half_cheetah(): + environment = dict( + gym_id="seals/HalfCheetah-v0", + num_vec=1, + ) + + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.Tanh, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + # total_timesteps = int(5e6) # does OK after 1e6, but continues improving + total_timesteps = 1e6 + normalize_reward = False + + rl = dict( + batch_size=512, + rl_kwargs=dict( + batch_size=64, + clip_range=0.1, + ent_coef=3.794797423594763e-06, + gae_lambda=0.95, + gamma=0.95, + learning_rate=0.0003286871805949382, + max_grad_norm=0.8, + n_epochs=5, + vf_coef=0.11483689492120866, + ), + ) + + @train_rl_ex.named_config def seals_hopper(): - environment = dict(gym_id="seals/Hopper-v0") + environment = dict(gym_id="seals/Hopper-v0", num_vec=1) + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + + total_timesteps = 1e6 + normalize_reward = False + + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=512, + clip_range=0.1, + ent_coef=0.0010159833764878474, + gae_lambda=0.98, + gamma=0.995, + learning_rate=0.0003904770450788824, + max_grad_norm=0.9, + n_epochs=20, + # policy_kwargs are same as the defaults + vf_coef=0.20315938606555833, + ), + ) @train_rl_ex.named_config @@ -122,17 +211,99 @@ def reacher(): @train_rl_ex.named_config def seals_ant(): - environment = dict(gym_id="seals/Ant-v0") + environment = dict( + gym_id="seals/Ant-v0", + num_vec=1, + ) + + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.Tanh, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + + total_timesteps = 1e6 + normalize_reward = False + + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=16, + clip_range=0.3, + ent_coef=3.1441389214159857e-06, + gae_lambda=0.8, + gamma=0.995, + learning_rate=0.00017959211641976886, + max_grad_norm=0.9, + n_epochs=10, + # policy_kwargs are same as the defaults + vf_coef=0.4351450387648799, + ), + ) @train_rl_ex.named_config def seals_swimmer(): - environment = dict(gym_id="seals/Swimmer-v0") + environment = dict(gym_id="seals/Swimmer-v0", num_vec=1) + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + + total_timesteps = 1e6 + normalize_reward = False + + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=64, + clip_range=0.1, + ent_coef=5.167107294612664e-08, + gae_lambda=0.95, + gamma=0.999, + learning_rate=0.000414936134792374, + max_grad_norm=2, + n_epochs=5, + # policy_kwargs are same as the defaults + vf_coef=0.6162112311062333, + ), + ) @train_rl_ex.named_config def seals_walker(): - environment = dict(gym_id="seals/Walker2d-v0") + environment = dict(gym_id="seals/Walker2d-v0", num_vec=1) + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + + total_timesteps = 1e6 + normalize_reward = False + + rl = dict( + batch_size=8192, + rl_kwargs=dict( + batch_size=128, + clip_range=0.4, + ent_coef=0.00013057334805552262, + gae_lambda=0.92, + gamma=0.98, + learning_rate=0.000138575372312869, + max_grad_norm=0.6, + n_epochs=20, + # policy_kwargs are same as the defaults + vf_coef=0.6167177795726859, + ), + ) # Debug configs diff --git a/src/imitation/scripts/ingredients/reward.py b/src/imitation/scripts/ingredients/reward.py index 2e2b67022..6b2e0195e 100644 --- a/src/imitation/scripts/ingredients/reward.py +++ b/src/imitation/scripts/ingredients/reward.py @@ -46,6 +46,11 @@ def normalize_output_running(): normalize_output_layer = networks.RunningNorm # noqa: F841 +@reward_ingredient.named_config +def normalize_output_ema(): + normalize_output_layer = networks.EMANorm # noqa: F841 + + @reward_ingredient.named_config def reward_ensemble(): net_cls = reward_nets.RewardEnsemble diff --git a/src/imitation/scripts/parallel.py b/src/imitation/scripts/parallel.py index 6014a08b6..38881ee2b 100644 --- a/src/imitation/scripts/parallel.py +++ b/src/imitation/scripts/parallel.py @@ -2,12 +2,15 @@ import collections.abc import copy +import glob import pathlib -from typing import Any, Callable, Dict, Mapping, Optional, Sequence +from typing import Any, Callable, Dict, Mapping, Sequence import ray import ray.tune import sacred +from ray.tune import search +from ray.tune.search import optuna from sacred.observers import FileStorageObserver from imitation.scripts.config.parallel import parallel_ex @@ -17,29 +20,33 @@ def parallel( sacred_ex_name: str, run_name: str, + num_samples: int, search_space: Mapping[str, Any], base_named_configs: Sequence[str], base_config_updates: Mapping[str, Any], resources_per_trial: Mapping[str, Any], init_kwargs: Mapping[str, Any], - local_dir: Optional[str], - upload_dir: Optional[str], -) -> None: + repeat: int, + experiment_checkpoint_path: str, + tune_run_kwargs: Dict[str, Any], +) -> ray.tune.ExperimentAnalysis: """Parallelize multiple runs of another Sacred Experiment using Ray Tune. A Sacred FileObserver is attached to the inner experiment and writes Sacred logs to "{RAY_LOCAL_DIR}/sacred/". These files are automatically copied over - to `upload_dir` if that argument is provided. + to `upload_dir` if that argument is provided in `tune_run_kwargs`. Args: sacred_ex_name: The Sacred experiment to tune. Either "train_rl" or - "train_adversarial". + "train_imitation" or "train_adversarial" or "train_preference_comparisons". run_name: A name describing this parallelizing experiment. This argument is also passed to `ray.tune.run` as the `name` argument. It is also saved in 'sacred/run.json' of each inner Sacred experiment under the 'experiment.name' key. This is equivalent to using the Sacred CLI '--name' option on the inner experiment. Offline analysis jobs can use this argument to group similar data. + num_samples: Number of times to sample from the hyperparameter space without + considering repetition using `repeat`. search_space: A dictionary which can contain Ray Tune search objects like `ray.tune.grid_search` and `ray.tune.sample_from`, and is passed as the `config` argument to `ray.tune.run()`. After the @@ -60,11 +67,22 @@ def parallel( generated Ray directory name, unlike config updates from `search_space`. resources_per_trial: Argument to `ray.tune.run()`. init_kwargs: Arguments to pass to `ray.init`. - local_dir: `local_dir` argument to `ray.tune.run()`. - upload_dir: `upload_dir` argument to `ray.tune.run()`. + repeat: Number of runs to repeat each trial for. + If `repeat` > 1, then optuna is used as the default search algorithm + unless specified otherwise in `tune_run_kwargs`. + experiment_checkpoint_path: Path containing the checkpoints of a previous + experiment ran using this script. Useful for evaluating the best trial + of the experiment. + tune_run_kwargs: Other arguments to pass to `ray.tune.run()`. Raises: TypeError: Named configs not string sequences or config updates not mappings. + ValueError: `repeat` > 1 but `search_alg` is not an instance of + `ray.tune.search.SearchAlgorithm`. + + Returns: + The result of running the parallel experiment with `ray.tune.run()`. + Useful for fetching the configs and results dataframe of all the trials. """ # Basic validation for config options before we enter parallel jobs. if not isinstance(base_named_configs, collections.abc.Sequence): @@ -95,15 +113,45 @@ def parallel( ) ray.init(**init_kwargs) + updated_tune_run_kwargs = copy.deepcopy(tune_run_kwargs) + if repeat > 1: + try: + # Use optuna as the default search algorithm for repeat runs. + algo = tune_run_kwargs.get("search_alg", optuna.OptunaSearch()) + updated_tune_run_kwargs["search_alg"] = search.Repeater(algo, repeat) + except AttributeError as e: + raise ValueError( + "repeat > 1 but search_alg is not an instance of " + "ray.tune.search.SearchAlgorithm", + ) from e + + if sacred_ex_name == "train_rl": + return_key = "monitor_return_mean" + else: + return_key = "imit_stats/monitor_return_mean" + try: - ray.tune.run( - trainable, - config=search_space, - name=run_name, - local_dir=local_dir, - resources_per_trial=resources_per_trial, - sync_config=ray.tune.syncer.SyncConfig(upload_dir=upload_dir), - ) + if experiment_checkpoint_path: + # load experiment analysis results + result = ray.tune.ExperimentAnalysis(experiment_checkpoint_path) + result._load_checkpoints_from_latest( + glob.glob(experiment_checkpoint_path + "/experiment_state*.json"), + ) + # update result.trials using all the experiment_state json files + result.trials = None + result.fetch_trial_dataframes() + else: + result = ray.tune.run( + trainable, + config=search_space, + num_samples=num_samples * repeat, + name=run_name, + resources_per_trial=resources_per_trial, + metric=return_key, + mode="max", + **updated_tune_run_kwargs, + ) + return result finally: ray.shutdown() @@ -113,7 +161,7 @@ def _ray_tune_sacred_wrapper( run_name: str, base_named_configs: list, base_config_updates: Mapping[str, Any], -) -> Callable[[Mapping[str, Any], Any], Mapping[str, Any]]: +) -> Callable[[Dict[str, Any], Any], Mapping[str, Any]]: """From an Experiment build a wrapped run function suitable for Ray Tune. `ray.tune.run(...)` expects a trainable function that takes a dict @@ -164,16 +212,22 @@ def inner(config: Mapping[str, Any], reporter) -> Mapping[str, Any]: # TODO(shwang): Stop modifying CAPTURE_MODE once the issue is fixed. sacred.SETTINGS.CAPTURE_MODE = "sys" - run_kwargs = config + run_kwargs = dict(**config) updated_run_kwargs: Dict[str, Any] = {} # Import inside function rather than in module because Sacred experiments # are not picklable, and Ray requires this function to be picklable. from imitation.scripts.train_adversarial import train_adversarial_ex + from imitation.scripts.train_imitation import train_imitation_ex + from imitation.scripts.train_preference_comparisons import ( + train_preference_comparisons_ex, + ) from imitation.scripts.train_rl import train_rl_ex experiments = { "train_rl": train_rl_ex, "train_adversarial": train_adversarial_ex, + "train_imitation": train_imitation_ex, + "train_preference_comparisons": train_preference_comparisons_ex, } ex = experiments[sacred_ex_name] @@ -181,23 +235,23 @@ def inner(config: Mapping[str, Any], reporter) -> Mapping[str, Any]: named_configs = base_named_configs + run_kwargs["named_configs"] updated_run_kwargs["named_configs"] = named_configs - config_updates = {**base_config_updates, **run_kwargs["config_updates"]} + config_updates: Dict[str, Any] = {} + config_updates.update(base_config_updates) + config_updates.update(run_kwargs["config_updates"]) + # for repeat runs, set the seed using their trial index + if "__trial_index__" in run_kwargs: + config_updates.update(seed=run_kwargs.pop("__trial_index__")) updated_run_kwargs["config_updates"] = config_updates # Add other run_kwargs items to updated_run_kwargs. for k, v in run_kwargs.items(): if k not in updated_run_kwargs: updated_run_kwargs[k] = v - run = ex.run( **updated_run_kwargs, options={"--run": run_name, "--file_storage": "sacred"}, ) - # Ray Tune has a string formatting error if raylet completes without - # any calls to `reporter`. - reporter(done=True) - assert run.status == "COMPLETED" return run.result diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 58dae3484..c47ed29bd 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -131,6 +131,7 @@ def dagger( expert_policy=expert_policy, custom_logger=custom_logger, bc_trainer=bc_trainer, + beta_schedule=dagger["beta_schedule"], rng=_rnd, ) diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 79ee4c136..867a666a4 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -280,7 +280,7 @@ def save_callback(iteration_num): # Storing and evaluating policy only useful if we generated trajectory data if bool(trajectory_path is None): results = dict(results) - results["rollout"] = policy_evaluation.eval_policy(agent, venv) + results["imit_stats"] = policy_evaluation.eval_policy(agent, venv) if save_preferences: main_trainer.dataset.save(log_dir / "preferences.pkl") diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 96d35122c..6780a557b 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -158,7 +158,8 @@ def train_rl( policies_serialize.save_stable_model(output_dir, rl_algo) # Final evaluation of expert policy. - return policy_evaluation.eval_policy(rl_algo, venv) + eval_stats = policy_evaluation.eval_policy(rl_algo, venv) + return eval_stats def main_console(): diff --git a/tests/algorithms/test_adversarial.py b/tests/algorithms/test_adversarial.py index d3609efaa..3a53e35ca 100644 --- a/tests/algorithms/test_adversarial.py +++ b/tests/algorithms/test_adversarial.py @@ -231,8 +231,9 @@ def test_train_gen_train_disc_no_crash( trainer_parametrized: common.AdversarialTrainer, n_updates: int = 2, ) -> None: - trainer_parametrized.train_gen(n_updates * trainer_parametrized.gen_train_timesteps) - trainer_parametrized.train_disc() + trainer_parametrized.train_gen_with_disc( + n_updates * trainer_parametrized.gen_train_timesteps, + ) @pytest.fixture diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 124667eca..a44639cef 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -824,10 +824,10 @@ def test_train_rl_cnn_policy(tmpdir: str, rng): dict( sacred_ex_name="train_rl", base_named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["rl"], - n_seeds=2, + repeat=2, search_space={ "config_updates": { - "rl": {"rl_kwargs": {"learning_rate": tune.grid_search([3e-4, 1e-4])}}, + "rl": {"rl_kwargs": {"learning_rate": tune.choice([3e-4, 1e-4])}}, }, "meta_info": {"asdf": "I exist for coverage purposes"}, }, @@ -840,7 +840,8 @@ def test_train_rl_cnn_policy(tmpdir: str, rng): "demonstrations.path": CARTPOLE_TEST_ROLLOUT_PATH.absolute(), }, search_space={ - "command_name": tune.grid_search(["gail", "airl"]), + "command_name": "airl", + "config_updates": {"total_timesteps": tune.choice([5, 10])}, }, ), ] @@ -919,13 +920,16 @@ def test_parallel_train_adversarial_custom_env(tmpdir): config_updates = dict( sacred_ex_name="train_adversarial", - n_seeds=1, + repeat=2, base_named_configs=[env_named_config] + ALGO_FAST_CONFIGS["adversarial"], base_config_updates=dict( logging=dict(log_root=tmpdir), demonstrations=dict(path=path), ), - search_space=dict(command_name="gail"), + # specifying repeat=2 uses the optuna search algorithm which + # requires the search space to be non-empty. So we provide + # the command name using tune.choice. + search_space=dict(command_name=tune.choice(["gail"])), ) config_updates.update(PARALLEL_CONFIG_LOW_RESOURCE) run = parallel.parallel_ex.run(config_updates=config_updates) @@ -977,7 +981,7 @@ def test_analyze_imitation(tmpdir: str, run_names: List[str], run_sacred_fn): assert run.status == "COMPLETED" # Check that analyze script finds the correct number of logs. - def check(run_name: Optional[str], count: int) -> None: + def check(run_name: Optional[str], count: int, table_verbosity=1) -> None: run = analyze.analysis_ex.run( command_name="analyze_imitation", config_updates=dict( @@ -987,6 +991,7 @@ def check(run_name: Optional[str], count: int) -> None: csv_output_path=tmpdir_path / "analysis.csv", tex_output_path=tmpdir_path / "analysis.tex", print_table=True, + table_verbosity=table_verbosity, ), ) assert run.status == "COMPLETED" @@ -996,15 +1001,19 @@ def check(run_name: Optional[str], count: int) -> None: for run_name, count in Counter(run_names).items(): check(run_name, count) - check(None, len(run_names)) # Check total number of logs. + check(None, len(run_names), table_verbosity=3) # Check total number of logs. def test_analyze_gather_tb(tmpdir: str): if os.name == "nt": # pragma: no cover pytest.skip("gather_tb uses symlinks: not supported by Windows") - - config_updates: Dict[str, Any] = dict(local_dir=tmpdir, run_name="test") + num_runs = 2 + config_updates: Dict[str, Any] = dict( + tune_run_kwargs=dict(local_dir=tmpdir), + run_name="test", + ) config_updates.update(PARALLEL_CONFIG_LOW_RESOURCE) + config_updates.update(num_samples=num_runs) parallel_run = parallel.parallel_ex.run( named_configs=["generate_test_data"], config_updates=config_updates, @@ -1019,7 +1028,7 @@ def test_analyze_gather_tb(tmpdir: str): ) assert run.status == "COMPLETED" assert isinstance(run.result, dict) - assert run.result["n_tb_dirs"] == 2 + assert run.result["n_tb_dirs"] == num_runs def test_pickle_fmt_rollout_test_data_is_pickle(): diff --git a/tests/test_benchmarking.py b/tests/test_benchmarking.py index ba01b38a2..18d4f12cf 100644 --- a/tests/test_benchmarking.py +++ b/tests/test_benchmarking.py @@ -1,5 +1,7 @@ """Tests for config files in benchmarking/ folder.""" import pathlib +import subprocess +import sys import pytest @@ -35,7 +37,7 @@ def test_benchmarks_print_config_succeeds(algorithm: str, environment: str): config_name = f"{algorithm}_{environment}" config_file = str( - BENCHMARKING_DIR / f"example_{algorithm}_{environment}_best_hp_eval.json", + BENCHMARKING_DIR / f"{algorithm}_{environment}_best_hp_eval.json", ) # WHEN @@ -44,3 +46,28 @@ def test_benchmarks_print_config_succeeds(algorithm: str, environment: str): # THEN assert run.status == "COMPLETED" + + +@pytest.mark.parametrize("algorithm", ALGORITHMS) +def test_tuning_print_config_succeeds(algorithm: str): + # We test the configs using the print_config command, + # because running the configs requires MuJoCo. + # Requiring MuJoCo to run the tests adds too much complexity. + + # We need to use sys.executable, not just "python", on Windows as + # subprocess.call ignores PATH (unless shell=True) so runs a + # system-wide Python interpreter outside of our venv. See: + # https://stackoverflow.com/questions/5658622/ + tuning_path = str(BENCHMARKING_DIR / "tuning.py") + env = 'parallel_run_config.base_named_configs=["seals_cartpole"]' + exit_code = subprocess.call( + [ + sys.executable, + tuning_path, + "print_config", + "with", + f"{algorithm}", + env, + ], + ) + assert exit_code == 0 diff --git a/tests/test_experiments.py b/tests/test_experiments.py index 0f6d314fe..b2417a9f9 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -245,13 +245,13 @@ def test_commands_hofvarpnir_config_with_special_characters_in_flags(tmpdir): def test_commands_bc_config(): if os.name == "nt": # pragma: no cover pytest.skip("commands.py not ported to Windows.") - cfg_pattern = _get_benchmarking_path("example_bc_seals_ant_best_hp_eval.json") + cfg_pattern = _get_benchmarking_path("bc_seals_ant_best_hp_eval.json") commands = _run_commands_from_flags(cfg_pattern=cfg_pattern) assert len(commands) == 1 expected = """python -m imitation.scripts.train_imitation bc \ --capture=sys --name=run0 --file_storage=output/sacred/\ -$USER-cmd-run0-bc-0-138a1475 \ -with benchmarking/example_bc_seals_ant_best_hp_eval.json \ +$USER-cmd-run0-bc-0-78e5112a \ +with benchmarking/bc_seals_ant_best_hp_eval.json \ seed=0 logging.log_root=output""" assert commands[0] == expected @@ -259,13 +259,13 @@ def test_commands_bc_config(): def test_commands_dagger_config(): if os.name == "nt": # pragma: no cover pytest.skip("commands.py not ported to Windows.") - cfg_pattern = _get_benchmarking_path("example_dagger_seals_ant_best_hp_eval.json") + cfg_pattern = _get_benchmarking_path("dagger_seals_ant_best_hp_eval.json") commands = _run_commands_from_flags(cfg_pattern=cfg_pattern) assert len(commands) == 1 expected = """python -m imitation.scripts.train_imitation dagger \ --capture=sys --name=run0 --file_storage=output/sacred/\ -$USER-cmd-run0-dagger-0-6a49161a \ -with benchmarking/example_dagger_seals_ant_best_hp_eval.json \ +$USER-cmd-run0-dagger-0-c27812cf \ +with benchmarking/dagger_seals_ant_best_hp_eval.json \ seed=0 logging.log_root=output""" assert commands[0] == expected @@ -273,13 +273,13 @@ def test_commands_dagger_config(): def test_commands_gail_config(): if os.name == "nt": # pragma: no cover pytest.skip("commands.py not ported to Windows.") - cfg_pattern = _get_benchmarking_path("example_gail_seals_ant_best_hp_eval.json") + cfg_pattern = _get_benchmarking_path("gail_seals_ant_best_hp_eval.json") commands = _run_commands_from_flags(cfg_pattern=cfg_pattern) assert len(commands) == 1 expected = """python -m imitation.scripts.train_adversarial gail \ --capture=sys --name=run0 --file_storage=output/sacred/\ -$USER-cmd-run0-gail-0-3ec8154d \ -with benchmarking/example_gail_seals_ant_best_hp_eval.json \ +$USER-cmd-run0-gail-0-9d8d1202 \ +with benchmarking/gail_seals_ant_best_hp_eval.json \ seed=0 logging.log_root=output""" assert commands[0] == expected @@ -287,13 +287,13 @@ def test_commands_gail_config(): def test_commands_airl_config(): if os.name == "nt": # pragma: no cover pytest.skip("commands.py not ported to Windows.") - cfg_pattern = _get_benchmarking_path("example_airl_seals_ant_best_hp_eval.json") + cfg_pattern = _get_benchmarking_path("airl_seals_ant_best_hp_eval.json") commands = _run_commands_from_flags(cfg_pattern=cfg_pattern) assert len(commands) == 1 expected = """python -m imitation.scripts.train_adversarial airl \ --capture=sys --name=run0 \ ---file_storage=output/sacred/$USER-cmd-run0-airl-0-400e1558 \ -with benchmarking/example_airl_seals_ant_best_hp_eval.json \ +--file_storage=output/sacred/$USER-cmd-run0-airl-0-9ed3120d \ +with benchmarking/airl_seals_ant_best_hp_eval.json \ seed=0 logging.log_root=output""" assert commands[0] == expected