77from stable_baselines3 .common .buffers import ReplayBuffer
88from stable_baselines3 .common .type_aliases import ReplayBufferSamples
99
10- from imitation .rewards .reward_function import RewardFn
10+ from imitation .rewards .reward_function import RewardFn , ReplayBufferAwareRewardFn
1111from imitation .util import util
1212
1313
@@ -37,13 +37,13 @@ def __init__(
3737 observations_buffer : np .ndarray ,
3838 buffer_slice_provider : Callable [[], slice ],
3939 ):
40- self ._observations_buffer = observations_buffer .view ()
41- self ._observations_buffer .flags .writeable = False
40+ self ._observations_buffer_view = observations_buffer .view ()
41+ self ._observations_buffer_view .flags .writeable = False
4242 self ._buffer_slice_provider = buffer_slice_provider
4343
4444 @property
4545 def observations (self ):
46- return self ._observations_buffer [self ._buffer_slice_provider ()]
46+ return self ._observations_buffer_view [self ._buffer_slice_provider ()]
4747
4848
4949class ReplayBufferRewardWrapper (ReplayBuffer ):
@@ -57,7 +57,6 @@ def __init__(
5757 * ,
5858 replay_buffer_class : Type [ReplayBuffer ],
5959 reward_fn : RewardFn ,
60- on_initialized_callback : Callable [["ReplayBufferRewardWrapper" ], None ] = None ,
6160 ** kwargs ,
6261 ):
6362 """Builds ReplayBufferRewardWrapper.
@@ -88,8 +87,8 @@ def __init__(
8887 self .reward_fn = reward_fn
8988 _base_kwargs = {k : v for k , v in kwargs .items () if k in ["device" , "n_envs" ]}
9089 super ().__init__ (buffer_size , observation_space , action_space , ** _base_kwargs )
91- if on_initialized_callback is not None :
92- on_initialized_callback (self )
90+ if isinstance ( reward_fn , ReplayBufferAwareRewardFn ) :
91+ reward_fn . on_replay_buffer_initialized (self )
9392
9493 # TODO(juan) remove the type ignore once the merged PR
9594 # https://github.com/python/mypy/pull/13475
0 commit comments