1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ #
3+ # This source code is licensed under the MIT license found in the
4+ # LICENSE file in the root directory of this source tree.
5+
6+ import torch
7+ import torchrl
8+ from tensordict import TensorDict
9+ import time
10+
11+ start_time = time .time ()
12+
13+ pgn_or_fen = "fen"
14+ mask_actions = True
15+
16+ env = torchrl .envs .ChessEnv (
17+ include_pgn = False ,
18+ include_fen = True ,
19+ include_hash = True ,
20+ include_hash_inv = True ,
21+ include_san = True ,
22+ stateful = True ,
23+ mask_actions = mask_actions ,
24+ )
25+
26+
27+ def transform_reward (td ):
28+ if "reward" not in td :
29+ return td
30+ reward = td ["reward" ]
31+ if reward == 0.5 :
32+ td ["reward" ] = 0
33+ elif reward == 1 and td ["turn" ]:
34+ td ["reward" ] = - td ["reward" ]
35+ return td
36+
37+
38+ # ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
39+ # Need to transform the reward to be:
40+ # white win = 1
41+ # draw = 0
42+ # black win = -1
43+ env = env .append_transform (transform_reward )
44+
45+ forest = torchrl .data .MCTSForest ()
46+ forest .reward_keys = env .reward_keys
47+ forest .done_keys = env .done_keys
48+ forest .action_keys = env .action_keys
49+
50+ if mask_actions :
51+ forest .observation_keys = [f"{ pgn_or_fen } _hash" , "turn" , "action_mask" ]
52+ else :
53+ forest .observation_keys = [f"{ pgn_or_fen } _hash" , "turn" ]
54+
55+ C = 2.0 ** 0.5
56+
57+
58+ def traversal_priority_UCB1 (tree ):
59+ subtree = tree .subtree
60+ visits = subtree .visits
61+ reward_sum = subtree .wins
62+
63+ # If it's black's turn, flip the reward, since black wants to
64+ # optimize for the lowest reward, not highest.
65+ if not subtree .rollout [0 , 0 ]["turn" ]:
66+ reward_sum = - reward_sum
67+
68+ parent_visits = tree .visits
69+ reward_sum = reward_sum .squeeze (- 1 )
70+ priority = (reward_sum + C * torch .sqrt (torch .log (parent_visits ))) / visits
71+ priority [visits == 0 ] = float ("inf" )
72+ return priority
73+
74+
75+ def _traverse_MCTS_one_step (forest , tree , env , max_rollout_steps ):
76+ done = False
77+ trees_visited = [tree ]
78+
79+ while not done :
80+ if tree .subtree is None :
81+ td_tree = tree .rollout [- 1 ]["next" ].clone ()
82+
83+ if (tree .visits > 0 or tree .parent is None ) and not td_tree ["done" ]:
84+ actions = env .all_actions (td_tree )
85+ subtrees = []
86+
87+ for action in actions :
88+ td = env .step (env .reset (td_tree ).update (action ))
89+ new_node = torchrl .data .Tree (
90+ rollout = td .unsqueeze (0 ),
91+ node_data = td ["next" ].select (* forest .node_map .in_keys ),
92+ count = torch .tensor (0 ),
93+ wins = torch .zeros_like (td ["next" ]["reward" ]),
94+ )
95+ subtrees .append (new_node )
96+
97+ # NOTE: This whole script runs about 2x faster with lazy stack
98+ # versus eager stack.
99+ tree .subtree = TensorDict .lazy_stack (subtrees )
100+ chosen_idx = torch .randint (0 , len (subtrees ), ()).item ()
101+ rollout_state = subtrees [chosen_idx ].rollout [- 1 ]["next" ]
102+
103+ else :
104+ rollout_state = td_tree
105+
106+ if rollout_state ["done" ]:
107+ rollout_reward = rollout_state ["reward" ]
108+ else :
109+ rollout = env .rollout (
110+ max_steps = max_rollout_steps ,
111+ tensordict = rollout_state ,
112+ )
113+ rollout_reward = rollout [- 1 ]["next" , "reward" ]
114+ done = True
115+
116+ else :
117+ priorities = traversal_priority_UCB1 (tree )
118+ chosen_idx = torch .argmax (priorities ).item ()
119+ tree = tree .subtree [chosen_idx ]
120+ trees_visited .append (tree )
121+
122+ for tree in trees_visited :
123+ tree .visits += 1
124+ tree .wins += rollout_reward
125+
126+
127+ def traverse_MCTS (forest , root , env , num_steps , max_rollout_steps ):
128+ """Performs Monte-Carlo tree search in an environment.
129+
130+ Args:
131+ forest (MCTSForest): Forest of the tree to update. If the tree does not
132+ exist yet, it is added.
133+ root (TensorDict): The root step of the tree to update.
134+ env (EnvBase): Environment to performs actions in.
135+ num_steps (int): Number of iterations to traverse.
136+ max_rollout_steps (int): Maximum number of steps for each rollout.
137+ """
138+ if root not in forest :
139+ for action in env .all_actions (root ):
140+ td = env .step (env .reset (root .clone ()).update (action ))
141+ forest .extend (td .unsqueeze (0 ))
142+
143+ tree = forest .get_tree (root )
144+ tree .wins = torch .zeros_like (td ["next" , "reward" ])
145+ for subtree in tree .subtree :
146+ subtree .wins = torch .zeros_like (td ["next" , "reward" ])
147+
148+ for _ in range (num_steps ):
149+ _traverse_MCTS_one_step (forest , tree , env , max_rollout_steps )
150+
151+ return tree
152+
153+
154+ def tree_format_fn (tree ):
155+ td = tree .rollout [- 1 ]["next" ]
156+ return [
157+ td ["san" ],
158+ td [pgn_or_fen ].split ("\n " )[- 1 ],
159+ tree .wins ,
160+ tree .visits ,
161+ ]
162+
163+
164+ def get_best_move (fen , mcts_steps , rollout_steps ):
165+ root = env .reset (TensorDict ({"fen" : fen }))
166+ tree = traverse_MCTS (forest , root , env , mcts_steps , rollout_steps )
167+ moves = []
168+
169+ for subtree in tree .subtree :
170+ san = subtree .rollout [0 ]["next" , "san" ]
171+ reward_sum = subtree .wins
172+ visits = subtree .visits
173+ value_avg = (reward_sum / visits ).item ()
174+ if not subtree .rollout [0 ]["turn" ]:
175+ value_avg = - value_avg
176+ moves .append ((value_avg , san ))
177+
178+ moves = sorted (moves , key = lambda x : - x [0 ])
179+
180+ print ("------------------" )
181+ for value_avg , san in moves :
182+ print (f" { value_avg :0.02f} { san } " )
183+ print ("------------------" )
184+
185+ return moves [0 ][1 ]
186+
187+
188+ # White has M1, best move Rd8#. Any other moves lose to M2 or M1.
189+ fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
190+ assert get_best_move (fen0 , 100 , 10 ) == "Rd8#"
191+
192+ # Black has M1, best move Qg6#. Other moves give rough equality or worse.
193+ fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
194+ assert get_best_move (fen1 , 100 , 10 ) == "Qg6#"
195+
196+ # White has M2, best move Rxg8+. Any other move loses.
197+ fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
198+ assert get_best_move (fen2 , 1000 , 10 ) == "Rxg8+"
199+
200+ end_time = time .time ()
201+ total_time = end_time - start_time
202+
203+ print (f"Took { total_time } s" )
0 commit comments