|
| 1 | +from typing import List |
1 | 2 | import numpy as np
|
2 |
| -import torch |
3 |
| -import random |
4 |
| -import json |
5 |
| -import io |
6 |
| -from gym_csle_cyborg.dao.csle_cyborg_config import CSLECyborgConfig |
7 |
| -from gym_csle_cyborg.dao.red_agent_type import RedAgentType |
8 |
| -from gym_csle_cyborg.envs.cyborg_scenario_two_defender import CyborgScenarioTwoDefender |
| 3 | +import time |
9 | 4 | from gym_csle_cyborg.envs.cyborg_scenario_two_wrapper import CyborgScenarioTwoWrapper
|
10 | 5 | from gym_csle_cyborg.dao.csle_cyborg_wrapper_config import CSLECyborgWrapperConfig
|
11 | 6 | from csle_agents.agents.pomcp.pomcp import POMCP
|
12 | 7 | from csle_agents.agents.pomcp.pomcp_acquisition_function_type import POMCPAcquisitionFunctionType
|
13 | 8 | import csle_agents.constants.constants as agents_constants
|
14 | 9 | from csle_common.logging.log import Logger
|
| 10 | +from gym_csle_cyborg.util.cyborg_env_util import CyborgEnvUtil |
| 11 | +from gym_csle_cyborg.dao.red_agent_type import RedAgentType |
| 12 | + |
| 13 | + |
| 14 | +def heuristic_value(o: List[List[int]]) -> float: |
| 15 | + """ |
| 16 | + A heuristic value function |
| 17 | +
|
| 18 | + :param o: the observation vector |
| 19 | + :return: the value |
| 20 | + """ |
| 21 | + host_costs = CyborgEnvUtil.get_host_compromised_costs() |
| 22 | + val = 0 |
| 23 | + for i in range(len(o)): |
| 24 | + if o[i][2] > 0: |
| 25 | + val += host_costs[i] |
| 26 | + return val |
| 27 | + |
15 | 28 |
|
16 | 29 | if __name__ == '__main__':
|
17 |
| - # ppo_policy = PPOPolicy(model=None, simulation_name="", save_path="") |
18 |
| - config = CSLECyborgConfig( |
19 |
| - gym_env_name="csle-cyborg-scenario-two-v1", scenario=2, baseline_red_agents=[RedAgentType.B_LINE_AGENT], |
20 |
| - maximum_steps=100, red_agent_distribution=[1.0], reduced_action_space=True, decoy_state=True, |
21 |
| - scanned_state=True, decoy_optimization=False, cache_visited_states=False) |
22 |
| - eval_env = CyborgScenarioTwoDefender(config=config) |
23 |
| - config = CSLECyborgWrapperConfig(maximum_steps=100, gym_env_name="", |
24 |
| - save_trace=False, reward_shaping=False, scenario=2) |
| 30 | + config = CSLECyborgWrapperConfig( |
| 31 | + gym_env_name="csle-cyborg-scenario-two-wrapper-v1", maximum_steps=100, save_trace=False, scenario=2, |
| 32 | + reward_shaping=True, red_agent_type=RedAgentType.B_LINE_AGENT) |
| 33 | + eval_env = CyborgScenarioTwoWrapper(config=config) |
25 | 34 | train_env = CyborgScenarioTwoWrapper(config=config)
|
| 35 | + action_id_to_type_and_host, type_and_host_to_action_id \ |
| 36 | + = CyborgEnvUtil.get_action_dicts(scenario=2, reduced_action_space=True, decoy_state=True, |
| 37 | + decoy_optimization=False) |
26 | 38 |
|
27 |
| - num_evaluations = 10 |
28 |
| - max_horizon = 100 |
29 |
| - returns = [] |
30 |
| - seed = 215125 |
31 |
| - random.seed(seed) |
32 |
| - np.random.seed(seed) |
33 |
| - torch.manual_seed(seed) |
| 39 | + N = 5000 |
| 40 | + rollout_policy = lambda x, deterministic: 35 |
| 41 | + value_function = heuristic_value |
34 | 42 | A = train_env.get_action_space()
|
35 |
| - gamma = 0.75 |
36 |
| - c = 1 |
37 |
| - print("Starting policy evaluation") |
38 |
| - for i in range(num_evaluations): |
| 43 | + gamma = 0.99 |
| 44 | + reinvigoration = False |
| 45 | + reinvigorated_particles_ratio = 0.0 |
| 46 | + initial_particles = train_env.initial_particles |
| 47 | + planning_time = 3.75 |
| 48 | + prune_action_space = False |
| 49 | + max_particles = 1000 |
| 50 | + max_planning_depth = 50 |
| 51 | + max_rollout_depth = 4 |
| 52 | + c = 0.5 |
| 53 | + c2 = 15000 |
| 54 | + use_rollout_policy = False |
| 55 | + prior_weight = 5 |
| 56 | + prior_confidence = 0 |
| 57 | + acquisition_function_type = POMCPAcquisitionFunctionType.UCB |
| 58 | + log_steps_frequency = 1 |
| 59 | + max_negative_samples = 20 |
| 60 | + default_node_value = 0 |
| 61 | + verbose = False |
| 62 | + eval_batch_size = 100 |
| 63 | + max_env_steps = 100 |
| 64 | + prune_size = 3 |
| 65 | + start = time.time() |
| 66 | + |
| 67 | + # Run N episodes |
| 68 | + returns = [] |
| 69 | + for i in range(N): |
| 70 | + done = False |
| 71 | + action_sequence = [] |
39 | 72 | _, info = eval_env.reset()
|
40 | 73 | s = info[agents_constants.COMMON.STATE]
|
41 | 74 | train_env.reset()
|
42 |
| - initial_particles = train_env.initial_particles |
43 |
| - max_particles = 1000 |
44 |
| - planning_time = 60 |
45 |
| - value_function = lambda x: 0 |
46 |
| - reinvigoration = False |
47 |
| - rollout_policy = False |
48 |
| - verbose = False |
49 |
| - default_node_value = 0 |
50 |
| - prior_weight = 1 |
51 |
| - acquisition_function_type = POMCPAcquisitionFunctionType.UCB |
52 |
| - use_rollout_policy = False |
53 |
| - reinvigorated_particles_ratio = False |
54 |
| - prune_action_space = False |
55 |
| - prune_size = 3 |
56 |
| - prior_confidence = 0 |
57 | 75 | pomcp = POMCP(A=A, gamma=gamma, env=train_env, c=c, initial_particles=initial_particles,
|
58 | 76 | planning_time=planning_time, max_particles=max_particles, rollout_policy=rollout_policy,
|
59 | 77 | value_function=value_function, reinvigoration=reinvigoration, verbose=verbose,
|
60 | 78 | default_node_value=default_node_value, prior_weight=prior_weight,
|
61 |
| - acquisition_function_type=acquisition_function_type, c2=1500, |
| 79 | + acquisition_function_type=acquisition_function_type, c2=c2, |
62 | 80 | use_rollout_policy=use_rollout_policy, prior_confidence=prior_confidence,
|
63 | 81 | reinvigorated_particles_ratio=reinvigorated_particles_ratio,
|
64 | 82 | prune_action_space=prune_action_space, prune_size=prune_size)
|
65 |
| - rollout_depth = 4 |
66 |
| - planning_depth = 100 |
67 | 83 | R = 0
|
68 |
| - t = 0 |
69 |
| - action_sequence = [] |
70 |
| - while t < max_horizon: |
71 |
| - pomcp.solve(max_rollout_depth=rollout_depth, max_planning_depth=planning_depth) |
| 84 | + t = 1 |
| 85 | + |
| 86 | + # Run episode |
| 87 | + while not done and t <= max_env_steps: |
| 88 | + rollout_depth = max_rollout_depth |
| 89 | + planning_depth = max_planning_depth |
| 90 | + pomcp.solve(max_rollout_depth=rollout_depth, max_planning_depth=planning_depth, t=t) |
72 | 91 | action = pomcp.get_action()
|
73 |
| - o, r, done, _, info = eval_env.step(action) |
| 92 | + o, _, done, _, info = eval_env.step(action) |
| 93 | + r = info[agents_constants.COMMON.REWARD] |
74 | 94 | action_sequence.append(action)
|
75 | 95 | s_prime = info[agents_constants.COMMON.STATE]
|
76 | 96 | obs_id = info[agents_constants.COMMON.OBSERVATION]
|
77 |
| - pomcp.update_tree_with_new_samples(action_sequence=action_sequence, observation=obs_id) |
78 |
| - print(eval_env.get_true_table()) |
79 |
| - print(eval_env.get_table()) |
| 97 | + pomcp.update_tree_with_new_samples(action_sequence=action_sequence, observation=obs_id, t=t) |
80 | 98 | R += r
|
81 | 99 | t += 1
|
82 |
| - Logger.__call__().get_logger().info(f"[POMCP] t: {t}, a: {action}, r: {r}, o: {obs_id}, " |
83 |
| - f"s_prime: {s_prime}," |
84 |
| - f", action sequence: {action_sequence}, R: {R}") |
| 100 | + if t % log_steps_frequency == 0: |
| 101 | + Logger.__call__().get_logger().info(f"[POMCP] t: {t}, a: {action_id_to_type_and_host[action]}, r: {r}, " |
| 102 | + f"action sequence: {action_sequence}, R: {round(R, 2)}") |
| 103 | + |
| 104 | + # Logging |
85 | 105 | returns.append(R)
|
86 |
| - print(f"{i}/{num_evaluations}, avg R: {np.mean(returns)}, R: {R}") |
87 |
| - results = {} |
88 |
| - results["seed"] = seed |
89 |
| - results["training_time"] = 0 |
90 |
| - results["returns"] = returns |
91 |
| - results["planning_time"] = planning_time |
92 |
| - json_str = json.dumps(results, indent=4, sort_keys=True) |
93 |
| - with io.open(f"/Users/kim/pomcp_{0}_60s.json", 'w', encoding='utf-8') as f: |
94 |
| - f.write(json_str) |
| 106 | + progress = round((i + 1) / N, 2) |
| 107 | + time_elapsed_minutes = round((time.time() - start) / 60, 3) |
| 108 | + Logger.__call__().get_logger().info( |
| 109 | + f"[POMCP] episode: {i}, J:{R}, " |
| 110 | + f"J_avg: {np.mean(returns)}, " |
| 111 | + f"progress: {round(progress * 100, 2)}%, " |
| 112 | + f"runtime: {time_elapsed_minutes} min") |
0 commit comments