Skip to content

Commit 88371e1

Browse files
author
Jan Michelfeit
committed
#625 PebbleStateEntropyReward supports the initial phase before replay buffer is filled
1 parent f957baf commit 88371e1

File tree

2 files changed

+109
-67
lines changed

2 files changed

+109
-67
lines changed
Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1+
from enum import Enum, auto
12
from typing import Tuple
23

34
import numpy as np
45
import torch as th
5-
from gym.vector.utils import spaces
6-
from stable_baselines3.common.preprocessing import get_obs_shape
76

87
from imitation.policies.replay_buffer_wrapper import (
98
ReplayBufferView,
@@ -14,27 +13,53 @@
1413
from imitation.util.networks import RunningNorm
1514

1615

16+
class PebbleRewardPhase(Enum):
17+
"""States representing different behaviors for PebbleStateEntropyReward"""
18+
19+
# Collecting samples so that we have something for entropy calculation
20+
LEARNING_START = auto()
21+
# Entropy based reward
22+
UNSUPERVISED_EXPLORATION = auto()
23+
# Learned reward
24+
POLICY_AND_REWARD_LEARNING = auto()
25+
26+
1727
class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
18-
# TODO #625: get rid of the observation_space parameter
28+
"""
29+
Reward function for implementation of the PEBBLE learning algorithm
30+
(https://arxiv.org/pdf/2106.05091.pdf).
31+
32+
The rewards returned by this function go through the three phases
33+
defined in PebbleRewardPhase. To transition between these phases,
34+
unsupervised_exploration_start() and unsupervised_exploration_finish()
35+
need to be called.
36+
37+
The second phase (UNSUPERVISED_EXPLORATION) also requires that a buffer
38+
with observations to compare against is supplied with set_replay_buffer()
39+
or on_replay_buffer_initialized().
40+
41+
Args:
42+
learned_reward_fn: The learned reward function used after unsupervised
43+
exploration is finished
44+
nearest_neighbor_k: Parameter for entropy computation (see
45+
compute_state_entropy())
46+
"""
47+
1948
# TODO #625: parametrize nearest_neighbor_k
2049
def __init__(
2150
self,
22-
trained_reward_fn: RewardFn,
23-
observation_space: spaces.Space,
51+
learned_reward_fn: RewardFn,
2452
nearest_neighbor_k: int = 5,
2553
):
26-
self.trained_reward_fn = trained_reward_fn
54+
self.trained_reward_fn = learned_reward_fn
2755
self.nearest_neighbor_k = nearest_neighbor_k
2856
# TODO support n_envs > 1
2957
self.entropy_stats = RunningNorm(1)
30-
self.observation_space = observation_space
31-
self.obs_shape = get_obs_shape(observation_space)
32-
self.replay_buffer_view = ReplayBufferView(
33-
np.empty(0, dtype=observation_space.dtype), lambda: slice(0)
34-
)
35-
# This indicates that the training is in the "Unsupervised exploration"
36-
# phase of the Pebble algorithm, where entropy is used as reward
37-
self.unsupervised_exploration_active = True
58+
self.state = PebbleRewardPhase.LEARNING_START
59+
60+
# These two need to be set with set_replay_buffer():
61+
self.replay_buffer_view = None
62+
self.obs_shape = None
3863

3964
def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
4065
self.set_replay_buffer(replay_buffer.buffer_view, replay_buffer.obs_shape)
@@ -43,8 +68,13 @@ def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape: Tuple):
4368
self.replay_buffer_view = replay_buffer
4469
self.obs_shape = obs_shape
4570

46-
def on_unsupervised_exploration_finished(self):
47-
self.unsupervised_exploration_active = False
71+
def unsupervised_exploration_start(self):
72+
assert self.state == PebbleRewardPhase.LEARNING_START
73+
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION
74+
75+
def unsupervised_exploration_finish(self):
76+
assert self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION
77+
self.state = PebbleRewardPhase.POLICY_AND_REWARD_LEARNING
4878

4979
def __call__(
5080
self,
@@ -53,19 +83,20 @@ def __call__(
5383
next_state: np.ndarray,
5484
done: np.ndarray,
5585
) -> np.ndarray:
56-
if self.unsupervised_exploration_active:
86+
if self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION:
5787
return self._entropy_reward(state)
5888
else:
5989
return self.trained_reward_fn(state, action, next_state, done)
6090

6191
def _entropy_reward(self, state):
62-
# TODO: should this work with torch instead of numpy internally?
63-
# (The RewardFn protocol requires numpy)
92+
if self.replay_buffer_view is None:
93+
raise ValueError(
94+
"Replay buffer must be supplied before entropy reward can be used"
95+
)
96+
6497
all_observations = self.replay_buffer_view.observations
6598
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
66-
all_observations = all_observations.reshape(
67-
(-1, *state.shape[1:]) # TODO #625: fix self.obs_shape
68-
)
99+
all_observations = all_observations.reshape((-1, *self.obs_shape))
69100
# TODO #625: deal with the conversion back and forth between np and torch
70101
entropies = util.compute_state_entropy(
71102
th.tensor(state),
@@ -82,6 +113,4 @@ def __getstate__(self):
82113

83114
def __setstate__(self, state):
84115
self.__dict__.update(state)
85-
self.replay_buffer_view = ReplayBufferView(
86-
np.empty(0, self.observation_space.dtype), lambda: slice(0)
87-
)
116+
self.replay_buffer_view = None

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,48 +11,87 @@
1111
from imitation.util import util
1212

1313
SPACE = Discrete(4)
14-
PLACEHOLDER = np.empty(get_obs_shape(SPACE))
14+
OBS_SHAPE = get_obs_shape(SPACE)
15+
PLACEHOLDER = np.empty(OBS_SHAPE)
1516

1617
BUFFER_SIZE = 20
1718
K = 4
1819
BATCH_SIZE = 8
1920
VENVS = 2
2021

2122

22-
def test_pebble_entropy_reward_returns_entropy(rng):
23-
obs_shape = get_obs_shape(SPACE)
24-
all_observations = rng.random((BUFFER_SIZE, VENVS, *obs_shape))
23+
def test_pebble_entropy_reward_function_returns_learned_reward_initially():
24+
expected_reward = np.ones(1)
25+
learned_reward_mock = Mock()
26+
learned_reward_mock.return_value = expected_reward
27+
reward_fn = PebbleStateEntropyReward(learned_reward_mock, SPACE)
28+
29+
# Act
30+
observations = np.ones((BATCH_SIZE, *OBS_SHAPE))
31+
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
32+
33+
# Assert
34+
assert reward == expected_reward
35+
learned_reward_mock.assert_called_once_with(
36+
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
37+
)
38+
39+
40+
def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_training():
41+
expected_reward = np.ones(1)
42+
learned_reward_mock = Mock()
43+
learned_reward_mock.return_value = expected_reward
44+
reward_fn = PebbleStateEntropyReward(learned_reward_mock, SPACE)
45+
# move all the way to the last state
46+
reward_fn.unsupervised_exploration_start()
47+
reward_fn.unsupervised_exploration_finish()
48+
49+
# Act
50+
observations = np.ones((BATCH_SIZE, *OBS_SHAPE))
51+
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
52+
53+
# Assert
54+
assert reward == expected_reward
55+
learned_reward_mock.assert_called_once_with(
56+
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
57+
)
58+
59+
60+
def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):
61+
all_observations = rng.random((BUFFER_SIZE, VENVS, *(OBS_SHAPE)))
2562

2663
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, K)
2764
reward_fn.set_replay_buffer(
28-
ReplayBufferView(all_observations, lambda: slice(None)), obs_shape
65+
ReplayBufferView(all_observations, lambda: slice(None)), OBS_SHAPE
2966
)
67+
reward_fn.unsupervised_exploration_start()
3068

3169
# Act
32-
observations = rng.random((BATCH_SIZE, *obs_shape))
70+
observations = th.rand((BATCH_SIZE, *(OBS_SHAPE)))
3371
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
3472

3573
# Assert
3674
expected = util.compute_state_entropy(
37-
observations, all_observations.reshape(-1, *obs_shape), K
75+
observations, all_observations.reshape(-1, *(OBS_SHAPE)), K
3876
)
3977
expected_normalized = reward_fn.entropy_stats.normalize(
4078
th.as_tensor(expected)
4179
).numpy()
4280
np.testing.assert_allclose(reward, expected_normalized)
4381

4482

45-
def test_pebble_entropy_reward_returns_normalized_values():
83+
def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
4684
with patch("imitation.util.util.compute_state_entropy") as m:
4785
# mock entropy computation so that we can test only stats collection in this test
4886
m.side_effect = lambda obs, all_obs, k: obs
4987

5088
reward_fn = PebbleStateEntropyReward(Mock(), SPACE, K)
51-
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
89+
all_observations = np.empty((BUFFER_SIZE, VENVS, *OBS_SHAPE))
5290
reward_fn.set_replay_buffer(
5391
ReplayBufferView(all_observations, lambda: slice(None)),
54-
get_obs_shape(SPACE),
92+
OBS_SHAPE,
5593
)
94+
reward_fn.unsupervised_exploration_start()
5695

5796
dim = 8
5897
shift = 3
@@ -77,51 +116,25 @@ def test_pebble_entropy_reward_returns_normalized_values():
77116

78117

79118
def test_pebble_entropy_reward_can_pickle():
80-
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
119+
all_observations = np.empty((BUFFER_SIZE, VENVS, *OBS_SHAPE))
81120
replay_buffer = ReplayBufferView(all_observations, lambda: slice(None))
82121

83-
obs1 = np.random.rand(VENVS, *get_obs_shape(SPACE))
122+
obs1 = np.random.rand(VENVS, *OBS_SHAPE)
84123
reward_fn = PebbleStateEntropyReward(reward_fn_stub, SPACE, K)
85-
reward_fn.set_replay_buffer(replay_buffer, get_obs_shape(SPACE))
124+
reward_fn.set_replay_buffer(replay_buffer, OBS_SHAPE)
86125
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
87126

88127
# Act
89128
pickled = pickle.dumps(reward_fn)
90129
reward_fn_deserialized = pickle.loads(pickled)
91-
reward_fn_deserialized.set_replay_buffer(replay_buffer)
130+
reward_fn_deserialized.set_replay_buffer(replay_buffer, OBS_SHAPE)
92131

93132
# Assert
94-
obs2 = np.random.rand(VENVS, *get_obs_shape(SPACE))
133+
obs2 = np.random.rand(VENVS, *OBS_SHAPE)
95134
expected_result = reward_fn(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
96135
actual_result = reward_fn_deserialized(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
97136
np.testing.assert_allclose(actual_result, expected_result)
98137

99138

100-
def test_pebble_entropy_reward_function_switches_to_inner():
101-
obs_shape = get_obs_shape(SPACE)
102-
103-
expected_reward = np.ones(1)
104-
reward_fn_mock = Mock()
105-
reward_fn_mock.return_value = expected_reward
106-
reward_fn = PebbleStateEntropyReward(reward_fn_mock, SPACE)
107-
108-
# Act
109-
reward_fn.on_unsupervised_exploration_finished()
110-
observations = np.ones((BATCH_SIZE, *obs_shape))
111-
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
112-
113-
# Assert
114-
assert reward == expected_reward
115-
reward_fn_mock.assert_called_once_with(
116-
observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
117-
)
118-
119-
120-
def reward_fn_stub(
121-
self,
122-
state: np.ndarray,
123-
action: np.ndarray,
124-
next_state: np.ndarray,
125-
done: np.ndarray,
126-
) -> np.ndarray:
139+
def reward_fn_stub(state, action, next_state, done):
127140
return state

0 commit comments

Comments
 (0)