Skip to content

Commit 6c709e8

Browse files
authored
Merge pull request #366 from Limmen/apt-game
apt_game unit tests
2 parents 8855621 + e9d6494 commit 6c709e8

File tree

8 files changed

+904
-20
lines changed

8 files changed

+904
-20
lines changed

simulation-system/libs/csle-common/tests/test_flags_controller.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
2 1 2 2 5 45 6 1
2+
0 0
3+
1 0
4+
CONTINUE
5+
STOP
6+
CONTINUE
7+
STOP
8+
o_0
9+
o_1
10+
o_2
11+
o_3
12+
o_4
13+
0 1
14+
0 1
15+
0 1
16+
0 0 0 0 0 0.542067363614722
17+
0 0 0 1 0 0.25296476968687015
18+
0 0 0 2 0 0.12901203254030388
19+
0 0 0 3 0 0.05805541464313671
20+
0 0 0 4 0 0.01790041951496717
21+
0 0 1 0 0 0.4336538908917776
22+
0 0 1 1 0 0.20237181574949614
23+
0 0 1 2 0 0.10320962603224311
24+
0 0 1 3 0 0.046444331714509374
25+
0 0 1 4 0 0.014320335611973737
26+
0 0 1 0 1 0.029787234042553196
27+
0 0 1 1 1 0.03220241518113858
28+
0 0 1 2 1 0.035780461312376215
29+
0 0 1 3 1 0.042094660367501424
30+
0 0 1 4 1 0.0601352290964306
31+
0 1 0 0 0 0.542067363614722
32+
0 1 0 1 0 0.25296476968687015
33+
0 1 0 2 0 0.12901203254030388
34+
0 1 0 3 0 0.05805541464313671
35+
0 1 0 4 0 0.01790041951496717
36+
0 1 1 0 0 0.542067363614722
37+
0 1 1 1 0 0.25296476968687015
38+
0 1 1 2 0 0.12901203254030388
39+
0 1 1 3 0 0.05805541464313671
40+
0 1 1 4 0 0.01790041951496717
41+
1 0 0 0 1 0.14893617021276598
42+
1 0 0 1 1 0.1610120759056929
43+
1 0 0 2 1 0.17890230656188105
44+
1 0 0 3 1 0.2104733018375071
45+
1 0 0 4 1 0.300676145482153
46+
1 0 1 0 1 0.14893617021276598
47+
1 0 1 1 1 0.1610120759056929
48+
1 0 1 2 1 0.17890230656188105
49+
1 0 1 3 1 0.2104733018375071
50+
1 0 1 4 1 0.300676145482153
51+
1 1 0 0 0 0.542067363614722
52+
1 1 0 1 0 0.25296476968687015
53+
1 1 0 2 0 0.12901203254030388
54+
1 1 0 3 0 0.05805541464313671
55+
1 1 0 4 0 0.01790041951496717
56+
1 1 1 0 0 0.542067363614722
57+
1 1 1 1 0 0.25296476968687015
58+
1 1 1 2 0 0.12901203254030388
59+
1 1 1 3 0 0.05805541464313671
60+
1 1 1 4 0 0.01790041951496717
61+
0 1 0 -1.0
62+
0 1 1 -1.0
63+
1 0 0 -1.0
64+
1 0 1 -1.0
65+
1 1 0 1.0
66+
1 1 1 1.0
67+
0 1 0

simulation-system/libs/gym-csle-apt-game/src/gym_csle_apt_game/util/apt_game_util.py

-19
Original file line numberDiff line numberDiff line change
@@ -235,25 +235,6 @@ def bayes_filter(s_prime: int, o: int, a1: int, b: npt.NDArray[np.float_], pi2:
235235
assert round(b_prime_s_prime, 2) <= 1
236236
return b_prime_s_prime
237237

238-
@staticmethod
239-
def p_o_given_b_a1_a2(o: int, b: List[float], a1: int, a2: int, config: AptGameConfig) -> float:
240-
"""
241-
Computes P[o|a,b]
242-
243-
:param o: the observation
244-
:param b: the belief point
245-
:param a1: the action of player 1
246-
:param a2: the action of player 2
247-
:param config: the game config
248-
:return: the probability of observing o when taking action a in belief point b
249-
"""
250-
prob = 0
251-
for s in config.S:
252-
for s_prime in config.S:
253-
prob += b[s] * config.T[a1][a2][s][s_prime] * config.Z[a1][a2][s_prime][o]
254-
assert prob < 1
255-
return prob
256-
257238
@staticmethod
258239
def next_belief(o: int, a1: int, b: npt.NDArray[np.float_], pi2: npt.NDArray[Any],
259240
config: AptGameConfig, a2: int = 0, s: int = 0) -> npt.NDArray[np.float_]:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
from gym_csle_apt_game.envs.apt_game_env import AptGameEnv
2+
from gym_csle_apt_game.dao.apt_game_config import AptGameConfig
3+
from gym_csle_apt_game.util.apt_game_util import AptGameUtil
4+
from gym_csle_apt_game.dao.apt_game_state import AptGameState
5+
from unittest.mock import patch, MagicMock
6+
import csle_common.constants.constants as constants
7+
from csle_common.dao.simulation_config.simulation_trace import SimulationTrace
8+
import pytest
9+
import numpy as np
10+
11+
12+
class TestAptGameEnvSuite:
13+
"""
14+
Test suite for apt_game_env.py
15+
"""
16+
17+
@pytest.fixture(autouse=True)
18+
def setup_env(self) -> None:
19+
"""
20+
Sets up the configuration of the apt game
21+
22+
:return: None
23+
"""
24+
env_name = "test_env"
25+
N = 3
26+
p_a = 0.5
27+
T = AptGameUtil.transition_tensor(N, p_a)
28+
O = AptGameUtil.observation_space(100)
29+
Z = AptGameUtil.observation_tensor(100, N)
30+
C = AptGameUtil.cost_tensor(N)
31+
S = AptGameUtil.state_space(N=3)
32+
A1 = AptGameUtil.defender_actions()
33+
A2 = AptGameUtil.attacker_actions()
34+
b1 = AptGameUtil.b1(N=3)
35+
save_dir = "save_directory"
36+
checkpoint_traces_freq = 100
37+
gamma = 0.9
38+
self.config = AptGameConfig(
39+
env_name,
40+
T,
41+
O,
42+
Z,
43+
C,
44+
S,
45+
A1,
46+
A2,
47+
b1,
48+
N,
49+
p_a,
50+
save_dir,
51+
checkpoint_traces_freq,
52+
gamma,
53+
)
54+
self.env = AptGameEnv(self.config)
55+
56+
def test_init_(self) -> None:
57+
"""
58+
Tests the initializing function
59+
60+
:return: None
61+
"""
62+
assert self.env.config == self.config
63+
assert self.env.state == AptGameState(b1=self.config.b1)
64+
assert (
65+
self.env.attacker_observation_space
66+
== self.config.attacker_observation_space()
67+
)
68+
assert (
69+
self.env.defender_observation_space
70+
== self.config.defender_observation_space()
71+
)
72+
assert self.env.attacker_action_space == self.config.attacker_action_space()
73+
assert self.env.defender_action_space == self.config.defender_action_space()
74+
assert self.env.action_space == self.config.defender_action_space()
75+
assert self.env.observation_space == self.config.defender_observation_space()
76+
assert isinstance(self.env.traces, list)
77+
assert isinstance(self.env.trace, SimulationTrace)
78+
assert self.env.trace.simulation_env == self.config.env_name
79+
80+
def test_step(self) -> None:
81+
"""
82+
Tests the function for taking a step in the environment by executing the given action
83+
84+
:return: None
85+
"""
86+
initial_state = self.env.state.s
87+
action_profile = (
88+
0,
89+
(np.array([[0.5, 0.5], [0.5, 0.5], [0.4, 0.6], [0.5, 0.5]]), 1),
90+
)
91+
obs, reward, terminated, truncated, info = self.env.step(action_profile)
92+
assert obs[0].all() == self.env.state.defender_observation().all()
93+
assert obs[1] == self.env.state.attacker_observation()
94+
assert isinstance(terminated, bool) # type: ignore
95+
assert isinstance(truncated, bool) # type: ignore
96+
assert reward == (
97+
self.config.C[action_profile[0]][initial_state],
98+
-self.config.C[action_profile[0]][initial_state],
99+
)
100+
assert isinstance(info, dict) # type: ignore
101+
102+
def test_mean(self) -> None:
103+
"""
104+
Tests the utility function for getting the mean of a vector
105+
106+
:return: None
107+
"""
108+
test_cases = [
109+
([], 0), # Test case for an empty vector
110+
([5], 0), # Test case for a vector with a single element
111+
([0.2, 0.3, 0.5], 1.3), # Test case for a vector with multiple elements
112+
]
113+
for prob_vector, expected_mean in test_cases:
114+
result = AptGameEnv(self.config).mean(prob_vector)
115+
assert result == expected_mean
116+
117+
def test_info(self) -> None:
118+
"""
119+
Tests the function for adding the cumulative reward and episode length to the info dict
120+
121+
:return: None
122+
"""
123+
info = {} # type: ignore
124+
assert isinstance(self.env._info(info), dict) # type: ignore
125+
126+
def test_reset(self) -> None:
127+
"""
128+
Tests the function for reseting the environment state
129+
130+
:return: None
131+
"""
132+
self.env.trace.attacker_rewards = [1, 2, 3]
133+
self.env.trace.attacker_observations = ["obs1", "obs2"]
134+
self.env.trace.defender_observations = ["obs1", "obs2"]
135+
initial_trace_count = len(self.env.traces)
136+
AptGameState.reset = lambda self: None
137+
initial_obs, info = self.env.reset(seed=10, soft=False, options=None)
138+
assert len(self.env.traces) == initial_trace_count + 1
139+
assert isinstance(self.env.trace, SimulationTrace)
140+
assert self.env.trace.simulation_env == self.config.env_name
141+
assert self.env.trace.attacker_observations == [initial_obs[1]]
142+
assert self.env.trace.defender_observations == [initial_obs[0]]
143+
assert info == {}
144+
145+
def test_render(self) -> None:
146+
"""
147+
Tests the function of rendering the environment
148+
149+
:return: None
150+
"""
151+
with pytest.raises(NotImplementedError):
152+
self.env.render()
153+
154+
def test_is_defense_action_legal(self) -> None:
155+
"""
156+
Tests the function of checking whether a defender action in the environment is legal or not
157+
158+
:return: None
159+
"""
160+
assert self.env.is_defense_action_legal(1)
161+
162+
def test_is_attack_action_legal(self) -> None:
163+
"""
164+
Tests the function of checking whether an attacker action in the environment is legal or not
165+
166+
:return: None
167+
"""
168+
assert self.env.is_attack_action_legal(1)
169+
170+
def test_get_traces(self) -> None:
171+
"""
172+
Tests the function of getting the list of simulation traces
173+
174+
:return: None
175+
"""
176+
assert self.env.get_traces() == self.env.traces
177+
178+
def test_reset_traces(self) -> None:
179+
"""
180+
Tests the function of resetting the list of traces
181+
182+
:return: None
183+
"""
184+
self.env.traces = ["trace1", "trace2"]
185+
self.env.reset_traces()
186+
assert self.env.traces == []
187+
188+
def test_checkpoint_traces(self) -> None:
189+
"""
190+
Tests the function of checkpointing agent traces
191+
192+
:return: None
193+
"""
194+
195+
fixed_timestamp = 123
196+
with patch("time.time", return_value=fixed_timestamp):
197+
with patch(
198+
"csle_common.dao.simulation_config.simulation_trace.SimulationTrace.save_traces"
199+
) as mock_save_traces:
200+
self.env.traces = ["trace1", "trace2"]
201+
self.env._AptGameEnv__checkpoint_traces()
202+
mock_save_traces.assert_called_once_with(
203+
traces_save_dir=constants.LOGGING.DEFAULT_LOG_DIR,
204+
traces=self.env.traces,
205+
traces_file=f"taus{fixed_timestamp}.json",
206+
)
207+
208+
def test_set_model(self) -> None:
209+
"""
210+
Tests the function of setting the model
211+
212+
:return: None
213+
"""
214+
mock_model = MagicMock()
215+
self.env.set_model(mock_model)
216+
assert self.env.model == mock_model
217+
218+
def test_set_state(self) -> None:
219+
"""
220+
Tests the function of setting the state
221+
222+
:return: None
223+
"""
224+
self.env.state = MagicMock()
225+
mock_state = MagicMock(spec=AptGameState)
226+
self.env.set_state(mock_state)
227+
assert self.env.state == mock_state
228+
229+
state_int = 5
230+
self.env.set_state(state_int)
231+
assert self.env.state.s == state_int
232+
233+
with pytest.raises(ValueError):
234+
self.env.set_state([1, 2, 3]) # type: ignore

0 commit comments

Comments
 (0)