-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmonte_carlo.py
More file actions
173 lines (119 loc) · 4.46 KB
/
monte_carlo.py
File metadata and controls
173 lines (119 loc) · 4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from network import NeuralNet
from tree_node import TreeNode
from numpy.typing import NDArray
from typing import Tuple
from data_preprocess import node_to_tensor
from game_node import GameNode
from config import *
class MonteCarlo:
"""
Monte Carlo Tree Search
Wrap Tree Node w/ Evaluation Function
Args:
model: the NN to use (assumes is on cpu)
root: root of the tree
device: the device to place the model on
always_allow_pass: whether to always allow pass
"""
def __init__(self, model: NeuralNet, root: TreeNode, device: str = DEVICE, always_allow_pass: bool = False):
self.always_allow_pass = always_allow_pass
self.device = device
self.model = model.to(self.device)
self.root = root
self.curr = root
def __str__(self):
return f"""Current node: {str(self.curr)}
Current node children: {'[' + ', '.join([str(a) for a in self.curr.nexts]) + ']'}
"""
def reset(self):
"""
Deletes all tree nodes including root
This instance should not be used after calling reset
"""
temp = self.root
self.curr = None
self.delete_node(temp)
del temp
def delete_node(self, node: TreeNode):
"""
Deletes the given node and all children
Args:
node: the root of the subtree to delete
"""
for child in node.nexts:
if child is not None:
self.delete_node(child)
del node
def select(self, node: TreeNode) -> TreeNode:
"""
Returns the leaf node from the subtree rooted at node
resulting from greedily maximizing Q+U for all children
Args:
node: the TreeNode to start selection from
"""
while not node.is_leaf():
# Select the child with the best Q + U value
coef = 1 if (node.move % 2) == 0 else -1
best_child = max(node.nexts, key=lambda child: coef * child.Q_value() + child.u_value())
node = best_child
return node
def evaluate(self, node: TreeNode) -> Tuple[float, NDArray]:
"""
Get the network's eval (value AND policy) of the current game state
Consider evaluating on all d8 transformations and taking the mean
if model is fast enough
Args:
node: the node to evaluate
"""
self.model.eval()
with torch.no_grad():
out = self.model.forward(node_to_tensor(node).unsqueeze(0).to(self.device))
return out[1].item(), out[0].squeeze(0).detach().cpu().numpy() # trust me guys it works
def expand(self, node: TreeNode, prior: NDArray, allow_pass: bool = True) -> None:
"""
Adds all valid children to the tree an initializes
all values as described in the slides
Args:
node: the TreeNode from select
prior: the precomputed output from the policy head
"""
if node.is_terminal():
return
node.get_children(allow_pass=(allow_pass or self.always_allow_pass))
for child in node.nexts:
move = child.prev_move
i = move[0] * self.curr.size + move[1]
child.prior = prior[i]
def search(self) -> None:
"""
Performs one iteration of search
Note that this doesn't return a value because the goal of search is improve the tree,
not to actually "search" for a specific node.
"""
selected = self.select(self.curr)
val, policy = self.evaluate(selected)
if self.curr.move > NUM_MOVES_ALLOW_PASS:
self.expand(selected, policy, allow_pass=True)
else:
self.expand(selected, policy, allow_pass=False)
selected.backprop(val)
def move_curr(self, loc: Tuple[int, int]) -> None:
"""
Moves the curr node forward by playing the action at loc
Args:
loc: the location of the move to play
"""
for child in self.curr.nexts:
if child.prev_move == loc:
self.curr = child
return
raise ValueError(f"Child from move at {loc} not found. Maybe you forgot to search?")
if __name__ == "__main__":
nn = NeuralNet()
root_node = GameNode(9)
game_tree = MonteCarlo(
model=nn,
root=TreeNode(root_node)
)
game_tree.search()
print(game_tree)