1+ """Reward function for the PEBBLE training algorithm."""
2+
13from enum import Enum , auto
2- from typing import Tuple
4+ from typing import Dict , Optional , Tuple , Union
35
46import numpy as np
57import torch as th
68
79from imitation .policies .replay_buffer_wrapper import (
8- ReplayBufferView ,
910 ReplayBufferRewardWrapper ,
11+ ReplayBufferView ,
1012)
1113from imitation .rewards .reward_function import ReplayBufferAwareRewardFn , RewardFn
1214from imitation .util import util
1315from imitation .util .networks import RunningNorm
1416
1517
1618class PebbleRewardPhase (Enum ):
17- """States representing different behaviors for PebbleStateEntropyReward"""
19+ """States representing different behaviors for PebbleStateEntropyReward. """
1820
1921 UNSUPERVISED_EXPLORATION = auto () # Entropy based reward
2022 POLICY_AND_REWARD_LEARNING = auto () # Learned reward
2123
2224
2325class PebbleStateEntropyReward (ReplayBufferAwareRewardFn ):
24- """
25- Reward function for implementation of the PEBBLE learning algorithm
26- ( https://arxiv.org/pdf/2106.05091.pdf) .
26+ """Reward function for implementation of the PEBBLE learning algorithm.
27+
28+ See https://arxiv.org/pdf/2106.05091.pdf .
2729
2830 The rewards returned by this function go through the three phases:
2931 1. Before enough samples are collected for entropy calculation, the
@@ -38,33 +40,38 @@ class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
3840 supplied with set_replay_buffer() or on_replay_buffer_initialized().
3941 To transition to the last phase, unsupervised_exploration_finish() needs
4042 to be called.
41-
42- Args:
43- learned_reward_fn: The learned reward function used after unsupervised
44- exploration is finished
45- nearest_neighbor_k: Parameter for entropy computation (see
46- compute_state_entropy())
4743 """
4844
49- # TODO #625: parametrize nearest_neighbor_k
5045 def __init__ (
5146 self ,
5247 learned_reward_fn : RewardFn ,
5348 nearest_neighbor_k : int = 5 ,
5449 ):
50+ """Builds this class.
51+
52+ Args:
53+ learned_reward_fn: The learned reward function used after unsupervised
54+ exploration is finished
55+ nearest_neighbor_k: Parameter for entropy computation (see
56+ compute_state_entropy())
57+ """
5558 self .learned_reward_fn = learned_reward_fn
5659 self .nearest_neighbor_k = nearest_neighbor_k
5760 self .entropy_stats = RunningNorm (1 )
5861 self .state = PebbleRewardPhase .UNSUPERVISED_EXPLORATION
5962
6063 # These two need to be set with set_replay_buffer():
61- self .replay_buffer_view = None
62- self .obs_shape = None
64+ self .replay_buffer_view : Optional [ ReplayBufferView ] = None
65+ self .obs_shape : Union [ Tuple [ int , ...], Dict [ str , Tuple [ int , ...]], None ] = None
6366
6467 def on_replay_buffer_initialized (self , replay_buffer : ReplayBufferRewardWrapper ):
6568 self .set_replay_buffer (replay_buffer .buffer_view , replay_buffer .obs_shape )
6669
67- def set_replay_buffer (self , replay_buffer : ReplayBufferView , obs_shape : Tuple ):
70+ def set_replay_buffer (
71+ self ,
72+ replay_buffer : ReplayBufferView ,
73+ obs_shape : Union [Tuple [int , ...], Dict [str , Tuple [int , ...]]],
74+ ):
6875 self .replay_buffer_view = replay_buffer
6976 self .obs_shape = obs_shape
7077
@@ -87,7 +94,7 @@ def __call__(
8794 def _entropy_reward (self , state , action , next_state , done ):
8895 if self .replay_buffer_view is None :
8996 raise ValueError (
90- "Replay buffer must be supplied before entropy reward can be used"
97+ "Replay buffer must be supplied before entropy reward can be used" ,
9198 )
9299 all_observations = self .replay_buffer_view .observations
93100 # ReplayBuffer sampling flattens the venv dimension, let's adapt to that
0 commit comments