3
3
from unittest .mock import patch , MagicMock
4
4
from gymnasium .spaces import Box , Discrete
5
5
import numpy as np
6
+ from gym_csle_stopping_game .util .stopping_game_util import StoppingGameUtil
6
7
from gym_csle_stopping_game .envs .stopping_game_env import StoppingGameEnv
7
8
from gym_csle_stopping_game .dao .stopping_game_config import StoppingGameConfig
8
9
from gym_csle_stopping_game .dao .stopping_game_state import StoppingGameState
@@ -23,19 +24,19 @@ def setup_env(self) -> None:
23
24
:return: None
24
25
"""
25
26
env_name = "test_env"
26
- T = np . array ([[[ 0.1 , 0.9 ], [ 0.4 , 0.6 ]], [[ 0.7 , 0.3 ], [ 0.2 , 0.8 ]]] )
27
- O = np . array ([ 0 , 1 ] )
28
- Z = np . array ([[[ 0.8 , 0.2 ], [ 0.5 , 0.5 ]], [[ 0.4 , 0.6 ], [ 0.9 , 0.1 ]]] )
27
+ T = StoppingGameUtil . transition_tensor ( L = 3 , p = 0 )
28
+ O = StoppingGameUtil . observation_space ( n = 100 )
29
+ Z = StoppingGameUtil . observation_tensor ( n = 100 )
29
30
R = np .zeros ((2 , 3 , 3 , 3 ))
30
- S = np . array ([ 0 , 1 , 2 ] )
31
- A1 = np . array ([ 0 , 1 , 2 ] )
32
- A2 = np . array ([ 0 , 1 , 2 ] )
31
+ S = StoppingGameUtil . state_space ( )
32
+ A1 = StoppingGameUtil . defender_actions ( )
33
+ A2 = StoppingGameUtil . attacker_actions ( )
33
34
L = 2
34
35
R_INT = 1
35
36
R_COST = 2
36
37
R_SLA = 3
37
38
R_ST = 4
38
- b1 = np . array ([ 0.6 , 0.4 ] )
39
+ b1 = StoppingGameUtil . b1 ( )
39
40
save_dir = "save_directory"
40
41
checkpoint_traces_freq = 100
41
42
gamma = 0.9
@@ -69,12 +70,12 @@ def test_stopping_game_init_(self) -> None:
69
70
70
71
:return: None
71
72
"""
72
- T = np . array ([[[ 0.1 , 0.9 ], [ 0.4 , 0.6 ]], [[ 0.7 , 0.3 ], [ 0.2 , 0.8 ]]] )
73
- O = np . array ([ 0 , 1 ] )
74
- A1 = np . array ([ 0 , 1 , 2 ] )
75
- A2 = np . array ([ 0 , 1 , 2 ] )
73
+ T = StoppingGameUtil . transition_tensor ( L = 3 , p = 0 )
74
+ O = StoppingGameUtil . observation_space ( n = 100 )
75
+ A1 = StoppingGameUtil . defender_actions ( )
76
+ A2 = StoppingGameUtil . attacker_actions ( )
76
77
L = 2
77
- b1 = np . array ([ 0.6 , 0.4 ] )
78
+ b1 = StoppingGameUtil . b1 ( )
78
79
attacker_observation_space = Box (
79
80
low = np .array ([0.0 , 0.0 , 0.0 ]),
80
81
high = np .array ([float (L ), 1.0 , 2.0 ]),
@@ -304,7 +305,7 @@ def test_is_state_terminal(self) -> None:
304
305
assert not env .is_state_terminal (state_tuple )
305
306
306
307
with pytest .raises (ValueError ):
307
- env .is_state_terminal ([1 , 2 , 3 ]) # type: ignore
308
+ env .is_state_terminal ([1 , 2 , 3 ]) # type: ignore
308
309
309
310
def test_get_observation_from_history (self ) -> None :
310
311
"""
@@ -346,26 +347,6 @@ def test_step(self) -> None:
346
347
:return: None
347
348
"""
348
349
env = StoppingGameEnv (self .config )
349
- env .state = MagicMock ()
350
- env .state .s = 1
351
- env .state .l = 2
352
- env .state .t = 0
353
- env .state .attacker_observation .return_value = np .array ([1 , 2 , 3 ])
354
- env .state .defender_observation .return_value = np .array ([4 , 5 , 6 ])
355
- env .state .b = np .array ([0.5 , 0.5 , 0.0 ])
356
-
357
- env .trace = MagicMock ()
358
- env .trace .defender_rewards = []
359
- env .trace .attacker_rewards = []
360
- env .trace .attacker_actions = []
361
- env .trace .defender_actions = []
362
- env .trace .infos = []
363
- env .trace .states = []
364
- env .trace .beliefs = []
365
- env .trace .infrastructure_metrics = []
366
- env .trace .attacker_observations = []
367
- env .trace .defender_observations = []
368
-
369
350
with patch ("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_state" ,
370
351
return_value = 2 ):
371
352
with patch ("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_observation" ,
@@ -376,32 +357,20 @@ def test_step(self) -> None:
376
357
1 ,
377
358
(
378
359
np .array (
379
- [[0.2 , 0.8 , 0.0 ], [0.6 , 0.4 , 0.0 ], [0.5 , 0.5 , 0.0 ]]
360
+ [[0.2 , 0.8 ], [0.6 , 0.4 ], [0.5 , 0.5 ]]
380
361
),
381
362
2 ,
382
363
),
383
364
)
384
365
observations , rewards , terminated , truncated , info = env .step (
385
366
action_profile
386
367
)
387
-
388
- assert (observations [0 ] == np .array ([4 , 5 , 6 ])).all (), "Incorrect defender observations"
389
- assert (observations [1 ] == np .array ([1 , 2 , 3 ])).all (), "Incorrect attacker observations"
368
+ assert observations [0 ].all () == np .array ([1 , 0.7 ]).all (), "Incorrect defender observations"
369
+ assert observations [1 ].all () == np .array ([1 , 2 , 3 ]).all (), "Incorrect attacker observations"
390
370
assert rewards == (0 , 0 )
391
371
assert not terminated
392
372
assert not truncated
393
- assert env .trace .defender_rewards [- 1 ] == 0
394
- assert env .trace .attacker_rewards [- 1 ] == 0
395
- assert env .trace .attacker_actions [- 1 ] == 2
396
- assert env .trace .defender_actions [- 1 ] == 1
397
- assert env .trace .infos [- 1 ] == info
398
- assert env .trace .states [- 1 ] == 2
399
- print (env .trace .beliefs )
400
- assert env .trace .beliefs [- 1 ] == 0.7
401
- assert env .trace .infrastructure_metrics [- 1 ] == 1
402
- assert (env .trace .attacker_observations [- 1 ] == np .array ([1 , 2 , 3 ])).all ()
403
- assert (env .trace .defender_observations [- 1 ] == np .array ([4 , 5 , 6 ])).all ()
404
-
373
+
405
374
def test_info (self ) -> None :
406
375
"""
407
376
Tests the function of adding the cumulative reward and episode length to the info dict
0 commit comments