Skip to content

Commit 3df633c

Browse files
committed
Merge branch 'master' of https://github.com/Limmen/csle
2 parents 8d13b14 + 21f8773 commit 3df633c

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from typing import List
2+
import csle_common.constants.constants as constants
3+
from csle_common.dao.training.experiment_config import ExperimentConfig
4+
from csle_common.metastore.metastore_facade import MetastoreFacade
5+
from csle_common.dao.training.agent_type import AgentType
6+
from csle_common.dao.training.hparam import HParam
7+
from csle_common.dao.training.player_type import PlayerType
8+
from csle_agents.agents.pomcp.pomcp_agent import POMCPAgent
9+
from csle_agents.agents.pomcp.pomcp_acquisition_function_type import POMCPAcquisitionFunctionType
10+
import csle_agents.constants.constants as agents_constants
11+
from csle_agents.common.objective_type import ObjectiveType
12+
from csle_common.dao.simulation_config.simulation_env_config import SimulationEnvConfig
13+
from gym_csle_cyborg.dao.csle_cyborg_wrapper_config import CSLECyborgWrapperConfig
14+
from gym_csle_cyborg.envs.cyborg_scenario_two_wrapper import CyborgScenarioTwoWrapper
15+
from gym_csle_cyborg.dao.red_agent_type import RedAgentType
16+
from gym_csle_cyborg.dao.csle_cyborg_config import CSLECyborgConfig
17+
from gym_csle_cyborg.util.cyborg_env_util import CyborgEnvUtil
18+
19+
20+
def heuristic_value(o: List[List[int]]) -> float:
21+
"""
22+
A heuristic value function
23+
24+
:param o: the observation vector
25+
:return: the value
26+
"""
27+
host_costs = CyborgEnvUtil.get_host_compromised_costs()
28+
val = 0
29+
for i in range(len(o)):
30+
if o[i][2] > 0:
31+
val += host_costs[i]
32+
return val
33+
34+
35+
if __name__ == '__main__':
36+
emulation_name = "csle-level9-040"
37+
emulation_env_config = None
38+
simulation_name = "csle-cyborg-001"
39+
simulation_env_config = SimulationEnvConfig(name="", version="", gym_env_name="", simulation_env_input_config="",
40+
players_config="", joint_action_space_config="",
41+
joint_observation_space_config="", time_step_type=None,
42+
reward_function_config=None, transition_operator_config=None,
43+
observation_function_config=None,
44+
initial_state_distribution_config=None, env_parameters_config=None,
45+
plot_transition_probabilities=False, plot_observation_function=False,
46+
plot_reward_function=False, descr="", state_space_config=None)
47+
eval_env_config = CSLECyborgConfig(
48+
gym_env_name="csle-cyborg-scenario-two-v1", scenario=2, baseline_red_agents=[RedAgentType.B_LINE_AGENT],
49+
maximum_steps=100, red_agent_distribution=[1.0], reduced_action_space=True, scanned_state=True,
50+
decoy_state=True, decoy_optimization=False, cache_visited_states=True, save_trace=False,
51+
randomize_topology=False)
52+
simulation_env_config.simulation_env_input_config = CSLECyborgWrapperConfig(
53+
gym_env_name="csle-cyborg-scenario-two-wrapper-v1", maximum_steps=100, save_trace=False, scenario=2,
54+
reward_shaping=True, red_agent_type=RedAgentType.B_LINE_AGENT)
55+
simulation_env_config.gym_env_name = "csle-cyborg-scenario-two-wrapper-v1"
56+
csle_cyborg_env = CyborgScenarioTwoWrapper(config=simulation_env_config.simulation_env_input_config)
57+
A = csle_cyborg_env.get_action_space()
58+
initial_particles = csle_cyborg_env.initial_particles
59+
rollout_policy = lambda x, deterministic: 35
60+
value_function = heuristic_value
61+
experiment_config = ExperimentConfig(
62+
output_dir=f"{constants.LOGGING.DEFAULT_LOG_DIR}pomcp_test", title="POMCP test",
63+
random_seeds=[555512, 98912, 999, 555],
64+
agent_type=AgentType.POMCP,
65+
log_every=1,
66+
hparams={
67+
agents_constants.POMCP.N: HParam(value=5000, name=agents_constants.POMCP.N,
68+
descr="the number of episodes"),
69+
agents_constants.POMCP.OBJECTIVE_TYPE: HParam(
70+
value=ObjectiveType.MAX, name=agents_constants.POMCP.OBJECTIVE_TYPE,
71+
descr="the type of objective (max or min)"),
72+
agents_constants.POMCP.ROLLOUT_POLICY: HParam(
73+
value=rollout_policy, name=agents_constants.POMCP.ROLLOUT_POLICY,
74+
descr="the policy to use for rollouts"),
75+
agents_constants.POMCP.VALUE_FUNCTION: HParam(
76+
value=value_function, name=agents_constants.POMCP.VALUE_FUNCTION,
77+
descr="the value function to use for truncated rollouts"),
78+
agents_constants.POMCP.A: HParam(value=A, name=agents_constants.POMCP.A, descr="the action space"),
79+
agents_constants.POMCP.GAMMA: HParam(value=0.99, name=agents_constants.POMCP.GAMMA,
80+
descr="the discount factor"),
81+
agents_constants.POMCP.REINVIGORATION: HParam(value=False, name=agents_constants.POMCP.REINVIGORATION,
82+
descr="whether reinvigoration should be used"),
83+
agents_constants.POMCP.REINVIGORATED_PARTICLES_RATIO: HParam(
84+
value=0.0, name=agents_constants.POMCP.REINVIGORATED_PARTICLES_RATIO,
85+
descr="the ratio of reinvigorated particles in the particle filter"),
86+
agents_constants.POMCP.INITIAL_PARTICLES: HParam(value=initial_particles,
87+
name=agents_constants.POMCP.INITIAL_PARTICLES,
88+
descr="the initial belief"),
89+
agents_constants.POMCP.PLANNING_TIME: HParam(value=3.75, name=agents_constants.POMCP.PLANNING_TIME,
90+
descr="the planning time"),
91+
agents_constants.POMCP.PRUNE_ACTION_SPACE: HParam(
92+
value=False, name=agents_constants.POMCP.PRUNE_ACTION_SPACE,
93+
descr="boolean flag indicating whether the action space should be pruned or not"),
94+
agents_constants.POMCP.PRUNE_SIZE: HParam(
95+
value=3, name=agents_constants.POMCP.PRUNE_ACTION_SPACE, descr="size of the pruned action space"),
96+
agents_constants.POMCP.MAX_PARTICLES: HParam(value=1000, name=agents_constants.POMCP.MAX_PARTICLES,
97+
descr="the maximum number of belief particles"),
98+
agents_constants.POMCP.MAX_PLANNING_DEPTH: HParam(
99+
value=50, name=agents_constants.POMCP.MAX_PLANNING_DEPTH, descr="the maximum depth for planning"),
100+
agents_constants.POMCP.MAX_ROLLOUT_DEPTH: HParam(value=4, name=agents_constants.POMCP.MAX_ROLLOUT_DEPTH,
101+
descr="the maximum depth for rollout"),
102+
agents_constants.POMCP.C: HParam(value=0.5, name=agents_constants.POMCP.C,
103+
descr="the weighting factor for UCB exploration"),
104+
agents_constants.POMCP.C2: HParam(value=15000, name=agents_constants.POMCP.C2,
105+
descr="the weighting factor for AlphaGo exploration"),
106+
agents_constants.POMCP.USE_ROLLOUT_POLICY: HParam(
107+
value=False, name=agents_constants.POMCP.USE_ROLLOUT_POLICY,
108+
descr="boolean flag indicating whether rollout policy should be used"),
109+
agents_constants.POMCP.PRIOR_WEIGHT: HParam(value=5, name=agents_constants.POMCP.PRIOR_WEIGHT,
110+
descr="the weight on the prior"),
111+
agents_constants.POMCP.PRIOR_CONFIDENCE: HParam(value=0, name=agents_constants.POMCP.PRIOR_CONFIDENCE,
112+
descr="the prior confidence"),
113+
agents_constants.POMCP.ACQUISITION_FUNCTION_TYPE: HParam(
114+
value=POMCPAcquisitionFunctionType.UCB, name=agents_constants.POMCP.ACQUISITION_FUNCTION_TYPE,
115+
descr="the type of acquisition function"),
116+
agents_constants.POMCP.LOG_STEP_FREQUENCY: HParam(
117+
value=1, name=agents_constants.POMCP.LOG_STEP_FREQUENCY, descr="frequency of logging time-steps"),
118+
agents_constants.POMCP.MAX_NEGATIVE_SAMPLES: HParam(
119+
value=20, name=agents_constants.POMCP.MAX_NEGATIVE_SAMPLES,
120+
descr="maximum number of negative samples when filling belief particles"),
121+
agents_constants.POMCP.DEFAULT_NODE_VALUE: HParam(
122+
value=0, name=agents_constants.POMCP.DEFAULT_NODE_VALUE, descr="the default node value in "
123+
"the search tree"),
124+
agents_constants.POMCP.VERBOSE: HParam(value=False, name=agents_constants.POMCP.VERBOSE,
125+
descr="verbose logging flag"),
126+
agents_constants.POMCP.EVAL_ENV_NAME: HParam(value="csle-cyborg-scenario-two-v1",
127+
name=agents_constants.POMCP.EVAL_ENV_NAME,
128+
descr="the name of the evaluation environment"),
129+
agents_constants.POMCP.EVAL_ENV_CONFIG: HParam(value=eval_env_config,
130+
name=agents_constants.POMCP.EVAL_ENV_CONFIG,
131+
descr="the configuration of the evaluation environment"),
132+
agents_constants.COMMON.EVAL_BATCH_SIZE: HParam(value=100, name=agents_constants.COMMON.EVAL_BATCH_SIZE,
133+
descr="number of evaluation episodes"),
134+
agents_constants.COMMON.CONFIDENCE_INTERVAL: HParam(
135+
value=0.95, name=agents_constants.COMMON.CONFIDENCE_INTERVAL,
136+
descr="confidence interval"),
137+
agents_constants.COMMON.MAX_ENV_STEPS: HParam(
138+
value=100, name=agents_constants.COMMON.MAX_ENV_STEPS,
139+
descr="maximum number of steps in the environment (for envs with infinite horizon generally)"),
140+
agents_constants.COMMON.RUNNING_AVERAGE: HParam(
141+
value=100, name=agents_constants.COMMON.RUNNING_AVERAGE,
142+
descr="the number of samples to include when computing the running avg")
143+
},
144+
player_type=PlayerType.DEFENDER, player_idx=0
145+
)
146+
agent = POMCPAgent(emulation_env_config=emulation_env_config, simulation_env_config=simulation_env_config,
147+
experiment_config=experiment_config, save_to_metastore=False)
148+
experiment_execution = agent.train()
149+
MetastoreFacade.save_experiment_execution(experiment_execution)

0 commit comments

Comments
 (0)