Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit caf5130

Browse files
author
Jules Pondard
committed
Add vanilla MCTS/UCT algorithm for options
Generate options using UCT algorithm. Useful for benchmarking.
1 parent 05405ef commit caf5130

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import tensor_comprehensions as tc
2+
import torch
3+
import my_utils
4+
import numpy as np
5+
#from tqdm import tqdm
6+
from visdom import Visdom
7+
8+
viz = Visdom()
9+
10+
class Node:
11+
def __init__(self, father=None, new_act=0):
12+
self.value = 0
13+
self.values = []
14+
self.nbVisits=0
15+
self.nbChildrenSeen = 0
16+
self.pos=0
17+
#self.hasSeen = {} #todo
18+
self.children=[]
19+
self.parent = father
20+
self.stateVector = [0] * my_utils.NB_HYPERPARAMS
21+
if(father != None):
22+
self.pos = father.pos+1
23+
#self.hasSeen = {} #todo
24+
self.stateVector = father.stateVector[:]
25+
self.stateVector[self.pos-1] = new_act
26+
27+
def getRoot(self):
28+
return self
29+
30+
def getParent(self):
31+
return self.parent
32+
33+
def notRoot(self):
34+
return (self.parent != None)
35+
36+
class MCTS:
37+
def __init__(self):
38+
self.C = 1 #to tune
39+
40+
(tc_code, tc_name, inp, _) = my_utils.get_convolution_example(size_type="input", inp_sz_list=[8,2,28,28,8,1,1])
41+
42+
my_utils.computeCat(inp)
43+
my_utils.set_tc(tc_code, tc_name)
44+
45+
self.nbActions = my_utils.cat_sz
46+
self.tree = Node()
47+
48+
self.best_rewards = []
49+
self.rws = []
50+
51+
self.curIter=0
52+
self.curr_best=0
53+
self.running_reward=0
54+
self.win0 = viz.line(X=np.arange(5), Y=np.random.rand(5))
55+
56+
def main_search(self, starting_pos): #, init_inp):
57+
node = starting_pos
58+
#node.nbVisits+=1
59+
ttNbIters = 10 #2*self.nbActions[node.pos]
60+
for _ in range(max(ttNbIters, self.nbActions[node.pos])):
61+
leaf = self.getLeaf(node)
62+
val = self.evaluate(leaf)
63+
self.backup(leaf, val)
64+
#print(node.value / node.nbVisits)
65+
_, action = self.getBestChild2(node)
66+
return action
67+
68+
def take_action(self, node, act):
69+
if(node.nbChildrenSeen > act):
70+
return node.children[act]
71+
new_child = Node(father=node, new_act=act)
72+
node.children.append(new_child)
73+
#node.hasSeen[act]=1
74+
node.nbChildrenSeen += 1
75+
return node.children[-1]
76+
77+
def getLeaf(self, node):
78+
first=True
79+
while(node.pos < my_utils.NB_HYPERPARAMS and (first or node.nbVisits != 0)):
80+
first=False
81+
pos = node.pos
82+
if(node.nbChildrenSeen == self.nbActions[pos]):
83+
node, _ = self.getBestChild(node)
84+
else:
85+
act=node.nbChildrenSeen
86+
self.take_action(node, act)
87+
return node.children[-1]
88+
return node
89+
90+
def getBestChild2(self, node):
91+
bestIndic = 0.
92+
bestAction = 0
93+
first=True
94+
pos = node.pos
95+
for act in range(self.nbActions[pos]):
96+
child = node.children[act]
97+
#indic = np.percentile(child.values, 20)
98+
indic = child.value / child.nbVisits
99+
if(first or indic > bestIndic):
100+
bestIndic = indic
101+
bestAction = act
102+
first=False
103+
return node.children[bestAction], bestAction
104+
105+
def getBestChild(self, node):
106+
bestIndic = 0.
107+
bestAction = 0
108+
first=True
109+
pos = node.pos
110+
for act in range(self.nbActions[pos]):
111+
child = node.children[act]
112+
#indic = np.percentile(child.values, 20) + self.C * np.sqrt(2*np.log(node.nbVisits) / child.nbVisits)
113+
indic = child.value / child.nbVisits + self.C * np.sqrt(2*np.log(node.nbVisits) / child.nbVisits)
114+
if(first or indic > bestIndic):
115+
bestIndic = indic
116+
bestAction = act
117+
first=False
118+
return node.children[bestAction], bestAction
119+
120+
def saveReward(self, reward, opts):
121+
INTER_DISP = 20
122+
#print(-reward)
123+
if(self.curIter == 0):
124+
self.running_reward = reward
125+
self.curr_best = reward
126+
if(self.curIter == 0 or reward > self.curr_best):
127+
print(-reward)
128+
print(opts)
129+
self.curIter += 1
130+
self.running_reward = self.running_reward * 0.99 + reward * 0.01
131+
self.curr_best = max(self.curr_best, reward)
132+
#self.rewards.append(-reward)
133+
self.best_rewards.append(-self.curr_best)
134+
self.rws.append(-self.running_reward)
135+
if self.curIter % INTER_DISP == 0:
136+
viz.line(X=np.column_stack((np.arange(self.curIter), np.arange(self.curIter))), \
137+
Y=np.column_stack((np.array(self.rws), np.array(self.best_rewards))), \
138+
win=self.win0, opts=dict(legend=["Geometric run", "Best time"]))
139+
140+
def randomSampleScoreFrom(self, node):
141+
pos = node.pos
142+
optsVector = node.stateVector
143+
for i in range(my_utils.NB_HYPERPARAMS - (pos)):
144+
a = np.random.randint(self.nbActions[i+pos])
145+
optsVector[i+(pos)] = a
146+
#print(optsVector)
147+
reward = -np.log(my_utils.evalTime(optsVector))
148+
self.saveReward(reward, optsVector)
149+
return reward
150+
151+
def evaluate(self, leaf):
152+
score = 0
153+
nb_iters=5
154+
for _ in range(nb_iters):
155+
score += self.randomSampleScoreFrom(leaf)
156+
return score / nb_iters
157+
158+
def backup(self, leaf, val):
159+
#if(val > 10.): #infty
160+
# return
161+
node = leaf
162+
while(node.notRoot()):
163+
node.nbVisits += 1
164+
#node.values.append(val)
165+
node.value += val
166+
node = node.getParent()
167+
node.nbVisits += 1
168+
node.value += val
169+
node.values.append(val)
170+
171+
mcts = MCTS()
172+
173+
opts = []
174+
curr_node = mcts.tree
175+
for i in range(my_utils.NB_HYPERPARAMS):
176+
opts.append(mcts.main_search(curr_node))
177+
curr_node = mcts.take_action(curr_node, opts[-1])
178+
print(opts)
179+
opts = np.array(opts).astype(int)
180+
print(my_utils.evalTime(opts.tolist()))
181+
my_utils.print_opt(opts)

0 commit comments

Comments
 (0)