@@ -4225,43 +4225,62 @@ def test_env_reset_with_hash(self, stateful, include_san):
42254225 td_check = env .reset (td .select ("fen_hash" ))
42264226 assert (td_check == td ).all ()
42274227
4228- @pytest .mark .parametrize ("include_fen" , [False , True ])
4229- @pytest .mark .parametrize ("include_pgn" , [False , True ])
4228+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[False , True ], [True , False ]])
42304229 @pytest .mark .parametrize ("stateful" , [False , True ])
4231- @pytest .mark .parametrize ("mask_actions" , [False , True ])
4232- def test_all_actions (self , include_fen , include_pgn , stateful , mask_actions ):
4233- if not stateful and not include_fen and not include_pgn :
4234- # pytest.skip("fen or pgn must be included if not stateful")
4235- return
4236-
4230+ @pytest .mark .parametrize ("include_hash" , [False , True ])
4231+ @pytest .mark .parametrize ("include_san" , [False , True ])
4232+ @pytest .mark .parametrize ("append_transform" , [False , True ])
4233+ @pytest .mark .parametrize ("mask_actions" , [True ])
4234+ def test_all_actions (
4235+ self ,
4236+ include_fen ,
4237+ include_pgn ,
4238+ stateful ,
4239+ include_hash ,
4240+ include_san ,
4241+ append_transform ,
4242+ mask_actions ,
4243+ ):
42374244 env = ChessEnv (
42384245 include_fen = include_fen ,
42394246 include_pgn = include_pgn ,
4247+ include_san = include_san ,
4248+ include_hash = include_hash ,
4249+ include_hash_inv = include_hash ,
42404250 stateful = stateful ,
42414251 mask_actions = mask_actions ,
42424252 )
4243- td = env .reset ()
42444253
4245- if not mask_actions :
4246- with pytest .raises (RuntimeError , match = "Cannot generate legal actions" ):
4247- env .all_actions ()
4248- return
4254+ def transform_reward (td ):
4255+ if "reward" not in td :
4256+ return td
4257+ reward = td ["reward" ]
4258+ if reward == 0.5 :
4259+ td ["reward" ] = 0
4260+ elif reward == 1 and td ["turn" ]:
4261+ td ["reward" ] = - td ["reward" ]
4262+ return td
4263+
4264+ if append_transform :
4265+ env = env .append_transform (transform_reward )
4266+
4267+ check_env_specs (env )
4268+
4269+ td = env .reset ()
42494270
42504271 # Choose random actions from the output of `all_actions`
4251- for _ in range (100 ):
4252- if stateful :
4253- all_actions = env .all_actions ()
4254- else :
4272+ for step_idx in range (100 ):
4273+ if step_idx % 5 == 0 :
42554274 # Reset the the initial state first, just to make sure
42564275 # `all_actions` knows how to get the board state from the input.
42574276 env .reset ()
4258- all_actions = env .all_actions (td .clone ())
4277+ all_actions = env .all_actions (td .clone ())
42594278
42604279 # Choose some random actions and make sure they match exactly one of
42614280 # the actions from `all_actions`. This part is not tested when
42624281 # `mask_actions == False`, because `rand_action` can pick illegal
42634282 # actions in that case.
4264- if mask_actions :
4283+ if mask_actions and step_idx % 4 == 0 :
42654284 # TODO: Something is wrong in `ChessEnv.rand_action` which makes
42664285 # it fail to work properly for stateless mode. It doesn't know
42674286 # how to correctly reset the board state to what is given in the
@@ -4278,7 +4297,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
42784297
42794298 action_idx = torch .randint (0 , all_actions .shape [0 ], ()).item ()
42804299 chosen_action = all_actions [action_idx ]
4281- td = env .step (td .update (chosen_action ))["next" ]
4300+ td_new = env .step (td .update (chosen_action ).clone ())
4301+ assert (td == td_new .exclude ("next" )).all ()
4302+ td = td_new ["next" ]
42824303
42834304 if td ["done" ]:
42844305 td = env .reset ()
0 commit comments