Skip to content

Commit 573b086

Browse files
authored
Change SQIL SAC to use Pendulum (#800)
1 parent cd76326 commit 573b086

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

docs/tutorials/8a_train_sqil_sac.ipynb

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
"[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8a_train_sqil_sac.ipynb)\n",
88
"# Train an Agent using Soft Q Imitation Learning with SAC\n",
99
"\n",
10-
"In the previous tutorial, we used Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) on top of the DQN base algorithm. In fact, SQIL can be combined with any off-policy algorithm from `stable_baselines3`. Here, we train a HalfCheetah agent using SQIL + SAC."
10+
"In the previous tutorial, we used Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) on top of the DQN base algorithm. In fact, SQIL can be combined with any off-policy algorithm from `stable_baselines3`. Here, we train a Pendulum agent using SQIL + SAC."
1111
]
1212
},
1313
{
1414
"cell_type": "markdown",
1515
"metadata": {},
1616
"source": [
17-
"First, we need some expert trajectories in our environment (`seals/HalfCheetah-v0`).\n",
17+
"First, we need some expert trajectories in our environment (`Pendulum-v1`).\n",
1818
"Note that you can use other environments, but the action space must be continuous."
1919
]
2020
},
@@ -28,7 +28,7 @@
2828
"from imitation.data import huggingface_utils\n",
2929
"\n",
3030
"# Download some expert trajectories from the HuggingFace Datasets Hub.\n",
31-
"dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-seals-HalfCheetah-v0\")\n",
31+
"dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-Pendulum-v1\")\n",
3232
"\n",
3333
"# Convert the dataset to a format usable by the imitation library.\n",
3434
"expert_trajectories = huggingface_utils.TrajectoryDatasetSequence(dataset[\"train\"])"
@@ -75,12 +75,11 @@
7575
"from imitation.util.util import make_vec_env\n",
7676
"import numpy as np\n",
7777
"from stable_baselines3 import sac\n",
78-
"import seals # noqa: F401 # needed to load \"seals/\" environments\n",
7978
"\n",
8079
"SEED = 42\n",
8180
"\n",
8281
"venv = make_vec_env(\n",
83-
" \"seals/HalfCheetah-v1\",\n",
82+
" \"Pendulum-v1\",\n",
8483
" rng=np.random.default_rng(seed=SEED),\n",
8584
")\n",
8685
"\n",

0 commit comments

Comments
 (0)