Skip to content

Commit 92d12d8

Browse files
authored
Merge pull request #365 from Limmen/fix_stopping_tests
update stopping_game tests
2 parents f2db9d5 + 2b996bf commit 92d12d8

File tree

3 files changed

+78
-80
lines changed

3 files changed

+78
-80
lines changed

simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_env.py

+18-49
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest.mock import patch, MagicMock
44
from gymnasium.spaces import Box, Discrete
55
import numpy as np
6+
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
67
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
78
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
89
from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState
@@ -23,19 +24,19 @@ def setup_env(self) -> None:
2324
:return: None
2425
"""
2526
env_name = "test_env"
26-
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
27-
O = np.array([0, 1])
28-
Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
27+
T = StoppingGameUtil.transition_tensor(L=3, p=0)
28+
O = StoppingGameUtil.observation_space(n=100)
29+
Z = StoppingGameUtil.observation_tensor(n=100)
2930
R = np.zeros((2, 3, 3, 3))
30-
S = np.array([0, 1, 2])
31-
A1 = np.array([0, 1, 2])
32-
A2 = np.array([0, 1, 2])
31+
S = StoppingGameUtil.state_space()
32+
A1 = StoppingGameUtil.defender_actions()
33+
A2 = StoppingGameUtil.attacker_actions()
3334
L = 2
3435
R_INT = 1
3536
R_COST = 2
3637
R_SLA = 3
3738
R_ST = 4
38-
b1 = np.array([0.6, 0.4])
39+
b1 = StoppingGameUtil.b1()
3940
save_dir = "save_directory"
4041
checkpoint_traces_freq = 100
4142
gamma = 0.9
@@ -69,12 +70,12 @@ def test_stopping_game_init_(self) -> None:
6970
7071
:return: None
7172
"""
72-
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
73-
O = np.array([0, 1])
74-
A1 = np.array([0, 1, 2])
75-
A2 = np.array([0, 1, 2])
73+
T = StoppingGameUtil.transition_tensor(L=3, p=0)
74+
O = StoppingGameUtil.observation_space(n=100)
75+
A1 = StoppingGameUtil.defender_actions()
76+
A2 = StoppingGameUtil.attacker_actions()
7677
L = 2
77-
b1 = np.array([0.6, 0.4])
78+
b1 = StoppingGameUtil.b1()
7879
attacker_observation_space = Box(
7980
low=np.array([0.0, 0.0, 0.0]),
8081
high=np.array([float(L), 1.0, 2.0]),
@@ -304,7 +305,7 @@ def test_is_state_terminal(self) -> None:
304305
assert not env.is_state_terminal(state_tuple)
305306

306307
with pytest.raises(ValueError):
307-
env.is_state_terminal([1, 2, 3]) # type: ignore
308+
env.is_state_terminal([1, 2, 3]) # type: ignore
308309

309310
def test_get_observation_from_history(self) -> None:
310311
"""
@@ -346,26 +347,6 @@ def test_step(self) -> None:
346347
:return: None
347348
"""
348349
env = StoppingGameEnv(self.config)
349-
env.state = MagicMock()
350-
env.state.s = 1
351-
env.state.l = 2
352-
env.state.t = 0
353-
env.state.attacker_observation.return_value = np.array([1, 2, 3])
354-
env.state.defender_observation.return_value = np.array([4, 5, 6])
355-
env.state.b = np.array([0.5, 0.5, 0.0])
356-
357-
env.trace = MagicMock()
358-
env.trace.defender_rewards = []
359-
env.trace.attacker_rewards = []
360-
env.trace.attacker_actions = []
361-
env.trace.defender_actions = []
362-
env.trace.infos = []
363-
env.trace.states = []
364-
env.trace.beliefs = []
365-
env.trace.infrastructure_metrics = []
366-
env.trace.attacker_observations = []
367-
env.trace.defender_observations = []
368-
369350
with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_state",
370351
return_value=2):
371352
with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_observation",
@@ -376,32 +357,20 @@ def test_step(self) -> None:
376357
1,
377358
(
378359
np.array(
379-
[[0.2, 0.8, 0.0], [0.6, 0.4, 0.0], [0.5, 0.5, 0.0]]
360+
[[0.2, 0.8], [0.6, 0.4], [0.5, 0.5]]
380361
),
381362
2,
382363
),
383364
)
384365
observations, rewards, terminated, truncated, info = env.step(
385366
action_profile
386367
)
387-
388-
assert (observations[0] == np.array([4, 5, 6])).all(), "Incorrect defender observations"
389-
assert (observations[1] == np.array([1, 2, 3])).all(), "Incorrect attacker observations"
368+
assert observations[0].all() == np.array([1, 0.7]).all(), "Incorrect defender observations"
369+
assert observations[1].all() == np.array([1, 2, 3]).all(), "Incorrect attacker observations"
390370
assert rewards == (0, 0)
391371
assert not terminated
392372
assert not truncated
393-
assert env.trace.defender_rewards[-1] == 0
394-
assert env.trace.attacker_rewards[-1] == 0
395-
assert env.trace.attacker_actions[-1] == 2
396-
assert env.trace.defender_actions[-1] == 1
397-
assert env.trace.infos[-1] == info
398-
assert env.trace.states[-1] == 2
399-
print(env.trace.beliefs)
400-
assert env.trace.beliefs[-1] == 0.7
401-
assert env.trace.infrastructure_metrics[-1] == 1
402-
assert (env.trace.attacker_observations[-1] == np.array([1, 2, 3])).all()
403-
assert (env.trace.defender_observations[-1] == np.array([4, 5, 6])).all()
404-
373+
405374
def test_info(self) -> None:
406375
"""
407376
Tests the function of adding the cumulative reward and episode length to the info dict

simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_mdp_attacker_env.py

+36-19
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
from gym_csle_stopping_game.dao.stopping_game_attacker_mdp_config import (
66
StoppingGameAttackerMdpConfig,
77
)
8+
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
89
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
910
from csle_common.dao.training.policy import Policy
11+
from csle_common.dao.training.random_policy import RandomPolicy
12+
from csle_common.dao.training.player_type import PlayerType
13+
from csle_common.dao.simulation_config.action import Action
1014
import pytest
1115
from unittest.mock import MagicMock
1216
import numpy as np
@@ -25,19 +29,19 @@ def setup_env(self) -> None:
2529
:return: None
2630
"""
2731
env_name = "test_env"
28-
T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]])
29-
O = np.array([0, 1])
30-
Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]])
32+
T = StoppingGameUtil.transition_tensor(L=3, p=0)
33+
O = StoppingGameUtil.observation_space(n=100)
34+
Z = StoppingGameUtil.observation_tensor(n=100)
3135
R = np.zeros((2, 3, 3, 3))
32-
S = np.array([0, 1, 2])
33-
A1 = np.array([0, 1, 2])
34-
A2 = np.array([0, 1, 2])
36+
S = StoppingGameUtil.state_space()
37+
A1 = StoppingGameUtil.defender_actions()
38+
A2 = StoppingGameUtil.attacker_actions()
3539
L = 2
3640
R_INT = 1
3741
R_COST = 2
3842
R_SLA = 3
3943
R_ST = 4
40-
b1 = np.array([0.6, 0.4])
44+
b1 = StoppingGameUtil.b1()
4145
save_dir = "save_directory"
4246
checkpoint_traces_freq = 100
4347
gamma = 0.9
@@ -107,9 +111,8 @@ def test_reset(self) -> None:
107111
)
108112

109113
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
110-
attacker_obs, info = env.reset()
111-
assert env.latest_defender_obs.all() == np.array([2, 0.4]).all() # type: ignore
112-
assert info == {}
114+
info = env.reset()
115+
assert info[-1] == {}
113116

114117
def test_set_model(self) -> None:
115118
"""
@@ -144,7 +147,7 @@ def test_set_state(self) -> None:
144147
)
145148

146149
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
147-
assert not env.set_state(1) # type: ignore
150+
assert not env.set_state(1) # type: ignore
148151

149152
def test_calculate_stage_policy(self) -> None:
150153
"""
@@ -190,7 +193,7 @@ def test_get_attacker_dist(self) -> None:
190193
def test_render(self) -> None:
191194
"""
192195
Tests the function for rendering the environment
193-
196+
194197
:return: None
195198
"""
196199
defender_strategy = MagicMock(spec=Policy)
@@ -317,7 +320,7 @@ def test_get_actions_from_particles(self) -> None:
317320
particles = [1, 2, 3]
318321
t = 0
319322
observation = 0
320-
expected_actions = [0, 1, 2]
323+
expected_actions = [0, 1]
321324
assert (
322325
env.get_actions_from_particles(particles, t, observation)
323326
== expected_actions
@@ -326,18 +329,32 @@ def test_get_actions_from_particles(self) -> None:
326329
def test_step(self) -> None:
327330
"""
328331
Tests the function for taking a step in the environment by executing the given action
329-
332+
330333
:return: None
331334
"""
332-
defender_strategy = MagicMock(spec=Policy)
335+
defender_stage_strategy = np.zeros((3, 2))
336+
defender_stage_strategy[0][0] = 0.9
337+
defender_stage_strategy[0][1] = 0.1
338+
defender_stage_strategy[1][0] = 0.9
339+
defender_stage_strategy[1][1] = 0.1
340+
defender_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A1))
341+
defender_strategy = RandomPolicy(
342+
actions=defender_actions,
343+
player_type=PlayerType.DEFENDER,
344+
stage_policy_tensor=list(defender_stage_strategy),
345+
)
333346
attacker_mdp_config = StoppingGameAttackerMdpConfig(
334347
env_name="test_env",
335348
stopping_game_config=self.config,
336349
defender_strategy=defender_strategy,
337350
stopping_game_name="csle-stopping-game-v1",
338351
)
339-
340352
env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
341-
pi2 = np.array([[0.5, 0.5]])
342-
with pytest.raises(AssertionError):
343-
env.step(pi2)
353+
env.reset()
354+
pi2 = env.calculate_stage_policy(o=list(env.latest_attacker_obs), a2=0) # type: ignore
355+
attacker_obs, reward, terminated, truncated, info = env.step(pi2)
356+
assert isinstance(attacker_obs[0], float) # type: ignore
357+
assert isinstance(terminated, bool) # type: ignore
358+
assert isinstance(truncated, bool) # type: ignore
359+
assert isinstance(reward, float) # type: ignore
360+
assert isinstance(info, dict) # type: ignore

simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_pomdp_defender_env.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import StoppingGamePomdpDefenderEnv
1+
from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import (
2+
StoppingGamePomdpDefenderEnv,
3+
)
24
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
3-
from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import StoppingGameDefenderPomdpConfig
5+
from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import (
6+
StoppingGameDefenderPomdpConfig,
7+
)
48
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
59
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
610
from csle_common.dao.training.policy import Policy
11+
from csle_common.dao.simulation_config.action import Action
712
from csle_common.dao.training.random_policy import RandomPolicy
813
from csle_common.dao.training.player_type import PlayerType
914
import pytest
@@ -219,7 +224,7 @@ def test_set_state(self) -> None:
219224
stopping_game_name="csle-stopping-game-v1",
220225
)
221226
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
222-
assert env.set_state(1) is None # type: ignore
227+
assert env.set_state(1) is None # type: ignore
223228

224229
def test_get_observation_from_history(self) -> None:
225230
"""
@@ -301,7 +306,10 @@ def test_get_actions_from_particles(self) -> None:
301306
t = 0
302307
observation = 0
303308
expected_actions = [0, 1]
304-
assert env.get_actions_from_particles(particles, t, observation) == expected_actions
309+
assert (
310+
env.get_actions_from_particles(particles, t, observation)
311+
== expected_actions
312+
)
305313

306314
def test_step(self) -> None:
307315
"""
@@ -315,8 +323,12 @@ def test_step(self) -> None:
315323
attacker_stage_strategy[1][0] = 0.9
316324
attacker_stage_strategy[1][1] = 0.1
317325
attacker_stage_strategy[2] = attacker_stage_strategy[1]
318-
attacker_strategy = RandomPolicy(actions=list(self.config.A2), player_type=PlayerType.ATTACKER,
319-
stage_policy_tensor=list(attacker_stage_strategy))
326+
attacker_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A2))
327+
attacker_strategy = RandomPolicy(
328+
actions=attacker_actions,
329+
player_type=PlayerType.ATTACKER,
330+
stage_policy_tensor=list(attacker_stage_strategy),
331+
)
320332
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
321333
env_name="test_env",
322334
stopping_game_config=self.config,
@@ -328,9 +340,9 @@ def test_step(self) -> None:
328340
env.reset()
329341
defender_obs, reward, terminated, truncated, info = env.step(a1)
330342
assert len(defender_obs) == 2
331-
assert isinstance(defender_obs[0], float) # type: ignore
332-
assert isinstance(defender_obs[1], float) # type: ignore
333-
assert isinstance(reward, float) # type: ignore
334-
assert isinstance(terminated, bool) # type: ignore
335-
assert isinstance(truncated, bool) # type: ignore
336-
assert isinstance(info, dict) # type: ignore
343+
assert isinstance(defender_obs[0], float) # type: ignore
344+
assert isinstance(defender_obs[1], float) # type: ignore
345+
assert isinstance(reward, float) # type: ignore
346+
assert isinstance(terminated, bool) # type: ignore
347+
assert isinstance(truncated, bool) # type: ignore
348+
assert isinstance(info, dict) # type: ignore

0 commit comments

Comments
 (0)