Skip to content

Commit 8e1780a

Browse files
committed
Fix some examples and plotting (with Qt)
1 parent f297a5a commit 8e1780a

File tree

9 files changed

+39
-21
lines changed

9 files changed

+39
-21
lines changed

Diff for: examples/bernoulli_gridworld.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
def select_domain(map_, noise, episode_cap, **kwargs):
99
map_ = BernoulliGridWorld.default_map(map_ + ".txt")
1010
return BernoulliGridWorld(
11-
map_, random_start=True, noise=noise, episode_cap=episode_cap,
11+
map_,
12+
noise=noise,
13+
episode_cap=episode_cap,
1214
)
1315

1416

Diff for: examples/deepsea.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@ def select_domain(size, noise, **kwargs):
2222
click.Option(["--epsilon-min"], type=float, default=None),
2323
click.Option(["--beta"], type=float, default=0.05),
2424
click.Option(["--show-reward"], is_flag=True),
25+
click.Option(["--vi-threshold"], type=float, default=0.001),
2526
],
2627
)

Diff for: examples/fr_gridworld.py

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ def select_domain(map_, noise, step_penalty, episode_cap, **kwargs):
99
map_ = FixedRewardGridWorld.default_map(map_ + ".txt")
1010
return FixedRewardGridWorld(
1111
map_,
12-
random_start=True,
1312
noise=noise,
1413
step_penalty=step_penalty,
1514
episode_cap=episode_cap,

Diff for: examples/gridworld.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def select_domain(map_, noise, **kwargs):
99
random_goal = "RandomGoal" in map_
1010
map_ = GridWorld.default_map(map_ + ".txt")
1111
return GridWorld(
12-
map_, random_start=True, random_goal=random_goal, noise=noise, episode_cap=20
12+
map_, random_goal=random_goal, noise=noise, episode_cap=20
1313
)
1414

1515

Diff for: examples/mdp-solvers/fr_gridworld.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212

1313
def select_domain(map_, step_penalty, **kwargs):
1414
map_ = FixedRewardGridWorld.default_map(map_ + ".txt")
15-
return FixedRewardGridWorld(
16-
map_, random_start=True, noise=0.1, step_penalty=step_penalty
17-
)
15+
return FixedRewardGridWorld(map_, noise=0.1, step_penalty=step_penalty)
1816

1917

2018
def select_agent(name, domain, seed, threshold, **kwargs):

Diff for: examples/mdp-solvers/gridworld.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
def select_domain(map_="4x5", **kwargs):
1414
map_ = GridWorld.default_map(map_ + ".txt")
15-
return GridWorld(map_, random_start=True, noise=0.1)
15+
return GridWorld(map_, noise=0.1)
1616

1717

1818
def select_agent(name, domain, seed, **kwargs):

Diff for: rlpy/domains/grid_world.py

+3
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def show_domain(self, a=0, s=None, legend=False, noticks=False):
245245
self.agent_fig.remove()
246246
self.agent_fig = self._agent_fig(s)
247247
self.domain_fig.canvas.draw()
248+
self.domain_fig.show()
248249
if JUPYTER_MODE:
249250
if self.domain_display is None:
250251
self.domain_display = display(self.domain_fig, display_id=True) # noqa
@@ -567,6 +568,8 @@ def _init_value_vis(self):
567568
self.vf_fig.show()
568569

569570
def show_learning(self, representation):
571+
import matplotlib as mpl
572+
570573
if self.vf_ax is None:
571574
self._init_value_vis()
572575
self._reset_texts(self.vf_texts)

Diff for: rlpy/domains/pinball.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
self.screen_width = screen_width
109109
self.screen_height = screen_height
110110

111-
def show_domain(self, _a=None):
111+
def show_domain(self, a=None):
112112
if self.screen is None:
113113
tk_window = Tk()
114114
tk_window.title("RLPy Pinball")
@@ -117,7 +117,10 @@ def show_domain(self, _a=None):
117117
self.screen.configure(background="LightGray")
118118
self.screen.pack()
119119
self.environment_view = PinballView(
120-
self.screen, width, height, self.environment,
120+
self.screen,
121+
width,
122+
height,
123+
self.environment,
121124
)
122125
self.environment_view.blit()
123126
self.screen.pack()
@@ -209,7 +212,11 @@ def __init__(
209212
for i in range(nrows * ncols):
210213
ax = self.fig.add_subplot(nrows, ncols, i + 1)
211214
img = ax.imshow(
212-
dummy_data, cmap=cmap, interpolation="nearest", vmin=vmin, vmax=vmax,
215+
dummy_data,
216+
cmap=cmap,
217+
interpolation="nearest",
218+
vmin=vmin,
219+
vmax=vmax,
213220
)
214221
cbar = ax.figure.colorbar(img, ax=ax)
215222
cbar.ax.set_ylabel("", rotation=-90, va="bottom")
@@ -234,7 +241,7 @@ def draw(self):
234241

235242
class BallModel:
236243

237-
""" This class maintains the state of the ball
244+
"""This class maintains the state of the ball
238245
in the pinball domain. It takes care of moving
239246
it according to the current velocity and drag coefficient.
240247
@@ -269,7 +276,7 @@ def step(self):
269276

270277
class PinballObstacle:
271278

272-
""" This class represents a single polygon obstacle in the
279+
"""This class represents a single polygon obstacle in the
273280
pinball domain and detects when a :class:`BallModel` hits it.
274281
275282
When a collision is detected, it also provides a way to
@@ -291,7 +298,7 @@ def __init__(self, points):
291298
self._intercept = None
292299

293300
def collision(self, ball):
294-
""" Determines if the ball hits this obstacle
301+
"""Determines if the ball hits this obstacle
295302
:param ball: An instance of :class:`BallModel`
296303
:type ball: :class:`BallModel`
297304
"""
@@ -371,7 +378,7 @@ def _select_edge(self, intersect1, intersect2, ball):
371378
return intersect2
372379

373380
def _angle(self, v1, v2):
374-
""" Compute the angle difference between two vectors
381+
"""Compute the angle difference between two vectors
375382
:param v1: The x,y coordinates of the vector
376383
:type: v1: list
377384
:param v2: The x,y coordinates of the vector
@@ -424,7 +431,11 @@ class PinballTarget:
424431
"""
425432

426433
def __init__(
427-
self, target_pos, target_rad, target_color="red", target_reward_scale=1.0,
434+
self,
435+
target_pos,
436+
target_rad,
437+
target_color="red",
438+
target_reward_scale=1.0,
428439
):
429440
if isinstance(target_pos[0], list):
430441
self.num_goals = len(target_pos)
@@ -481,8 +492,7 @@ class _DoubleCollision:
481492
pass
482493

483494
def __init__(self, config_file, random_state):
484-
""" Reads a configuration file for Pinball and draw the domain to screen
485-
"""
495+
"""Reads a configuration file for Pinball and draw the domain to screen"""
486496

487497
self.random_state = random_state
488498
self.action_effects = {
@@ -522,7 +532,7 @@ def sample_start(self):
522532
return self.start_positions[idx].copy()
523533

524534
def get_state(self):
525-
""" Access the current 4-dimensional state vector.
535+
"""Access the current 4-dimensional state vector.
526536
:returns: a list containing the x position, y position, xdot, ydot
527537
:rtype: np.ndarray
528538
"""
@@ -542,7 +552,7 @@ def _detect_collision(self):
542552
return PinballModel._Collision(dxdy)
543553

544554
def take_action(self, action):
545-
""" Take a step in the environment
555+
"""Take a step in the environment
546556
547557
:param action: The action to apply over the ball
548558
:type action: int

Diff for: rlpy/tools/plotting.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import matplotlib as mpl
55
from matplotlib import cm, colors, lines, rc # noqa
66
from matplotlib import pylab as pl
7-
from matplotlib import pyplot as plt
87
from matplotlib import patches as mpatches # noqa
98
from matplotlib import path as mpath # noqa
109
import numpy as np
@@ -27,6 +26,8 @@ def jupyter_mode(mode=True):
2726

2827

2928
def nogui_mode():
29+
from matplotlib import pyplot as plt
30+
3031
mpl.use("agg")
3132
plt.ioff()
3233

@@ -38,9 +39,13 @@ def _stub(*args, **kwargs):
3839

3940
# Try GUI backend first
4041
try:
41-
mpl.use("tkAgg")
42+
mpl.use("TkAgg")
43+
from matplotlib import pyplot as plt
44+
4245
plt.ion()
4346
except ImportError:
47+
from matplotlib import pyplot as plt
48+
4449
nogui_mode()
4550

4651

0 commit comments

Comments
 (0)