Skip to content

Commit e725991

Browse files
committed
Refactor around @AbstractMethod
1 parent f37f098 commit e725991

File tree

4 files changed

+16
-11
lines changed

4 files changed

+16
-11
lines changed

rlpy/agents/agent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def learn(self, s, p_actions, a, r, ns, np_actions, na, terminal):
9797
:param na: action taken in the next state
9898
:param terminal: boolean indicating whether next state (ns) is terminal
9999
"""
100-
return NotImplementedError
100+
pass
101101

102102
def episode_terminated(self):
103103
"""

rlpy/domains/cartpole_base.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,15 @@ def _getReward(self, a, s=None):
251251
(e.g. for swinging the pendulum up vs. balancing).
252252
253253
"""
254-
raise NotImplementedError
254+
raise NotImplementedError("_getReward should be implemented in child classes")
255255

256+
@abstractmethod
256257
def show_domain(self, a=0):
257-
raise NotImplementedError
258+
pass
258259

260+
@abstractmethod
259261
def show_learning(self, representation):
260-
raise NotImplementedError
262+
pass
261263

262264
def possible_actions(self, s=None):
263265
"""
@@ -266,9 +268,11 @@ def possible_actions(self, s=None):
266268
"""
267269
return np.arange(self.actions_num)
268270

271+
@abstractmethod
269272
def step(self):
270-
errMsg = "Implemented in child classes which call _stepFourState()"
271-
raise NotImplementedError(errMsg)
273+
"""Implemented in child classes which call _stepFourState()
274+
"""
275+
pass
272276

273277
def _stepFourState(self, s, a):
274278
"""
@@ -656,7 +660,7 @@ def euler_int(self, df, x0, times):
656660
return [ns]
657661

658662

659-
class StateIndex(object):
663+
class StateIndex:
660664

661665
"""
662666
Flexible way to index states in the CartPole Domain.

rlpy/domains/domain.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def show_learning(self, representation):
164164
"""
165165
pass
166166

167+
@abstractmethod
167168
def s0(self):
168169
"""
169170
Begins a new episode and returns the initial observed state of the Domain.
@@ -172,7 +173,7 @@ def s0(self):
172173
:return: A numpy array that defines the initial domain state.
173174
174175
"""
175-
raise NotImplementedError("Children need to implement this method")
176+
pass
176177

177178
def possible_actions(self, s=None):
178179
"""

rlpy/policies/policy.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def pi(self, s, terminal, p_actions):
6666
:param terminal: boolean, whether or not the *s* is a terminal state.
6767
:param p_actions: a list / array of all possible actions in *s*.
6868
"""
69-
raise NotImplementedError
69+
pass
7070

7171
def turnOffExploration(self):
7272
"""
@@ -105,7 +105,7 @@ def pi(self, s, terminal, p_actions):
105105
@abstractmethod
106106
def dlogpi(self, s, a):
107107
"""derivative of the log probabilities of the policy"""
108-
return NotImplementedError
108+
pass
109109

110110
def prob(self, s, a):
111111
"""
@@ -128,4 +128,4 @@ def probabilities(self, s, terminal):
128128
returns a vector of num_actions length containing the normalized
129129
probabilities for taking each action given the state s
130130
"""
131-
return NotImplementedError
131+
pass

0 commit comments

Comments
 (0)