11"""Reward function for the PEBBLE training algorithm."""
22
33import enum
4- from typing import Optional , Tuple
4+ from typing import Any , Callable , Optional , Tuple
55
66import gym
77import numpy as np
1818
1919
2020class InsufficientObservations (RuntimeError ):
21+ """Error signifying not enough observations for entropy calculation."""
22+
2123 pass
2224
2325
2426class EntropyRewardNet (RewardNet , ReplayBufferAwareRewardFn ):
27+ """RewardNet wrapping entropy reward function."""
28+
29+ __call__ : Callable [..., Any ] # Needed to appease pytype
30+
2531 def __init__ (
2632 self ,
2733 nearest_neighbor_k : int ,
@@ -53,6 +59,9 @@ def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper)
5359
5460 This method needs to be called, e.g., after unpickling.
5561 See also __getstate__() / __setstate__().
62+
63+ Args:
64+ replay_buffer: replay buffer with history of observations
5665 """
5766 assert self .observation_space == replay_buffer .observation_space
5867 assert self .action_space == replay_buffer .action_space
@@ -72,16 +81,18 @@ def forward(
7281 all_observations = self ._replay_buffer_view .observations
7382 # ReplayBuffer sampling flattens the venv dimension, let's adapt to that
7483 all_observations = all_observations .reshape (
75- (- 1 ,) + self .observation_space .shape
84+ (- 1 ,) + self .observation_space .shape ,
7685 )
7786
7887 if all_observations .shape [0 ] < self .nearest_neighbor_k :
7988 raise InsufficientObservations (
80- "Insufficient observations for entropy calculation"
89+ "Insufficient observations for entropy calculation" ,
8190 )
8291
8392 return util .compute_state_entropy (
84- state , all_observations , self .nearest_neighbor_k
93+ state ,
94+ all_observations ,
95+ self .nearest_neighbor_k ,
8596 )
8697
8798 def preprocess (
@@ -95,6 +106,15 @@ def preprocess(
95106
96107 We also know forward() only works with state, so no need to convert
97108 other tensors.
109+
110+ Args:
111+ state: The observation input.
112+ action: The action input.
113+ next_state: The observation input.
114+ done: Whether the episode has terminated.
115+
116+ Returns:
117+ Observations preprocessed by converting them to Tensor.
98118 """
99119 state_th = util .safe_to_tensor (state ).to (self .device )
100120 action_th = next_state_th = done_th = th .empty (0 )
@@ -172,8 +192,8 @@ def __call__(
172192 try :
173193 return self .entropy_reward_fn (state , action , next_state , done )
174194 except InsufficientObservations :
175- # not enough observations to compare to, fall back to the learned function;
176- # (falling back to a constant may also be ok)
195+ # not enough observations to compare to, fall back to the learned
196+ # function; (falling back to a constant may also be ok)
177197 return self .learned_reward_fn (state , action , next_state , done )
178198 else :
179199 return self .learned_reward_fn (state , action , next_state , done )
0 commit comments