Skip to content

Commit 8c2aea2

Browse files
authored
refactor a2c, acer, acktr, ppo2, deepq, and trpo_mpi (#490)
* exported rl-algs * more stuff from rl-algs * run slow tests * re-exported rl_algs * re-exported rl_algs - fixed problems with serialization test and test_cartpole * replaced atari_arg_parser with common_arg_parser * run.py can run algos from both baselines and rl_algs * added approximate humanoid reward with ppo2 into the README for reference * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * very dummy commit to RUN BENCHMARKS * serialize variables as a dict, not as a list * running_mean_std uses tensorflow variables * fixed import in vec_normalize * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * flake8 complaints * save all variables to make sure we save the vec_normalize normalization * benchmarks on ppo2 only RUN BENCHMARKS * make_atari_env compatible with mpi * run ppo_mpi benchmarks only RUN BENCHMARKS * hardcode names of retro environments * add defaults * changed default ppo2 lr schedule to linear RUN BENCHMARKS * non-tf normalization benchmark RUN BENCHMARKS * use ncpu=1 for mujoco sessions - gives a bit of a performance speedup * reverted running_mean_std to user property decorators for mean, var, count * reverted VecNormalize to use RunningMeanStd (no tf) * reverted VecNormalize to use RunningMeanStd (no tf) * profiling wip * use VecNormalize with regular RunningMeanStd * added acer runner (missing import) * flake8 complaints * added a note in README about TfRunningMeanStd and serialization of VecNormalize * dummy commit to RUN BENCHMARKS * merged benchmarks branch
1 parent 366f486 commit 8c2aea2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+2944
-1072
lines changed

.benchmark_pattern

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

.gitignore

-2
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,3 @@ src
3434
.cache
3535

3636
MUJOCO_LOG.TXT
37-
38-

.travis.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ install:
1010
- docker build . -t baselines-test
1111

1212
script:
13-
- flake8 --select=F baselines/common
14-
- docker run baselines-test pytest
13+
- flake8 --select=F,E999 baselines/common baselines/trpo_mpi baselines/ppo2 baselines/a2c baselines/deepq baselines/acer
14+
- docker run baselines-test pytest --runslow

Dockerfile

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
FROM ubuntu:16.04
22

3-
RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake
3+
RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv
44
ENV CODE_DIR /root/code
55
ENV VENV /root/venv
66

7-
COPY . $CODE_DIR/baselines
87
RUN \
98
pip install virtualenv && \
109
virtualenv $VENV --python=python3 && \
1110
. $VENV/bin/activate && \
12-
cd $CODE_DIR && \
13-
pip install --upgrade pip && \
14-
pip install -e baselines && \
15-
pip install pytest
11+
pip install --upgrade pip
1612

1713
ENV PATH=$VENV/bin:$PATH
14+
15+
COPY . $CODE_DIR/baselines
1816
WORKDIR $CODE_DIR/baselines
1917

18+
# Clean up pycache and pyc files
19+
RUN rm -rf __pycache__ && \
20+
find . -name "*.pyc" -delete && \
21+
pip install -e .[test]
22+
23+
2024
CMD /bin/bash

README.md

+55
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,60 @@ pip install pytest
6262
pytest
6363
```
6464

65+
## Subpackages
66+
67+
## Testing the installation
68+
All unit tests in baselines can be run using pytest runner:
69+
```
70+
pip install pytest
71+
pytest
72+
```
73+
74+
## Training models
75+
Most of the algorithms in baselines repo are used as follows:
76+
```bash
77+
python -m baselines.run --alg=<name of the algorithm> --env=<environment_id> [additional arguments]
78+
```
79+
### Example 1. PPO with MuJoCo Humanoid
80+
For instance, to train a fully-connected network controlling MuJoCo humanoid using a2c for 20M timesteps
81+
```bash
82+
python -m baselines.run --alg=a2c --env=Humanoid-v2 --network=mlp --num_timesteps=2e7
83+
```
84+
Note that for mujoco environments fully-connected network is default, so we can omit `--network=mlp`
85+
The hyperparameters for both network and the learning algorithm can be controlled via the command line, for instance:
86+
```bash
87+
python -m baselines.run --alg=a2c --env=Humanoid-v2 --network=mlp --num_timesteps=2e7 --ent_coef=0.1 --num_hidden=32 --num_layers=3 --value_network=copy
88+
```
89+
will set entropy coeffient to 0.1, and construct fully connected network with 3 layers with 32 hidden units in each, and create a separate network for value function estimation (so that its parameters are not shared with the policy network, but the structure is the same)
90+
91+
See docstrings in [common/models.py](common/models.py) for description of network parameters for each type of model, and
92+
docstring for [baselines/ppo2/ppo2.py/learn()](ppo2/ppo2.py) fir the description of the ppo2 hyperparamters.
93+
94+
### Example 2. DQN on Atari
95+
DQN with Atari is at this point a classics of benchmarks. To run the baselines implementation of DQN on Atari Pong:
96+
```
97+
python -m baselines.run --alg=deepq --env=PongNoFrameskip-v4 --num_timesteps=1e6
98+
```
99+
100+
## Saving, loading and visualizing models
101+
The algorithms serialization API is not properly unified yet; however, there is a simple method to save / restore trained models.
102+
`--save_path` and `--load_path` command-line option loads the tensorflow state from a given path before training, and saves it after the training, respectively.
103+
Let's imagine you'd like to train ppo2 on Atari Pong, save the model and then later visualize what has it learnt.
104+
```bash
105+
python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4 --num-timesteps=2e7 --save_path=~/models/pong_20M_ppo2
106+
```
107+
This should get to the mean reward per episode about 5k. To load and visualize the model, we'll do the following - load the model, train it for 0 steps, and then visualize:
108+
```bash
109+
python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4 --num-timesteps=0 --load_path=~/models/pong_20M_ppo2 --play
110+
```
111+
112+
*NOTE:* At the moment Mujoco training uses VecNormalize wrapper for the environment which is not being saved correctly; so loading the models trained on Mujoco will not work well if the environment is recreated. If necessary, you can work around that by replacing RunningMeanStd by TfRunningMeanStd in [baselines/common/vec_env/vec_normalize.py](baselines/common/vec_env/vec_normalize.py#L12). This way, mean and std of environment normalizing wrapper will be saved in tensorflow variables and included in the model file; however, training is slower that way - hence not including it by default
113+
114+
115+
116+
117+
118+
65119
## Subpackages
66120

67121
- [A2C](baselines/a2c)
@@ -85,3 +139,4 @@ To cite this repository in publications:
85139
journal = {GitHub repository},
86140
howpublished = {\url{https://github.com/openai/baselines}},
87141
}
142+

baselines/a2c/a2c.py

+99-81
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,48 @@
1-
import os.path as osp
21
import time
3-
import joblib
4-
import numpy as np
2+
import functools
53
import tensorflow as tf
4+
65
from baselines import logger
76

87
from baselines.common import set_global_seeds, explained_variance
9-
from baselines.common.runners import AbstractEnvRunner
108
from baselines.common import tf_util
9+
from baselines.common.policies import build_policy
10+
1111

12-
from baselines.a2c.utils import discount_with_dones
13-
from baselines.a2c.utils import Scheduler, make_path, find_trainable_variables
14-
from baselines.a2c.utils import cat_entropy, mse
12+
from baselines.a2c.utils import Scheduler, find_trainable_variables
13+
from baselines.a2c.runner import Runner
14+
15+
from tensorflow import losses
1516

1617
class Model(object):
1718

18-
def __init__(self, policy, ob_space, ac_space, nenvs, nsteps,
19+
def __init__(self, policy, env, nsteps,
1920
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
2021
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):
2122

22-
sess = tf_util.make_session()
23+
sess = tf_util.get_session()
24+
nenvs = env.num_envs
2325
nbatch = nenvs*nsteps
2426

25-
A = tf.placeholder(tf.int32, [nbatch])
27+
28+
with tf.variable_scope('a2c_model', reuse=tf.AUTO_REUSE):
29+
step_model = policy(nenvs, 1, sess)
30+
train_model = policy(nbatch, nsteps, sess)
31+
32+
A = tf.placeholder(train_model.action.dtype, train_model.action.shape)
2633
ADV = tf.placeholder(tf.float32, [nbatch])
2734
R = tf.placeholder(tf.float32, [nbatch])
2835
LR = tf.placeholder(tf.float32, [])
2936

30-
step_model = policy(sess, ob_space, ac_space, nenvs, 1, reuse=False)
31-
train_model = policy(sess, ob_space, ac_space, nenvs*nsteps, nsteps, reuse=True)
37+
neglogpac = train_model.pd.neglogp(A)
38+
entropy = tf.reduce_mean(train_model.pd.entropy())
3239

33-
neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.pi, labels=A)
3440
pg_loss = tf.reduce_mean(ADV * neglogpac)
35-
vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.vf), R))
36-
entropy = tf.reduce_mean(cat_entropy(train_model.pi))
41+
vf_loss = losses.mean_squared_error(tf.squeeze(train_model.vf), R)
42+
3743
loss = pg_loss - entropy*ent_coef + vf_loss * vf_coef
3844

39-
params = find_trainable_variables("model")
45+
params = find_trainable_variables("a2c_model")
4046
grads = tf.gradients(loss, params)
4147
if max_grad_norm is not None:
4248
grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
@@ -50,6 +56,7 @@ def train(obs, states, rewards, masks, actions, values):
5056
advs = rewards - values
5157
for step in range(len(obs)):
5258
cur_lr = lr.value()
59+
5360
td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, LR:cur_lr}
5461
if states is not None:
5562
td_map[train_model.S] = states
@@ -60,84 +67,94 @@ def train(obs, states, rewards, masks, actions, values):
6067
)
6168
return policy_loss, value_loss, policy_entropy
6269

63-
def save(save_path):
64-
ps = sess.run(params)
65-
make_path(osp.dirname(save_path))
66-
joblib.dump(ps, save_path)
67-
68-
def load(load_path):
69-
loaded_params = joblib.load(load_path)
70-
restores = []
71-
for p, loaded_p in zip(params, loaded_params):
72-
restores.append(p.assign(loaded_p))
73-
sess.run(restores)
7470

7571
self.train = train
7672
self.train_model = train_model
7773
self.step_model = step_model
7874
self.step = step_model.step
7975
self.value = step_model.value
8076
self.initial_state = step_model.initial_state
81-
self.save = save
82-
self.load = load
77+
self.save = functools.partial(tf_util.save_variables, sess=sess)
78+
self.load = functools.partial(tf_util.load_variables, sess=sess)
8379
tf.global_variables_initializer().run(session=sess)
8480

85-
class Runner(AbstractEnvRunner):
86-
87-
def __init__(self, env, model, nsteps=5, gamma=0.99):
88-
super().__init__(env=env, model=model, nsteps=nsteps)
89-
self.gamma = gamma
90-
91-
def run(self):
92-
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
93-
mb_states = self.states
94-
for n in range(self.nsteps):
95-
actions, values, states, _ = self.model.step(self.obs, self.states, self.dones)
96-
mb_obs.append(np.copy(self.obs))
97-
mb_actions.append(actions)
98-
mb_values.append(values)
99-
mb_dones.append(self.dones)
100-
obs, rewards, dones, _ = self.env.step(actions)
101-
self.states = states
102-
self.dones = dones
103-
for n, done in enumerate(dones):
104-
if done:
105-
self.obs[n] = self.obs[n]*0
106-
self.obs = obs
107-
mb_rewards.append(rewards)
108-
mb_dones.append(self.dones)
109-
#batch of steps to batch of rollouts
110-
mb_obs = np.asarray(mb_obs, dtype=np.uint8).swapaxes(1, 0).reshape(self.batch_ob_shape)
111-
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
112-
mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0)
113-
mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0)
114-
mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
115-
mb_masks = mb_dones[:, :-1]
116-
mb_dones = mb_dones[:, 1:]
117-
last_values = self.model.value(self.obs, self.states, self.dones).tolist()
118-
#discount/bootstrap off value fn
119-
for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)):
120-
rewards = rewards.tolist()
121-
dones = dones.tolist()
122-
if dones[-1] == 0:
123-
rewards = discount_with_dones(rewards+[value], dones+[0], self.gamma)[:-1]
124-
else:
125-
rewards = discount_with_dones(rewards, dones, self.gamma)
126-
mb_rewards[n] = rewards
127-
mb_rewards = mb_rewards.flatten()
128-
mb_actions = mb_actions.flatten()
129-
mb_values = mb_values.flatten()
130-
mb_masks = mb_masks.flatten()
131-
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values
132-
133-
def learn(policy, env, seed, nsteps=5, total_timesteps=int(80e6), vf_coef=0.5, ent_coef=0.01, max_grad_norm=0.5, lr=7e-4, lrschedule='linear', epsilon=1e-5, alpha=0.99, gamma=0.99, log_interval=100):
81+
82+
def learn(
83+
network,
84+
env,
85+
seed=None,
86+
nsteps=5,
87+
total_timesteps=int(80e6),
88+
vf_coef=0.5,
89+
ent_coef=0.01,
90+
max_grad_norm=0.5,
91+
lr=7e-4,
92+
lrschedule='linear',
93+
epsilon=1e-5,
94+
alpha=0.99,
95+
gamma=0.99,
96+
log_interval=100,
97+
load_path=None,
98+
**network_kwargs):
99+
100+
'''
101+
Main entrypoint for A2C algorithm. Train a policy with given network architecture on a given environment using a2c algorithm.
102+
103+
Parameters:
104+
-----------
105+
106+
network: policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
107+
specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
108+
tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
109+
neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
110+
See baselines.common/policies.py/lstm for more details on using recurrent nets in policies
111+
112+
113+
env: RL environment. Should implement interface similar to VecEnv (baselines.common/vec_env) or be wrapped with DummyVecEnv (baselines.common/vec_env/dummy_vec_env.py)
114+
115+
116+
seed: seed to make random number sequence in the alorightm reproducible. By default is None which means seed from system noise generator (not reproducible)
117+
118+
nsteps: int, number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
119+
nenv is number of environment copies simulated in parallel)
120+
121+
total_timesteps: int, total number of timesteps to train on (default: 80M)
122+
123+
vf_coef: float, coefficient in front of value function loss in the total loss function (default: 0.5)
124+
125+
ent_coef: float, coeffictiant in front of the policy entropy in the total loss function (default: 0.01)
126+
127+
max_gradient_norm: float, gradient is clipped to have global L2 norm no more than this value (default: 0.5)
128+
129+
lr: float, learning rate for RMSProp (current implementation has RMSProp hardcoded in) (default: 7e-4)
130+
131+
lrschedule: schedule of learning rate. Can be 'linear', 'constant', or a function [0..1] -> [0..1] that takes fraction of the training progress as input and
132+
returns fraction of the learning rate (specified as lr) as output
133+
134+
epsilon: float, RMSProp epsilon (stabilizes square root computation in denominator of RMSProp update) (default: 1e-5)
135+
136+
alpha: float, RMSProp decay parameter (default: 0.99)
137+
138+
gamma: float, reward discounting parameter (default: 0.99)
139+
140+
log_interval: int, specifies how frequently the logs are printed out (default: 100)
141+
142+
**network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
143+
For instance, 'mlp' network architecture has arguments num_hidden and num_layers.
144+
145+
'''
146+
147+
148+
134149
set_global_seeds(seed)
135150

136151
nenvs = env.num_envs
137-
ob_space = env.observation_space
138-
ac_space = env.action_space
139-
model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
152+
policy = build_policy(env, network, **network_kwargs)
153+
154+
model = Model(policy=policy, env=env, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
140155
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
156+
if load_path is not None:
157+
model.load(load_path)
141158
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
142159

143160
nbatch = nenvs*nsteps
@@ -158,3 +175,4 @@ def learn(policy, env, seed, nsteps=5, total_timesteps=int(80e6), vf_coef=0.5, e
158175
logger.dump_tabular()
159176
env.close()
160177
return model
178+

0 commit comments

Comments
 (0)