1111from imitation .util import util
1212
1313SPACE = Discrete (4 )
14- PLACEHOLDER = np .empty (get_obs_shape (SPACE ))
14+ OBS_SHAPE = get_obs_shape (SPACE )
15+ PLACEHOLDER = np .empty (OBS_SHAPE )
1516
1617BUFFER_SIZE = 20
1718K = 4
1819BATCH_SIZE = 8
1920VENVS = 2
2021
2122
22- def test_pebble_entropy_reward_returns_entropy (rng ):
23- obs_shape = get_obs_shape (SPACE )
24- all_observations = rng .random ((BUFFER_SIZE , VENVS , * obs_shape ))
23+ def test_pebble_entropy_reward_function_returns_learned_reward_initially ():
24+ expected_reward = np .ones (1 )
25+ learned_reward_mock = Mock ()
26+ learned_reward_mock .return_value = expected_reward
27+ reward_fn = PebbleStateEntropyReward (learned_reward_mock , SPACE )
28+
29+ # Act
30+ observations = np .ones ((BATCH_SIZE , * OBS_SHAPE ))
31+ reward = reward_fn (observations , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
32+
33+ # Assert
34+ assert reward == expected_reward
35+ learned_reward_mock .assert_called_once_with (
36+ observations , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER
37+ )
38+
39+
40+ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_training ():
41+ expected_reward = np .ones (1 )
42+ learned_reward_mock = Mock ()
43+ learned_reward_mock .return_value = expected_reward
44+ reward_fn = PebbleStateEntropyReward (learned_reward_mock , SPACE )
45+ # move all the way to the last state
46+ reward_fn .unsupervised_exploration_start ()
47+ reward_fn .unsupervised_exploration_finish ()
48+
49+ # Act
50+ observations = np .ones ((BATCH_SIZE , * OBS_SHAPE ))
51+ reward = reward_fn (observations , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
52+
53+ # Assert
54+ assert reward == expected_reward
55+ learned_reward_mock .assert_called_once_with (
56+ observations , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER
57+ )
58+
59+
60+ def test_pebble_entropy_reward_returns_entropy_for_pretraining (rng ):
61+ all_observations = rng .random ((BUFFER_SIZE , VENVS , * (OBS_SHAPE )))
2562
2663 reward_fn = PebbleStateEntropyReward (Mock (), SPACE , K )
2764 reward_fn .set_replay_buffer (
28- ReplayBufferView (all_observations , lambda : slice (None )), obs_shape
65+ ReplayBufferView (all_observations , lambda : slice (None )), OBS_SHAPE
2966 )
67+ reward_fn .unsupervised_exploration_start ()
3068
3169 # Act
32- observations = rng . random ((BATCH_SIZE , * obs_shape ))
70+ observations = th . rand ((BATCH_SIZE , * ( OBS_SHAPE ) ))
3371 reward = reward_fn (observations , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
3472
3573 # Assert
3674 expected = util .compute_state_entropy (
37- observations , all_observations .reshape (- 1 , * obs_shape ), K
75+ observations , all_observations .reshape (- 1 , * ( OBS_SHAPE ) ), K
3876 )
3977 expected_normalized = reward_fn .entropy_stats .normalize (
4078 th .as_tensor (expected )
4179 ).numpy ()
4280 np .testing .assert_allclose (reward , expected_normalized )
4381
4482
45- def test_pebble_entropy_reward_returns_normalized_values ():
83+ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining ():
4684 with patch ("imitation.util.util.compute_state_entropy" ) as m :
4785 # mock entropy computation so that we can test only stats collection in this test
4886 m .side_effect = lambda obs , all_obs , k : obs
4987
5088 reward_fn = PebbleStateEntropyReward (Mock (), SPACE , K )
51- all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape ( SPACE ) ))
89+ all_observations = np .empty ((BUFFER_SIZE , VENVS , * OBS_SHAPE ))
5290 reward_fn .set_replay_buffer (
5391 ReplayBufferView (all_observations , lambda : slice (None )),
54- get_obs_shape ( SPACE ) ,
92+ OBS_SHAPE ,
5593 )
94+ reward_fn .unsupervised_exploration_start ()
5695
5796 dim = 8
5897 shift = 3
@@ -77,51 +116,25 @@ def test_pebble_entropy_reward_returns_normalized_values():
77116
78117
79118def test_pebble_entropy_reward_can_pickle ():
80- all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape ( SPACE ) ))
119+ all_observations = np .empty ((BUFFER_SIZE , VENVS , * OBS_SHAPE ))
81120 replay_buffer = ReplayBufferView (all_observations , lambda : slice (None ))
82121
83- obs1 = np .random .rand (VENVS , * get_obs_shape ( SPACE ) )
122+ obs1 = np .random .rand (VENVS , * OBS_SHAPE )
84123 reward_fn = PebbleStateEntropyReward (reward_fn_stub , SPACE , K )
85- reward_fn .set_replay_buffer (replay_buffer , get_obs_shape ( SPACE ) )
124+ reward_fn .set_replay_buffer (replay_buffer , OBS_SHAPE )
86125 reward_fn (obs1 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
87126
88127 # Act
89128 pickled = pickle .dumps (reward_fn )
90129 reward_fn_deserialized = pickle .loads (pickled )
91- reward_fn_deserialized .set_replay_buffer (replay_buffer )
130+ reward_fn_deserialized .set_replay_buffer (replay_buffer , OBS_SHAPE )
92131
93132 # Assert
94- obs2 = np .random .rand (VENVS , * get_obs_shape ( SPACE ) )
133+ obs2 = np .random .rand (VENVS , * OBS_SHAPE )
95134 expected_result = reward_fn (obs2 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
96135 actual_result = reward_fn_deserialized (obs2 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
97136 np .testing .assert_allclose (actual_result , expected_result )
98137
99138
100- def test_pebble_entropy_reward_function_switches_to_inner ():
101- obs_shape = get_obs_shape (SPACE )
102-
103- expected_reward = np .ones (1 )
104- reward_fn_mock = Mock ()
105- reward_fn_mock .return_value = expected_reward
106- reward_fn = PebbleStateEntropyReward (reward_fn_mock , SPACE )
107-
108- # Act
109- reward_fn .on_unsupervised_exploration_finished ()
110- observations = np .ones ((BATCH_SIZE , * obs_shape ))
111- reward = reward_fn (observations , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
112-
113- # Assert
114- assert reward == expected_reward
115- reward_fn_mock .assert_called_once_with (
116- observations , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER
117- )
118-
119-
120- def reward_fn_stub (
121- self ,
122- state : np .ndarray ,
123- action : np .ndarray ,
124- next_state : np .ndarray ,
125- done : np .ndarray ,
126- ) -> np .ndarray :
139+ def reward_fn_stub (state , action , next_state , done ):
127140 return state
0 commit comments