Path: blob/main/cyberbattle/simulation/generate_network.py
960 views
# Copyright (c) Microsoft Corporation.1# Licensed under the MIT License.23"""Generating random graphs"""45from cyberbattle.simulation.model import Identifiers, NodeID, CredentialID, PortName, FirewallConfiguration, FirewallRule, RulePermission6import numpy as np7import networkx as nx8from cyberbattle.simulation import model as m9import random10from typing import List, Optional, Tuple, DefaultDict1112from collections import defaultdict1314ENV_IDENTIFIERS = Identifiers(15properties=["breach_node"],16ports=["SMB", "HTTP", "RDP"],17local_vulnerabilities=["ScanWindowsCredentialManagerForRDP", "ScanWindowsExplorerRecentFiles", "ScanWindowsCredentialManagerForSMB"],18remote_vulnerabilities=["Traceroute"],19)202122def generate_random_traffic_network(23n_clients: int = 200,24n_servers={25"SMB": 1,26"HTTP": 1,27"RDP": 1,28},29seed: Optional[int] = 0,30tolerance: np.float32 = np.float32(1e-3),31alpha=np.array([(0.1, 0.3), (0.18, 0.09)], dtype=float),32beta=np.array([(100, 10), (10, 100)], dtype=float),33) -> nx.DiGraph:34"""35Randomly generate a directed multi-edge network graph representing36fictitious SMB, HTTP, and RDP traffic.3738Arguments:39n_clients: number of workstation nodes that can initiate sessions with server nodes40n_servers: dictionary indicatin the numbers of each nodes listening to each protocol41seed: seed for the psuedo-random number generator42tolerance: absolute tolerance for bounding the edge probabilities in [tolerance, 1-tolerance]43alpha: beta distribution parameters alpha such that E(edge prob) = alpha / beta44beta: beta distribution parameters beta such that E(edge prob) = alpha / beta4546Returns:47(nx.classes.multidigraph.MultiDiGraph): the randomly generated network from the hierarchical block model48"""49edges_labels = defaultdict(set) # set backed multidict5051for protocol in list(n_servers.keys()):52sizes = [n_clients, n_servers[protocol]]53# sample edge probabilities from a beta distribution54np.random.seed(seed)55probs: np.ndarray = np.random.beta(a=alpha, b=beta, size=(2, 2))5657# scale by edge type58if protocol == "SMB":59probs = 3 * probs60if protocol == "RDP":61probs = 4 * probs6263# don't allow probs too close to zero or one64probs = np.clip(probs, a_min=tolerance, a_max=np.float32(1.0 - tolerance))6566# sample edges using block models given edge probabilities67di_graph_for_protocol = nx.stochastic_block_model(sizes=sizes, p=probs, directed=True, seed=seed)6869for edge in di_graph_for_protocol.edges:70edges_labels[edge].add(protocol)7172digraph = nx.DiGraph()73for (u, v), port in list(edges_labels.items()):74digraph.add_edge(u, v, protocol=port)75return digraph767778def cyberbattle_model_from_traffic_graph(79traffic_graph: nx.DiGraph,80cached_smb_password_probability=0.75,81cached_rdp_password_probability=0.8,82cached_accessed_network_shares_probability=0.6,83cached_password_has_changed_probability=0.1,84traceroute_discovery_probability=0.5,85probability_two_nodes_use_same_password_to_access_given_resource=0.8,86) -> nx.DiGraph:87"""Generate a random CyberBattle network model from a specified traffic (directed multi) graph.8889The input graph can for instance be generated with `generate_random_traffic_network`.90Each edge of the input graph indicates that a communication took place91between the two nodes with the protocol specified in the edge label.9293Returns a CyberBattle network with the same nodes and implanted vulnerabilities94to be used to instantiate a CyverBattleSim gym.9596Arguments:9798cached_smb_password_probability, cached_rdp_password_probability:99probability that a password used for authenticated traffic was cached by the OS for SMB and RDP100cached_accessed_network_shares_probability:101probability that a network share accessed by the system was cached by the OS102cached_password_has_changed_probability:103probability that a given password cached on a node has been rotated on the target node104(typically low has people tend to change their password infrequently)105probability_two_nodes_use_same_password_to_access_given_resource:106as the variable name says107traceroute_discovery_probability:108probability that a target node of an SMB/RDP connection get exposed by a traceroute attack109"""110# convert node IDs to string111graph = nx.relabel_nodes(traffic_graph, {i: str(i) for i in traffic_graph.nodes})112113password_counter: int = 0114115def generate_password() -> CredentialID:116nonlocal password_counter117password_counter = password_counter + 1118return f"unique_pwd{password_counter}"119120def traffic_targets(source_node: NodeID, protocol: str) -> List[NodeID]:121neighbors = [t for (s, t) in graph.edges() if s == source_node and protocol in graph.edges[(s, t)]["protocol"]]122return neighbors123124# Map (node, port name) -> assigned pwd125assigned_passwords: DefaultDict[Tuple[NodeID, PortName], List[CredentialID]] = defaultdict(list)126127def assign_new_valid_password(node: NodeID, port: PortName) -> CredentialID:128pwd = generate_password()129assigned_passwords[node, port].append(pwd)130return pwd131132def reuse_valid_password(node: NodeID, port: PortName) -> CredentialID:133"""Reuse a password already assigned to that node an port, if none is already134assigned create and assign a new valid password"""135if (node, port) not in assigned_passwords:136return assign_new_valid_password(node, port)137138# reuse any of the existing assigne valid password for that node/port139return random.choice(assigned_passwords[node, port])140141def create_cached_credential(node: NodeID, port: PortName) -> CredentialID:142if random.random() < cached_password_has_changed_probability:143# generate a new invalid password144return generate_password()145else:146if random.random() < probability_two_nodes_use_same_password_to_access_given_resource:147return reuse_valid_password(node, port)148else:149return assign_new_valid_password(node, port)150151def add_leak_neighbors_vulnerability(node_id: m.NodeID, library: Optional[m.VulnerabilityLibrary] = None) -> m.VulnerabilityLibrary:152"""Create random vulnerabilities153that reveals immediate traffic neighbors from a given node"""154155if not library:156library = {}157158rdp_neighbors = traffic_targets(node_id, "RDP")159160if len(rdp_neighbors) > 0:161library["ScanWindowsCredentialManagerForRDP"] = m.VulnerabilityInfo(162description="Look for RDP credentials in the Windows Credential Manager",163type=m.VulnerabilityType.LOCAL,164outcome=m.LeakedCredentials(165credentials=[166m.CachedCredential(node=target_node, port="RDP", credential=create_cached_credential(target_node, "RDP"))167for target_node in rdp_neighbors168if random.random() < cached_rdp_password_probability169]170),171reward_string="Discovered creds in the Windows Credential Manager",172cost=2.0,173)174175smb_neighbors = traffic_targets(node_id, "SMB")176177if len(smb_neighbors) > 0:178library["ScanWindowsExplorerRecentFiles"] = m.VulnerabilityInfo(179description="Look for network shares in the Windows Explorer Recent files",180type=m.VulnerabilityType.LOCAL,181outcome=m.LeakedNodesId([target_node for target_node in smb_neighbors if random.random() < cached_accessed_network_shares_probability]),182reward_string="Windows Explorer Recent Files revealed network shares",183cost=1.0,184)185186library["ScanWindowsCredentialManagerForSMB"] = m.VulnerabilityInfo(187description="Look for network credentials in the Windows Credential Manager",188type=m.VulnerabilityType.LOCAL,189outcome=m.LeakedCredentials(190credentials=[191m.CachedCredential(node=target_node, port="SMB", credential=create_cached_credential(target_node, "SMB"))192for target_node in smb_neighbors193if random.random() < cached_smb_password_probability194]195),196reward_string="Discovered SMB creds in the Windows Credential Manager",197cost=2.0,198)199200if len(smb_neighbors) > 0 and len(rdp_neighbors) > 0:201library["Traceroute"] = m.VulnerabilityInfo(202description="Attempt to discvover network nodes using Traceroute",203type=m.VulnerabilityType.REMOTE,204outcome=m.LeakedNodesId([target_node for target_node in smb_neighbors or rdp_neighbors if random.random() < traceroute_discovery_probability]),205reward_string="Discovered new network nodes via traceroute",206cost=5.0,207)208209return library210211def create_vulnerabilities_from_traffic_data(node_id: m.NodeID):212return add_leak_neighbors_vulnerability(node_id=node_id)213214firewall_conf = FirewallConfiguration(215[FirewallRule("RDP", RulePermission.ALLOW), FirewallRule("SMB", RulePermission.ALLOW)], [FirewallRule("RDP", RulePermission.ALLOW), FirewallRule("SMB", RulePermission.ALLOW)]216)217218# Pick a random node as the agent entry node219entry_node_index = random.randrange(len(graph.nodes))220entry_node_id, entry_node_data = list(graph.nodes(data=True))[entry_node_index]221graph.nodes[entry_node_id].clear()222graph.nodes[entry_node_id].update(223{224"data": m.NodeInfo(225services=[],226value=0,227properties=["breach_node"],228vulnerabilities=create_vulnerabilities_from_traffic_data(entry_node_id),229agent_installed=True,230firewall=firewall_conf,231reimagable=False,232)233}234)235236def create_node_data_without_vulnerabilities(node_id: m.NodeID):237return m.NodeInfo(238services=[m.ListeningService(name=port, allowedCredentials=assigned_passwords[(target_node, port)]) for (target_node, port) in assigned_passwords.keys() if target_node == node_id],239value=random.randint(0, 100),240agent_installed=False,241firewall=firewall_conf,242)243244# Step 1: Create all the nodes with associated services and firewall configuration245for node in list(graph.nodes):246if node != entry_node_id:247graph.nodes[node].clear()248graph.nodes[node].update({"data": create_node_data_without_vulnerabilities(node)})249250# Step 2: Assign vulnerabilities to each node.251# This must be a separate step because vulnerabilities definitions252# may depend on the passwords assigned to the nodes in Step 1.253for node in list(graph.nodes):254if node != entry_node_id:255node_data = graph.nodes[node]["data"]256node_data.vulnerabilities = create_vulnerabilities_from_traffic_data(node)257graph.nodes[node].update({"data": node_data})258259# remove all the edges inherited from the network graph260graph.clear_edges()261262return graph263264265def new_environment(n_servers_per_protocol: int):266"""Create a new simulation environment based on267a randomly generated network topology.268269NOTE: the probabilities and parameter values used270here for the statistical generative model271were arbirarily picked. We recommend exploring different values for those parameters.272"""273traffic = generate_random_traffic_network(274seed=None,275n_clients=50,276n_servers={277"SMB": n_servers_per_protocol,278"HTTP": n_servers_per_protocol,279"RDP": n_servers_per_protocol,280},281alpha=np.array([(1, 1), (0.2, 0.5)], dtype=float),282beta=np.array([(1000, 10), (10, 100)], dtype=float),283)284285network = cyberbattle_model_from_traffic_graph(286traffic,287cached_rdp_password_probability=0.8,288cached_smb_password_probability=0.7,289cached_accessed_network_shares_probability=0.8,290cached_password_has_changed_probability=0.01,291probability_two_nodes_use_same_password_to_access_given_resource=0.9,292)293return m.Environment(network=network, vulnerability_library=dict([]), identifiers=ENV_IDENTIFIERS)294295296