1616class PebbleRewardPhase (Enum ):
1717 """States representing different behaviors for PebbleStateEntropyReward"""
1818
19- # Collecting samples so that we have something for entropy calculation
20- LEARNING_START = auto ()
21- # Entropy based reward
22- UNSUPERVISED_EXPLORATION = auto ()
23- # Learned reward
24- POLICY_AND_REWARD_LEARNING = auto ()
19+ UNSUPERVISED_EXPLORATION = auto () # Entropy based reward
20+ POLICY_AND_REWARD_LEARNING = auto () # Learned reward
2521
2622
2723class PebbleStateEntropyReward (ReplayBufferAwareRewardFn ):
2824 """
2925 Reward function for implementation of the PEBBLE learning algorithm
3026 (https://arxiv.org/pdf/2106.05091.pdf).
3127
32- The rewards returned by this function go through the three phases
33- defined in PebbleRewardPhase. To transition between these phases,
34- unsupervised_exploration_start() and unsupervised_exploration_finish()
35- need to be called.
28+ The rewards returned by this function go through the three phases:
29+ 1. Before enough samples are collected for entropy calculation, the
30+ underlying function is returned. This shouldn't matter because
31+ OffPolicyAlgorithms have an initialization period for `learning_starts`
32+ timesteps.
33+ 2. During the unsupervised exploration phase, entropy based reward is returned
34+ 3. After unsupervised exploration phase is finished, the underlying learned
35+ reward is returned.
3636
37- The second phase (UNSUPERVISED_EXPLORATION) also requires that a buffer
38- with observations to compare against is supplied with set_replay_buffer()
39- or on_replay_buffer_initialized().
37+ The second phase requires that a buffer with observations to compare against is
38+ supplied with set_replay_buffer() or on_replay_buffer_initialized().
39+ To transition to the last phase, unsupervised_exploration_finish() needs
40+ to be called.
4041
4142 Args:
4243 learned_reward_fn: The learned reward function used after unsupervised
@@ -51,11 +52,10 @@ def __init__(
5152 learned_reward_fn : RewardFn ,
5253 nearest_neighbor_k : int = 5 ,
5354 ):
54- self .trained_reward_fn = learned_reward_fn
55+ self .learned_reward_fn = learned_reward_fn
5556 self .nearest_neighbor_k = nearest_neighbor_k
56- # TODO support n_envs > 1
5757 self .entropy_stats = RunningNorm (1 )
58- self .state = PebbleRewardPhase .LEARNING_START
58+ self .state = PebbleRewardPhase .UNSUPERVISED_EXPLORATION
5959
6060 # These two need to be set with set_replay_buffer():
6161 self .replay_buffer_view = None
@@ -68,10 +68,6 @@ def set_replay_buffer(self, replay_buffer: ReplayBufferView, obs_shape: Tuple):
6868 self .replay_buffer_view = replay_buffer
6969 self .obs_shape = obs_shape
7070
71- def unsupervised_exploration_start (self ):
72- assert self .state == PebbleRewardPhase .LEARNING_START
73- self .state = PebbleRewardPhase .UNSUPERVISED_EXPLORATION
74-
7571 def unsupervised_exploration_finish (self ):
7672 assert self .state == PebbleRewardPhase .UNSUPERVISED_EXPLORATION
7773 self .state = PebbleRewardPhase .POLICY_AND_REWARD_LEARNING
@@ -84,26 +80,30 @@ def __call__(
8480 done : np .ndarray ,
8581 ) -> np .ndarray :
8682 if self .state == PebbleRewardPhase .UNSUPERVISED_EXPLORATION :
87- return self ._entropy_reward (state )
83+ return self ._entropy_reward (state , action , next_state , done )
8884 else :
89- return self .trained_reward_fn (state , action , next_state , done )
85+ return self .learned_reward_fn (state , action , next_state , done )
9086
91- def _entropy_reward (self , state ):
87+ def _entropy_reward (self , state , action , next_state , done ):
9288 if self .replay_buffer_view is None :
9389 raise ValueError (
9490 "Replay buffer must be supplied before entropy reward can be used"
9591 )
96-
9792 all_observations = self .replay_buffer_view .observations
9893 # ReplayBuffer sampling flattens the venv dimension, let's adapt to that
9994 all_observations = all_observations .reshape ((- 1 , * self .obs_shape ))
100- # TODO #625: deal with the conversion back and forth between np and torch
101- entropies = util .compute_state_entropy (
102- th .tensor (state ),
103- th .tensor (all_observations ),
104- self .nearest_neighbor_k ,
105- )
106- normalized_entropies = self .entropy_stats .forward (entropies )
95+
96+ if all_observations .shape [0 ] < self .nearest_neighbor_k :
97+ # not enough observations to compare to, fall back to the learned function
98+ return self .learned_reward_fn (state , action , next_state , done )
99+ else :
100+ # TODO #625: deal with the conversion back and forth between np and torch
101+ entropies = util .compute_state_entropy (
102+ th .tensor (state ),
103+ th .tensor (all_observations ),
104+ self .nearest_neighbor_k ,
105+ )
106+ normalized_entropies = self .entropy_stats .forward (entropies )
107107 return normalized_entropies .numpy ()
108108
109109 def __getstate__ (self ):
0 commit comments