1+ import pickle
12from unittest .mock import patch
23
34import numpy as np
@@ -33,7 +34,9 @@ def test_state_entropy_reward_returns_entropy(rng):
3334 expected = util .compute_state_entropy (
3435 observations , all_observations .reshape (- 1 , * obs_shape ), K
3536 )
36- expected_normalized = reward_fn .entropy_stats .normalize (th .as_tensor (expected )).numpy ()
37+ expected_normalized = reward_fn .entropy_stats .normalize (
38+ th .as_tensor (expected )
39+ ).numpy ()
3740 np .testing .assert_allclose (reward , expected_normalized )
3841
3942
@@ -44,7 +47,7 @@ def test_state_entropy_reward_returns_normalized_values():
4447
4548 reward_fn = StateEntropyReward (K , SPACE )
4649 all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape (SPACE )))
47- reward_fn .set_buffer_view (
50+ reward_fn .set_replay_buffer (
4851 ReplayBufferView (all_observations , lambda : slice (None ))
4952 )
5053
@@ -68,3 +71,24 @@ def test_state_entropy_reward_returns_normalized_values():
6871 rtol = 0.05 ,
6972 atol = 0.05 ,
7073 )
74+
75+
76+ def test_state_entropy_reward_can_pickle ():
77+ all_observations = np .empty ((BUFFER_SIZE , VENVS , * get_obs_shape (SPACE )))
78+ replay_buffer = ReplayBufferView (all_observations , lambda : slice (None ))
79+
80+ obs1 = np .random .rand (VENVS , * get_obs_shape (SPACE ))
81+ reward_fn = StateEntropyReward (K , SPACE )
82+ reward_fn .set_replay_buffer (replay_buffer )
83+ reward_fn (obs1 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
84+
85+ # Act
86+ pickled = pickle .dumps (reward_fn )
87+ reward_fn_deserialized = pickle .loads (pickled )
88+ reward_fn_deserialized .set_replay_buffer (replay_buffer )
89+
90+ # Assert
91+ obs2 = np .random .rand (VENVS , * get_obs_shape (SPACE ))
92+ expected_result = reward_fn (obs2 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
93+ actual_result = reward_fn_deserialized (obs2 , PLACEHOLDER , PLACEHOLDER , PLACEHOLDER )
94+ np .testing .assert_allclose (actual_result , expected_result )
0 commit comments