Skip to content

Commit 2dec99f

Browse files
author
Jan Michelfeit
committed
#625 entropy reward as a function
1 parent 27b8a55 commit 2dec99f

File tree

7 files changed

+211
-8
lines changed

7 files changed

+211
-8
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
import torch as th
3+
from gym.vector.utils import spaces
4+
from stable_baselines3.common.preprocessing import get_obs_shape
5+
6+
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
7+
from imitation.rewards.reward_function import RewardFn
8+
from imitation.util import util
9+
from imitation.util.networks import RunningNorm
10+
11+
12+
class StateEntropyReward(RewardFn):
13+
def __init__(self, nearest_neighbor_k: int, observation_space: spaces.Space):
14+
self.nearest_neighbor_k = nearest_neighbor_k
15+
# TODO support n_envs > 1
16+
self.entropy_stats = RunningNorm(1)
17+
self.obs_shape = get_obs_shape(observation_space)
18+
self.replay_buffer_view = ReplayBufferView(
19+
np.empty(0, dtype=observation_space.dtype), lambda: slice(0)
20+
)
21+
22+
def set_buffer_view(self, replay_buffer_view: ReplayBufferView):
23+
self.replay_buffer_view = replay_buffer_view
24+
25+
def __call__(
26+
self,
27+
state: np.ndarray,
28+
action: np.ndarray,
29+
next_state: np.ndarray,
30+
done: np.ndarray,
31+
) -> np.ndarray:
32+
# TODO: should this work with torch instead of numpy internally?
33+
# (The RewardFn protocol requires numpy)
34+
35+
all_observations = self.replay_buffer_view.observations
36+
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
37+
all_observations = all_observations.reshape((-1, *self.obs_shape))
38+
entropies = util.compute_state_entropy(
39+
state,
40+
all_observations,
41+
self.nearest_neighbor_k,
42+
)
43+
normalized_entropies = self.entropy_stats.forward(th.as_tensor(entropies))
44+
return normalized_entropies.numpy()

src/imitation/policies/replay_buffer_wrapper.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,29 @@ def _samples_to_reward_fn_input(
2424
)
2525

2626

27+
class ReplayBufferView:
28+
"""A read-only view over a valid records in a ReplayBuffer.
29+
30+
Args:
31+
observations_buffer: Array buffer holding observations
32+
buffer_slice_provider: Function returning slice of buffer
33+
with valid observations
34+
"""
35+
36+
def __init__(
37+
self,
38+
observations_buffer: np.ndarray,
39+
buffer_slice_provider: Callable[[], slice],
40+
):
41+
self._observations_buffer = observations_buffer.view()
42+
self._observations_buffer.flags.writeable = False
43+
self._buffer_slice_provider = buffer_slice_provider
44+
45+
@property
46+
def observations(self):
47+
return self._observations_buffer[self._buffer_slice_provider()]
48+
49+
2750
class ReplayBufferRewardWrapper(ReplayBuffer):
2851
"""Relabel the rewards in transitions sampled from a ReplayBuffer."""
2952

@@ -83,6 +106,13 @@ def full(self) -> bool: # type: ignore[override]
83106
def full(self, full: bool):
84107
self.replay_buffer.full = full
85108

109+
@property
110+
def buffer_view(self) -> ReplayBufferView:
111+
def valid_buffer_slice():
112+
return slice(None) if self.full else slice(self.pos)
113+
114+
return ReplayBufferView(self.replay_buffer.observations, valid_buffer_slice)
115+
86116
def sample(self, *args, **kwargs):
87117
samples = self.replay_buffer.sample(*args, **kwargs)
88118
rewards = self.reward_fn(**_samples_to_reward_fn_input(samples))
@@ -171,7 +201,7 @@ def sample(self, *args, **kwargs):
171201
all_obs = all_obs.reshape((-1, *self.obs_shape))
172202
entropies = util.compute_state_entropy(
173203
samples.observations,
174-
all_obs.reshape((-1, *self.obs_shape)),
204+
all_obs,
175205
self.k,
176206
)
177207

src/imitation/util/networks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def forward(self, x: th.Tensor) -> th.Tensor:
8686
with th.no_grad():
8787
self.update_stats(x)
8888

89+
return self.normalize(x)
90+
91+
def normalize(self, x: th.Tensor) -> th.Tensor:
8992
return (x - self.running_mean) / th.sqrt(self.running_var + self.eps)
9093

9194
@abc.abstractmethod

src/imitation/util/util.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,10 @@ def get_first_iter_element(iterable: Iterable[T]) -> Tuple[T, Iterable[T]]:
362362

363363

364364
def compute_state_entropy(
365-
obs: th.Tensor,
366-
all_obs: th.Tensor,
365+
obs: np.ndarray,
366+
all_obs: np.ndarray,
367367
k: int,
368-
) -> th.Tensor:
368+
) -> np.ndarray:
369369
"""Compute the state entropy given by KNN distance.
370370
371371
Args:
@@ -379,14 +379,19 @@ def compute_state_entropy(
379379
assert obs.shape[1:] == all_obs.shape[1:]
380380
with th.no_grad():
381381
non_batch_dimensions = tuple(range(2, len(obs.shape) + 1))
382-
distances_tensor = th.linalg.vector_norm(
382+
distances_tensor = np.linalg.norm(
383383
obs[:, None] - all_obs[None, :],
384-
dim=non_batch_dimensions,
384+
axis=non_batch_dimensions,
385385
ord=2,
386386
)
387387

388388
# Note that we take the k+1'th value because the closest neighbor to
389389
# a point is itself, which we want to skip.
390-
knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values
390+
knn_dists = kth_value(distances_tensor, k+1)
391391
state_entropy = knn_dists
392-
return state_entropy.unsqueeze(1)
392+
return np.expand_dims(state_entropy, axis=1)
393+
394+
395+
def kth_value(x: np.ndarray, k: int):
396+
assert k > 0
397+
return np.partition(x, k - 1, axis=-1)[..., k - 1]
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from unittest.mock import patch
2+
3+
import numpy as np
4+
import torch as th
5+
from gym.spaces import Discrete
6+
from stable_baselines3.common.preprocessing import get_obs_shape
7+
8+
from imitation.algorithms.pebble.entropy_reward import StateEntropyReward
9+
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
10+
from imitation.util import util
11+
12+
SPACE = Discrete(4)
13+
PLACEHOLDER = np.empty(get_obs_shape(SPACE))
14+
15+
BUFFER_SIZE = 20
16+
K = 4
17+
BATCH_SIZE = 8
18+
VENVS = 2
19+
20+
21+
def test_state_entropy_reward_returns_entropy(rng):
22+
obs_shape = get_obs_shape(SPACE)
23+
all_observations = rng.random((BUFFER_SIZE, VENVS, *obs_shape))
24+
25+
reward_fn = StateEntropyReward(K, SPACE)
26+
reward_fn.set_buffer_view(ReplayBufferView(all_observations, lambda: slice(None)))
27+
28+
# Act
29+
observations = rng.random((BATCH_SIZE, *obs_shape))
30+
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
31+
32+
# Assert
33+
expected = util.compute_state_entropy(
34+
observations, all_observations.reshape(-1, *obs_shape), K
35+
)
36+
expected_normalized = reward_fn.entropy_stats.normalize(th.as_tensor(expected)).numpy()
37+
np.testing.assert_allclose(reward, expected_normalized)
38+
39+
40+
def test_state_entropy_reward_returns_normalized_values():
41+
with patch("imitation.util.util.compute_state_entropy") as m:
42+
# mock entropy computation so that we can test only stats collection in this test
43+
m.side_effect = lambda obs, all_obs, k: obs
44+
45+
reward_fn = StateEntropyReward(K, SPACE)
46+
all_observations = np.empty((BUFFER_SIZE, VENVS, *get_obs_shape(SPACE)))
47+
reward_fn.set_buffer_view(
48+
ReplayBufferView(all_observations, lambda: slice(None))
49+
)
50+
51+
dim = 8
52+
shift = 3
53+
scale = 2
54+
55+
# Act
56+
for _ in range(1000):
57+
state = th.randn(dim) * scale + shift
58+
reward_fn(state, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
59+
60+
normalized_reward = reward_fn(
61+
np.zeros(dim), PLACEHOLDER, PLACEHOLDER, PLACEHOLDER
62+
)
63+
64+
# Assert
65+
np.testing.assert_allclose(
66+
normalized_reward,
67+
np.repeat(-shift / scale, dim),
68+
rtol=0.05,
69+
atol=0.05,
70+
)

tests/policies/test_replay_buffer_wrapper.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os.path as osp
44
from typing import Type
5+
from unittest.mock import Mock
56

67
import gym
78
import numpy as np
@@ -10,7 +11,9 @@
1011
import torch as th
1112
from gym import spaces
1213
from stable_baselines3.common import buffers, off_policy_algorithm, policies
14+
from stable_baselines3.common.buffers import ReplayBuffer
1315
from stable_baselines3.common.policies import BasePolicy
16+
from stable_baselines3.common.preprocessing import get_obs_shape, get_action_dim
1417
from stable_baselines3.common.save_util import load_from_pkl
1518
from stable_baselines3.common.vec_env import DummyVecEnv
1619

@@ -225,3 +228,39 @@ def test_entropy_wrapper_class(tmpdir, rng):
225228
k=k,
226229
)
227230
assert trained_entropy.mean() > initial_entropy.mean()
231+
232+
233+
def test_replay_buffer_view_provides_buffered_observations():
234+
space = spaces.Box(np.array([0]), np.array([5]))
235+
n_envs = 2
236+
buffer_size = 10
237+
action = np.empty((n_envs, get_action_dim(space)))
238+
239+
obs_shape = get_obs_shape(space)
240+
wrapper = ReplayBufferRewardWrapper(
241+
buffer_size,
242+
space,
243+
space,
244+
replay_buffer_class=ReplayBuffer,
245+
reward_fn=Mock(),
246+
n_envs=n_envs,
247+
handle_timeout_termination=False,
248+
)
249+
view = wrapper.buffer_view
250+
251+
# initially empty
252+
assert len(view.observations) == 0
253+
254+
# after adding observation
255+
obs1 = np.random.random((n_envs, *obs_shape))
256+
wrapper.add(obs1, obs1, action, np.empty(n_envs), np.empty(n_envs), [])
257+
np.testing.assert_allclose(view.observations, np.array([obs1]))
258+
259+
# after filling buffer
260+
observations = np.random.random((buffer_size // n_envs, n_envs, *obs_shape))
261+
for obs in observations:
262+
wrapper.add(obs, obs, action, np.empty(n_envs), np.empty(n_envs), [])
263+
264+
# ReplayBuffer internally uses a circular buffer
265+
expected = np.roll(observations, 1, axis=0)
266+
np.testing.assert_allclose(view.observations, expected)

tests/util/test_util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from imitation.util import sacred as sacred_util
1313
from imitation.util import util
14+
from imitation.util.util import kth_value
1415

1516

1617
def test_endless_iter():
@@ -144,3 +145,14 @@ def test_compute_state_entropy_2d():
144145
util.compute_state_entropy(obs, all_obs, k=3),
145146
np.sqrt(20**2 + 2**2),
146147
)
148+
149+
150+
def test_kth_value():
151+
arr1 = np.arange(0, 10, 1)
152+
np.random.shuffle(arr1)
153+
arr2 = np.arange(0, 100, 10)
154+
np.random.shuffle(arr2)
155+
arr = np.stack([arr1, arr2])
156+
157+
result = kth_value(arr, 3)
158+
np.testing.assert_array_equal(result, np.array([2, 20]))

0 commit comments

Comments
 (0)