11import pickle
2- from unittest .mock import patch
2+ from unittest .mock import patch , Mock
33
44import numpy as np
55import torch as th
1919VENVS = 2
2020
2121
22- def test_state_entropy_reward_returns_entropy (rng ):
22+ def test_pebble_entropy_reward_returns_entropy (rng ):
2323 obs_shape = get_obs_shape (SPACE )
2424 all_observations = rng .random ((BUFFER_SIZE , VENVS , * obs_shape ))
2525
26-
27- reward_fn = PebbleStateEntropyReward (K , SPACE )
28- reward_fn .set_replay_buffer (ReplayBufferView (all_observations , lambda : slice (None )), obs_shape )
26+ reward_fn = PebbleStateEntropyReward (Mock (), SPACE , K )
27+ reward_fn .set_replay_buffer (
28+ ReplayBufferView (all_observations , lambda : slice (None )), obs_shape
29+ )
2930
3031 # Act
3132 observations = rng .random ((BATCH_SIZE , * obs_shape ))
@@ -41,16 +42,16 @@ def test_state_entropy_reward_returns_entropy(rng):
4142 np .testing .assert_allclose (reward , expected_normalized )
4243
4344
44- def test_state_entropy_reward_returns_normalized_values ():
45+ def test_pebble_entropy_reward_returns_normalized_values ():
4546 with patch ("imitation.util.util.compute_state_entropy" ) as m :
4647 # mock entropy computation so that we can test only stats collection in this test
4748 m .side_effect = lambda obs , all_obs , k : obs
4849
49- reward_fn = PebbleStateEntropyReward (K , SPACE )
50+ reward_fn = PebbleStateEntropyReward (Mock () , SPACE , K )
5051 all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape (SPACE )))
5152 reward_fn .set_replay_buffer (
5253 ReplayBufferView (all_observations , lambda : slice (None )),
53- get_obs_shape (SPACE )
54+ get_obs_shape (SPACE ),
5455 )
5556
5657 dim = 8
@@ -75,12 +76,12 @@ def test_state_entropy_reward_returns_normalized_values():
7576 )
7677
7778
78- def test_state_entropy_reward_can_pickle ():
79+ def test_pebble_entropy_reward_can_pickle ():
7980 all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape (SPACE )))
8081 replay_buffer = ReplayBufferView (all_observations , lambda : slice (None ))
8182
8283 obs1 = np .random .rand (VENVS , * get_obs_shape (SPACE ))
83- reward_fn = PebbleStateEntropyReward (K , SPACE )
84+ reward_fn = PebbleStateEntropyReward (reward_fn_stub , SPACE , K )
8485 reward_fn .set_replay_buffer (replay_buffer , get_obs_shape (SPACE ))
8586 reward_fn (obs1 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
8687
@@ -94,3 +95,33 @@ def test_state_entropy_reward_can_pickle():
9495 expected_result = reward_fn (obs2 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
9596 actual_result = reward_fn_deserialized (obs2 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
9697 np .testing .assert_allclose (actual_result , expected_result )
98+
99+
100+ def test_pebble_entropy_reward_function_switches_to_inner ():
101+ obs_shape = get_obs_shape (SPACE )
102+
103+ expected_reward = np .ones (1 )
104+ reward_fn_mock = Mock ()
105+ reward_fn_mock .return_value = expected_reward
106+ reward_fn = PebbleStateEntropyReward (reward_fn_mock , SPACE )
107+
108+ # Act
109+ reward_fn .on_unsupervised_exploration_finished ()
110+ observations = np .ones ((BATCH_SIZE , * obs_shape ))
111+ reward = reward_fn (observations , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
112+
113+ # Assert
114+ assert reward == expected_reward
115+ reward_fn_mock .assert_called_once_with (
116+ observations , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER
117+ )
118+
119+
120+ def reward_fn_stub (
121+ self ,
122+ state : np .ndarray ,
123+ action : np .ndarray ,
124+ next_state : np .ndarray ,
125+ done : np .ndarray ,
126+ ) -> np .ndarray :
127+ return state
0 commit comments