|
7 | 7 | "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8a_train_sqil_sac.ipynb)\n", |
8 | 8 | "# Train an Agent using Soft Q Imitation Learning with SAC\n", |
9 | 9 | "\n", |
10 | | - "In the previous tutorial, we used Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) on top of the DQN base algorithm. In fact, SQIL can be combined with any off-policy algorithm from `stable_baselines3`. Here, we train a HalfCheetah agent using SQIL + SAC." |
| 10 | + "In the previous tutorial, we used Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) on top of the DQN base algorithm. In fact, SQIL can be combined with any off-policy algorithm from `stable_baselines3`. Here, we train a Pendulum agent using SQIL + SAC." |
11 | 11 | ] |
12 | 12 | }, |
13 | 13 | { |
14 | 14 | "cell_type": "markdown", |
15 | 15 | "metadata": {}, |
16 | 16 | "source": [ |
17 | | - "First, we need some expert trajectories in our environment (`seals/HalfCheetah-v0`).\n", |
| 17 | + "First, we need some expert trajectories in our environment (`Pendulum-v1`).\n", |
18 | 18 | "Note that you can use other environments, but the action space must be continuous." |
19 | 19 | ] |
20 | 20 | }, |
|
28 | 28 | "from imitation.data import huggingface_utils\n", |
29 | 29 | "\n", |
30 | 30 | "# Download some expert trajectories from the HuggingFace Datasets Hub.\n", |
31 | | - "dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-seals-HalfCheetah-v0\")\n", |
| 31 | + "dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-Pendulum-v1\")\n", |
32 | 32 | "\n", |
33 | 33 | "# Convert the dataset to a format usable by the imitation library.\n", |
34 | 34 | "expert_trajectories = huggingface_utils.TrajectoryDatasetSequence(dataset[\"train\"])" |
|
75 | 75 | "from imitation.util.util import make_vec_env\n", |
76 | 76 | "import numpy as np\n", |
77 | 77 | "from stable_baselines3 import sac\n", |
78 | | - "import seals # noqa: F401 # needed to load \"seals/\" environments\n", |
79 | 78 | "\n", |
80 | 79 | "SEED = 42\n", |
81 | 80 | "\n", |
82 | 81 | "venv = make_vec_env(\n", |
83 | | - " \"seals/HalfCheetah-v1\",\n", |
| 82 | + " \"Pendulum-v1\",\n", |
84 | 83 | " rng=np.random.default_rng(seed=SEED),\n", |
85 | 84 | ")\n", |
86 | 85 | "\n", |
|
0 commit comments