1313from stable_baselines3 .common import buffers , off_policy_algorithm , policies
1414from stable_baselines3 .common .buffers import ReplayBuffer
1515from stable_baselines3 .common .policies import BasePolicy
16- from stable_baselines3 .common .preprocessing import get_obs_shape , get_action_dim
16+ from stable_baselines3 .common .preprocessing import get_action_dim , get_obs_shape
1717from stable_baselines3 .common .save_util import load_from_pkl
18- from stable_baselines3 .common .vec_env import DummyVecEnv
1918
20- from imitation .policies .replay_buffer_wrapper import (
21- ReplayBufferEntropyRewardWrapper ,
22- ReplayBufferRewardWrapper ,
23- )
19+ from imitation .policies .replay_buffer_wrapper import ReplayBufferRewardWrapper
2420from imitation .util import util
2521
2622
@@ -123,54 +119,6 @@ def test_wrapper_class(tmpdir, rng):
123119 replay_buffer_wrapper ._get_samples ()
124120
125121
126- # Combine this with the above test via parameterization over the buffer class
127- def test_entropy_wrapper_class_no_op (tmpdir , rng ):
128- buffer_size = 15
129- total_timesteps = 20
130- entropy_samples = 0
131-
132- venv = util .make_vec_env ("Pendulum-v1" , n_envs = 1 , rng = rng )
133- rl_algo = sb3 .SAC (
134- policy = sb3 .sac .policies .SACPolicy ,
135- policy_kwargs = dict (),
136- env = venv ,
137- seed = 42 ,
138- replay_buffer_class = ReplayBufferEntropyRewardWrapper ,
139- replay_buffer_kwargs = dict (
140- replay_buffer_class = buffers .ReplayBuffer ,
141- reward_fn = zero_reward_fn ,
142- entropy_as_reward_samples = entropy_samples ,
143- ),
144- buffer_size = buffer_size ,
145- )
146-
147- rl_algo .learn (total_timesteps = total_timesteps )
148-
149- buffer_path = osp .join (tmpdir , "buffer.pkl" )
150- rl_algo .save_replay_buffer (buffer_path )
151- replay_buffer_wrapper = load_from_pkl (buffer_path )
152- replay_buffer = replay_buffer_wrapper .replay_buffer
153-
154- # replay_buffer_wrapper.sample(...) should return zero-reward transitions
155- assert buffer_size == replay_buffer_wrapper .size () == replay_buffer .size ()
156- assert (replay_buffer_wrapper .sample (total_timesteps ).rewards == 0.0 ).all ()
157- assert (replay_buffer .sample (total_timesteps ).rewards != 0.0 ).all () # seed=42
158-
159- # replay_buffer_wrapper.pos, replay_buffer_wrapper.full
160- assert replay_buffer_wrapper .pos == total_timesteps - buffer_size
161- assert replay_buffer_wrapper .full
162-
163- # reset()
164- replay_buffer_wrapper .reset ()
165- assert 0 == replay_buffer_wrapper .size () == replay_buffer .size ()
166- assert replay_buffer_wrapper .pos == 0
167- assert not replay_buffer_wrapper .full
168-
169- # to_torch()
170- tensor = replay_buffer_wrapper .to_torch (np .ones (42 ))
171- assert type (tensor ) is th .Tensor
172-
173-
174122class ActionIsObsEnv (gym .Env ):
175123 """Simple environment where the obs is the action."""
176124
@@ -191,45 +139,6 @@ def reset(self):
191139 return np .array ([0 ])
192140
193141
194- def test_entropy_wrapper_class (tmpdir , rng ):
195- buffer_size = 20
196- entropy_samples = 500
197- k = 4
198-
199- venv = DummyVecEnv ([ActionIsObsEnv ])
200- rl_algo = sb3 .SAC (
201- policy = sb3 .sac .policies .SACPolicy ,
202- policy_kwargs = dict (),
203- env = venv ,
204- seed = 42 ,
205- replay_buffer_class = ReplayBufferEntropyRewardWrapper ,
206- replay_buffer_kwargs = dict (
207- replay_buffer_class = buffers .ReplayBuffer ,
208- reward_fn = zero_reward_fn ,
209- entropy_as_reward_samples = entropy_samples ,
210- k = k ,
211- ),
212- buffer_size = buffer_size ,
213- )
214-
215- rl_algo .learn (total_timesteps = buffer_size )
216- initial_entropy = util .compute_state_entropy (
217- th .Tensor (rl_algo .replay_buffer .replay_buffer .observations ),
218- th .Tensor (rl_algo .replay_buffer .replay_buffer .observations ),
219- k = k ,
220- )
221-
222- rl_algo .learn (total_timesteps = entropy_samples - buffer_size )
223- # Expect that the entropy of our replay buffer is now higher,
224- # since we trained with that as the reward.
225- trained_entropy = util .compute_state_entropy (
226- th .Tensor (rl_algo .replay_buffer .replay_buffer .observations ),
227- th .Tensor (rl_algo .replay_buffer .replay_buffer .observations ),
228- k = k ,
229- )
230- assert trained_entropy .mean () > initial_entropy .mean ()
231-
232-
233142def test_replay_buffer_view_provides_buffered_observations ():
234143 space = spaces .Box (np .array ([0 ]), np .array ([5 ]))
235144 n_envs = 2
0 commit comments