-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDP_Gridworld.py
executable file
·78 lines (59 loc) · 2.35 KB
/
DP_Gridworld.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
class DPGridworld:
def __init__(self) -> None:
self.theta = 0.01 # threshold to break out of policy eval loop
self.gridsize = 4
# the value of each state under the current policy
self.state_values = np.zeros(shape=(self.gridsize, self.gridsize))
def evaluatePolicy(self):
count=0
while True:
# for debug
if count % 1 == 0:
print(f"after {count} iterations")
self.prettyPrintGrid()
count+=1
delta = 0
# loop over all states
for r in range(self.gridsize):
for c in range(self.gridsize):
# skip if (r,c) is a terminal state
if (r,c) == (0,0) or (r,c) == (self.gridsize-1, self.gridsize-1):
continue
# else set reward
else:
reward = -1
# store old value for delta
old_v = self.state_values[r,c]
# gets a list of valid s' in tuple form: (r, c)
valid_moves = self.getValidMoves(r, c)
# get new state value by summing across valid actions
new_v = 0
for (sp_r, sp_c) in valid_moves:
sp_v = self.state_values[sp_r, sp_c]
new_v += (1/len(valid_moves)) * (reward + sp_v)
self.state_values[r, c] = new_v
# update delta
delta = abs(old_v-new_v)
if delta < self.theta:
print(f"delta({round(delta, 5)}) < theta({round(self.theta, 5)}) after {count} iterations")
self.prettyPrintGrid()
break
def getValidMoves(self, r, c):
moves = [(r-1,c),(r+1,c),(r,c-1),(r,c+1)]
valid_moves = []
for a_r, a_c in moves:
if a_r<0 or a_r>self.gridsize-1 or a_c<0 or a_c>self.gridsize-1:
continue
else:
valid_moves.append((a_r, a_c))
return valid_moves
def prettyPrintGrid(self):
for r in self.state_values:
for v in r:
print(round(v, 2), end="\t")
print()
print()
if __name__ == "__main__":
gw = DPGridworld()
gw.evaluatePolicy()