Skip to content

Commit 038c3a0

Browse files
committed
[Pinball] Use json instead of original format
1 parent 2533f9f commit 038c3a0

13 files changed

+134
-127
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ Config.py
1717
dist
1818
*.egg-info
1919
*.bak
20+
Untitled.ipynb
2021
.eggs
22+
/**/.ipynb_checkpoints
2123
/**/Results
2224
/**/Result
2325
MANIFEST

examples/pinball.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
import click
12
from rlpy.domains import Pinball
23
from rlpy.tools.cli import run_experiment
34

45
import methods
56

67

7-
def select_domain(noise=0.1):
8-
return Pinball(noise=noise)
8+
def select_domain(cfg, noise=0.1):
9+
if not cfg.startswith("pinball_"):
10+
cfg = "pinball_" + cfg
11+
cfg = Pinball.default_cfg(cfg + ".json")
12+
return Pinball(noise=noise, config_file=cfg)
913

1014

11-
def select_agent(name, domain, max_steps, seed):
15+
def select_agent(name, domain, max_steps, seed, **kwargs):
1216
if name is None or name == "fourier-q":
1317
return methods.fourier_q(domain, order=5)
1418
elif name == "fourier-sarsa":
@@ -36,4 +40,7 @@ def select_agent(name, domain, max_steps, seed):
3640
default_max_steps=100000,
3741
default_num_policy_checks=30,
3842
default_checks_per_policy=1,
43+
other_options=[
44+
click.Option(["--cfg"], type=str, default="pinball_simple_single")
45+
],
3946
)

rlpy/domains/PinballConfigs/pinball_box.cfg

-10
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"ball_rad": 0.02,
3+
"obstacles": [
4+
[[0.0, 0.0], [0.0, 0.01], [1.0, 0.01], [1.0, 0.0]],
5+
[[0.0, 0.0], [0.01, 0.0], [0.01, 1.0], [0.0, 1.0]],
6+
[[0.0, 1.0], [0.0, 0.99], [1.0, 0.99], [1.0, 1.0]],
7+
[[1.0, 1.0], [0.99, 1.0], [0.99, 0.0], [1.0, 0.0]],
8+
[[0.45, 0.45], [0.55, 0.45], [0.55, 0.55], [0.45, 0.55]]
9+
],
10+
"start_pos": [[0.2, 0.9]],
11+
"target_pos": [0.9, 0.2],
12+
"target_rad": 0.04
13+
}

rlpy/domains/PinballConfigs/pinball_empty.cfg

-8
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"ball_rad": 0.02,
3+
"obstacles": [
4+
[[0.0, 0.0], [0.0, 0.01], [1.0, 0.01], [1.0, 0.0]],
5+
[[0.0, 0.0], [0.01, 0.0], [0.01, 1.0], [0.0, 1.0]],
6+
[[0.0, 1.0], [0.0, 0.99], [1.0, 0.99], [1.0, 1.0]],
7+
[[1.0, 1.0], [0.99, 1.0], [0.99, 0.0], [1.0, 0.0]]
8+
],
9+
"start_pos": [[0.2, 0.9]],
10+
"target_pos": [0.9, 0.2],
11+
"target_rad": 0.04
12+
}

rlpy/domains/PinballConfigs/pinball_hard_single.cfg

-22
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"ball_rad": 0.015,
3+
"obstacles": [
4+
[[0.0, 0.0], [0.0, 0.01], [1.0, 0.01], [1.0, 0.0]],
5+
[[0.0, 0.0], [0.01, 0.0], [0.01, 1.0], [0.0, 1.0]],
6+
[[0.0, 1.0], [0.0, 0.99], [1.0, 0.99], [1.0, 1.0]],
7+
[[1.0, 1.0], [0.99, 1.0], [0.99, 0.0], [1.0, 0.0]],
8+
[[0.034, 0.852], [0.106, 0.708], [0.33199999999999996, 0.674], [0.17599999999999996, 0.618], [0.028, 0.718]],
9+
[[0.15, 0.7559999999999999], [0.142, 0.93], [0.232, 0.894], [0.238, 0.99], [0.498, 0.722]],
10+
[[0.8079999999999999, 0.91], [0.904, 0.784], [0.7799999999999999, 0.572], [0.942, 0.562], [0.952, 0.82], [0.874, 0.934]],
11+
[[0.768, 0.814], [0.692, 0.548], [0.594, 0.47], [0.606, 0.804], [0.648, 0.626]],
12+
[[0.22799999999999998, 0.5760000000000001], [0.39, 0.322], [0.3400000000000001, 0.31400000000000006], [0.184, 0.456]],
13+
[[0.09, 0.228], [0.242, 0.076], [0.106, 0.03], [0.022, 0.178]],
14+
[[0.11, 0.278], [0.24600000000000002, 0.262], [0.108, 0.454], [0.16, 0.566], [0.064, 0.626], [0.016, 0.438]],
15+
[[0.772, 0.1], [0.71, 0.20599999999999996], [0.77, 0.322], [0.894, 0.09600000000000002], [0.8039999999999999, 0.17600000000000002]],
16+
[[0.698, 0.476], [0.984, 0.27199999999999996], [0.908, 0.512]],
17+
[[0.45, 0.39199999999999996], [0.614, 0.25799999999999995], [0.7340000000000001, 0.438]],
18+
[[0.476, 0.868], [0.552, 0.8119999999999999], [0.62, 0.902], [0.626, 0.972], [0.49, 0.958]],
19+
[[0.61, 0.014000000000000002], [0.58, 0.094], [0.774, 0.05000000000000001], [0.63, 0.054000000000000006]],
20+
[[0.33399999999999996, 0.014], [0.27799999999999997, 0.03799999999999998], [0.368, 0.254], [0.7, 0.20000000000000004], [0.764, 0.108], [0.526, 0.158]],
21+
[[0.294, 0.584], [0.478, 0.626], [0.482, 0.574], [0.324, 0.434], [0.35, 0.39], [0.572, 0.52], [0.588, 0.722], [0.456, 0.668]]],
22+
"start_pos": [[0.055, 0.95]],
23+
"target_pos": [0.5, 0.06],
24+
"target_rad": 0.04
25+
}

rlpy/domains/PinballConfigs/pinball_medium.cfg

-15
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"ball_rad": 0.02,
3+
"obstacles": [
4+
[[0.0, 0.0], [0.0, 0.01], [1.0, 0.01], [1.0, 0.0]],
5+
[[0.0, 0.0], [0.01, 0.0], [0.01, 1.0], [0.0, 1.0]],
6+
[[0.0, 1.0], [0.0, 0.99], [1.0, 0.99], [1.0, 1.0]],
7+
[[1.0, 1.0], [0.99, 1.0], [0.99, 0.0], [1.0, 0.0]],
8+
[[0.09, 0.228], [0.242, 0.076], [0.106, 0.03], [0.022, 0.178]],
9+
[[0.33399999999999996, 0.014], [0.27799999999999997, 0.03799999999999998], [0.368, 0.254], [0.7, 0.20000000000000004], [0.764, 0.108], [0.526, 0.158]],
10+
[[0.034, 0.852], [0.106, 0.708], [0.33199999999999996, 0.674], [0.17599999999999996, 0.618], [0.028, 0.718]],
11+
[[0.45, 0.39199999999999996], [0.614, 0.25799999999999995], [0.7340000000000001, 0.438]],
12+
[[0.33399999999999996, 0.014], [0.27799999999999997, 0.03799999999999998], [0.368, 0.254], [0.7, 0.20000000000000004], [0.764, 0.108], [0.526, 0.158]],
13+
[[0.294, 0.584], [0.478, 0.626], [0.482, 0.574], [0.324, 0.434], [0.35, 0.39], [0.572, 0.52], [0.588, 0.722], [0.456, 0.668]]
14+
],
15+
"start_pos": [[0.2, 0.9]],
16+
"target_pos": [0.9, 0.2],
17+
"target_rad": 0.04
18+
}

rlpy/domains/PinballConfigs/pinball_simple_single.cfg

-15
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"ball_rad": 0.02,
3+
"obstacles": [
4+
[[0.0, 0.0], [0.0, 0.01], [1.0, 0.01], [1.0, 0.0]],
5+
[[0.0, 0.0], [0.01, 0.0], [0.01, 1.0], [0.0, 1.0]],
6+
[[0.0, 1.0], [0.0, 0.99], [1.0, 0.99], [1.0, 1.0]],
7+
[[1.0, 1.0], [0.99, 1.0], [0.99, 0.0], [1.0, 0.0]],
8+
[[0.35, 0.4], [0.45, 0.55], [0.43, 0.65], [0.3, 0.7], [0.45, 0.7], [0.5, 0.6], [0.45, 0.35]],
9+
[[0.2, 0.6], [0.25, 0.55], [0.15, 0.5], [0.15, 0.45], [0.2, 0.3], [0.12, 0.27], [0.075, 0.35], [0.09, 0.55]],
10+
[[0.3, 0.8], [0.6, 0.75], [0.8, 0.8], [0.8, 0.9], [0.6, 0.85], [0.3, 0.9]],
11+
[[0.8, 0.7], [0.975, 0.65], [0.75, 0.5], [0.9, 0.3], [0.7, 0.35], [0.63, 0.65]],
12+
[[0.6, 0.25], [0.3, 0.07], [0.15, 0.175], [0.15, 0.2], [0.3, 0.175], [0.6, 0.3]],
13+
[[0.75, 0.025], [0.8, 0.24], [0.725, 0.27], [0.7, 0.025]]
14+
],
15+
"start_pos": [[0.2, 0.9]],
16+
"target_pos": [0.9, 0.2],
17+
"target_rad": 0.04
18+
}

rlpy/domains/pinball.py

+36-54
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from itertools import tee
66
import itertools
7-
import os
7+
from pathlib import Path
88

99
try:
1010
from tkinter import Tk, Canvas
@@ -57,24 +57,24 @@ class Pinball(Domain):
5757
"""
5858

5959
#: default location of config files shipped with rlpy
60-
default_config_dir = os.path.join(__rlpy_location__, "domains", "PinballConfigs")
60+
DEFAULT_CONFIG_DIR = Path(__rlpy_location__).joinpath("domains/PinballConfigs")
61+
62+
@classmethod
63+
def default_cfg(cls, name="pinball_simple_single.json"):
64+
return cls.DEFAULT_CONFIG_DIR.joinpath(name)
6165

6266
def __init__(
6367
self,
6468
noise=0.1,
6569
episode_cap=1000,
66-
configuration=os.path.join(default_config_dir, "pinball_simple_single.cfg"),
70+
config_file=DEFAULT_CONFIG_DIR.joinpath("pinball_simple_single.json"),
6771
):
6872
"""
69-
configuration:
70-
location of the configuration file
71-
episode_cap:
72-
maximum length of an episode
73-
noise:
74-
with probability noise, a uniformly random action is executed
73+
:param config_file: Location of the configuration file.
74+
:param episode_cap: Maximum length of an episode
75+
:param noise: With probability noise, a uniformly random action is executed
7576
"""
7677
self.NOISE = noise
77-
self.configuration = configuration
7878
self.screen = None
7979
self.actions = [
8080
PinballModel.ACC_X,
@@ -91,14 +91,12 @@ def __init__(
9191
continuous_dims=[4],
9292
episode_cap=episode_cap,
9393
)
94-
self.environment = PinballModel(
95-
self.configuration, random_state=self.random_state
96-
)
94+
self.environment = PinballModel(config_file, random_state=self.random_state)
9795

9896
def show_domain(self, a):
9997
if self.screen is None:
10098
master = Tk()
101-
master.title("RLPY Pinball")
99+
master.title("RLPy Pinball")
102100
self.screen = Canvas(master, width=500.0, height=500.0)
103101
self.screen.configure(background="LightGray")
104102
self.screen.pack()
@@ -149,7 +147,7 @@ def is_terminal(self):
149147
return self.environment.episode_ended()
150148

151149

152-
class BallModel(object):
150+
class BallModel:
153151

154152
""" This class maintains the state of the ball
155153
in the pinball domain. It takes care of moving
@@ -202,7 +200,7 @@ def _clip(self, val, low=-2, high=2):
202200
return val
203201

204202

205-
class PinballObstacle(object):
203+
class PinballObstacle:
206204

207205
""" This class represents a single polygon obstacle in the
208206
pinball domain and detects when a :class:`BallModel` hits it.
@@ -216,7 +214,7 @@ def __init__(self, points):
216214
:param points: A list of points defining the polygon
217215
:type points: list of lists
218216
"""
219-
self.points = points
217+
self.points = np.array(points)
220218
self.min_x = min(self.points, key=lambda pt: pt[0])[0]
221219
self.max_x = max(self.points, key=lambda pt: pt[0])[0]
222220
self.min_y = min(self.points, key=lambda pt: pt[1])[1]
@@ -375,7 +373,7 @@ def _intercept_edge(self, pt_pair, ball):
375373
return False
376374

377375

378-
class PinballModel(object):
376+
class PinballModel:
379377

380378
""" This class is a self-contained model of the pinball
381379
domain for reinforcement learning.
@@ -395,13 +393,8 @@ class PinballModel(object):
395393
THRUST_PENALTY = -5
396394
END_EPISODE = 10000
397395

398-
def __init__(self, configuration, random_state=np.random.RandomState()):
396+
def __init__(self, config_file, random_state):
399397
""" Read a configuration file for Pinball and draw the domain to screen
400-
401-
:param configuration: a configuration file containing the polygons,
402-
source(s) and target location.
403-
:type configuration: str
404-
405398
"""
406399

407400
self.random_state = random_state
@@ -412,33 +405,23 @@ def __init__(self, configuration, random_state=np.random.RandomState()):
412405
self.DEC_Y: (0, -1),
413406
self.ACC_NONE: (0, 0),
414407
}
408+
import json
415409

416410
# Set up the environment according to the configuration
417-
self.obstacles = []
418-
self.target_pos = []
419-
self.target_rad = 0.01
420-
421-
ball_rad = 0.01
422-
start_pos = []
423-
with open(configuration) as fp:
424-
for line in fp.readlines():
425-
tokens = line.strip().split()
426-
if not len(tokens):
427-
continue
428-
elif tokens[0] == "polygon":
429-
self.obstacles.append(
430-
PinballObstacle(list(zip(*[iter(map(float, tokens[1:]))] * 2)))
431-
)
432-
elif tokens[0] == "target":
433-
self.target_pos = [float(tokens[1]), float(tokens[2])]
434-
self.target_rad = float(tokens[3])
435-
elif tokens[0] == "start":
436-
start_pos = list(zip(*[iter(map(float, tokens[1:]))] * 2))
437-
elif tokens[0] == "ball":
438-
ball_rad = float(tokens[1])
411+
with config_file.open() as f:
412+
config = json.load(f)
413+
try:
414+
self.obstacles = list(map(PinballObstacle, config["obstacles"]))
415+
self.target_pos = config["target_pos"]
416+
self.target_rad = config["target_rad"]
417+
start_pos = config["start_pos"]
418+
ball_rad = config["ball_rad"]
419+
except KeyError as e:
420+
raise KeyError(f"Pinball config doesn't have a key: {e}")
421+
439422
self.start_pos = start_pos[0]
440-
a = self.random_state.randint(len(start_pos))
441-
self.ball = BallModel(list(start_pos[a]), ball_rad)
423+
start_idx = self.random_state.randint(len(start_pos))
424+
self.ball = BallModel(list(start_pos[start_idx]), ball_rad)
442425

443426
def get_state(self):
444427
""" Access the current 4-dimensional state vector
@@ -520,7 +503,7 @@ def _check_bounds(self):
520503
self.ball.position[1] = 0.05
521504

522505

523-
class PinballView(object):
506+
class PinballView:
524507

525508
""" This class displays a :class:`PinballModel`
526509
@@ -592,14 +575,13 @@ def blit(self):
592575

593576
def run_pinballview(width, height, configuration):
594577
"""
595-
596-
Changed from original Pierre-Luc Bacon implementation to reflect
597-
the visualization changes in the PinballView Class.
598-
578+
Changed from original Pierre-Luc Bacon implementation to reflect
579+
the visualization changes in the PinballView Class.
599580
"""
581+
600582
width, height = float(width), float(height)
601583
master = Tk()
602-
master.title("RLPY Pinball")
584+
master.title("RLPy Pinball")
603585
screen = Canvas(master, width=500.0, height=500.0)
604586
screen.configure(background="LightGray")
605587
screen.pack()

0 commit comments

Comments
 (0)