66from gym .spaces import Discrete
77from stable_baselines3 .common .preprocessing import get_obs_shape
88
9- from imitation .algorithms .pebble .entropy_reward import StateEntropyReward
9+ from imitation .algorithms .pebble .entropy_reward import PebbleStateEntropyReward
1010from imitation .policies .replay_buffer_wrapper import ReplayBufferView
1111from imitation .util import util
1212
@@ -24,7 +24,7 @@ def test_state_entropy_reward_returns_entropy(rng):
2424 all_observations = rng .random ((BUFFER_SIZE , VENVS , * obs_shape ))
2525
2626
27- reward_fn = StateEntropyReward (K , SPACE )
27+ reward_fn = PebbleStateEntropyReward (K , SPACE )
2828 reward_fn .set_replay_buffer (ReplayBufferView (all_observations , lambda : slice (None )), obs_shape )
2929
3030 # Act
@@ -46,7 +46,7 @@ def test_state_entropy_reward_returns_normalized_values():
4646 # mock entropy computation so that we can test only stats collection in this test
4747 m .side_effect = lambda obs , all_obs , k : obs
4848
49- reward_fn = StateEntropyReward (K , SPACE )
49+ reward_fn = PebbleStateEntropyReward (K , SPACE )
5050 all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape (SPACE )))
5151 reward_fn .set_replay_buffer (
5252 ReplayBufferView (all_observations , lambda : slice (None )),
@@ -80,7 +80,7 @@ def test_state_entropy_reward_can_pickle():
8080 replay_buffer = ReplayBufferView (all_observations , lambda : slice (None ))
8181
8282 obs1 = np .random .rand (VENVS , * get_obs_shape (SPACE ))
83- reward_fn = StateEntropyReward (K , SPACE )
83+ reward_fn = PebbleStateEntropyReward (K , SPACE )
8484 reward_fn .set_replay_buffer (replay_buffer , get_obs_shape (SPACE ))
8585 reward_fn (obs1 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
8686
0 commit comments