-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathsimple_breakout.py
98 lines (77 loc) · 2.95 KB
/
simple_breakout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import random
class SimpleBreakoutVisualizer:
def __init__(self, algo):
import pygame
self.screen_size = 12
self.screen = pygame.display.set_mode((self.screen_size * 10, self.screen_size * 10 * 4))
self.mem = {}
self.algo = algo
def show(self, prev_frames):
import pygame
def l(x):
if x not in self.mem:
self.mem[x] = (x, x, x)
return self.mem[x]
f_l = np.frompyfunc(l, 1, 3)
rect = pygame.Surface((self.screen_size * 10, self.screen_size * 10 * 4))
image = np.reshape(zip(*list(f_l(np.concatenate(prev_frames).flatten()))), (self.screen_size * 4, self.screen_size, 3))
image = np.transpose(image, [1, 0, 2])
pygame.surfarray.blit_array(rect, np.repeat(np.repeat(image, 10, axis=0), 10, axis=1))
self.screen.blit(rect, (0, 0))
pygame.display.flip()
import time
time.sleep(0.05)
def next_game(self):
pass
class SimpleBreakout(object):
def __init__(self):
self.action_set = [4, 7, 10]
self.finished = True
self.bar = [random.randint(0, 11), 10]
self.h = 12
self.w = 12
self.prev_frames = [np.zeros((self.h, self.w), dtype=np.uint8),
np.zeros((self.h, self.w), dtype=np.uint8),
np.zeros((self.h, self.w), dtype=np.uint8),
np.zeros((self.h, self.w), dtype=np.uint8)]
def reset_game(self):
self.finished = False
self.ball = [random.randint(0, 11), 0]
self.cum_reward = 0
self.bar = [random.randint(0, 11), 11]
def n_actions(self):
return len(self.action_set)
def input(self, action):
if action == 0:
self.bar[0] = max(self.bar[0] - 1, 0)
if action == 1:
self.bar[0] = min(self.bar[0] + 1, 11)
self.ball[1] += 1
# print(self.bar, action)
action_reward = 0
game_over = False
lol = False
if self.ball[1] == 11:
if abs(self.ball[0] - self.bar[0]) <= 1:
action_reward = 1
self.ball = [random.randint(0, 11), 0]
else:
game_over = True
lol = True
self.cum_reward += action_reward
self.prev_frames = self.prev_frames[1:]
self.prev_frames.append(self._as_frame())
self.finished = game_over
return action_reward, lol
def _as_frame(self):
frame = np.zeros((12, 12), dtype=np.uint8)
frame[self.ball[0], max(0, self.ball[1] - 1)] = 125
frame[self.ball[0], min(11, self.ball[1] + 1)] = 125
frame[self.ball[0], self.ball[1]] = 230
frame[self.bar[0], self.bar[1]] = 200
frame[max(0, self.bar[0]-1), self.bar[1]] = 180
frame[min(11, self.bar[0]+1), self.bar[1]] = 180
return frame
def get_state(self):
return self.prev_frames