diff --git a/src/l3_attackmitigator/README.md b/src/l3_attackmitigator/README.md index 04c937a1d35e91071e0357278c81b33335e2e37a..fad689a5e5b64f9d64481b6d819f087a27e5f492 100644 --- a/src/l3_attackmitigator/README.md +++ b/src/l3_attackmitigator/README.md @@ -1,3 +1,6 @@ # l3_attackmitigator -- Receives packages and process it with TSTAT -- Functions: ReportSummarizeKpi(KpiList) +- Receives detected attacks from the Centralized Attack Detector component and performs the necessary mitigations. +- Functions: + - PerformMitigation(self, request: L3AttackmitigatorOutput) + - GetMitigation(self, request: Empty) + - GetConfiguredACLRules(self, request: Empty) diff --git a/src/l3_centralizedattackdetector/Config.py b/src/l3_centralizedattackdetector/Config.py index f6c7e33553820b1214e5265cf219db629bcfe006..809380b2cda1c8c556f973e570de36e3189edb99 100644 --- a/src/l3_centralizedattackdetector/Config.py +++ b/src/l3_centralizedattackdetector/Config.py @@ -18,7 +18,7 @@ import logging LOG_LEVEL = logging.WARNING # gRPC settings -GRPC_SERVICE_PORT = 10001 # TODO UPM FIXME +GRPC_SERVICE_PORT = 10001 GRPC_MAX_WORKERS = 10 GRPC_GRACE_PERIOD = 60 diff --git a/src/l3_centralizedattackdetector/README.md b/src/l3_centralizedattackdetector/README.md index bcec4052cc9aa2ea734e08a4ed6b9158609b3532..0569132915165bd500d5a0caee9a5a222f8b4500 100644 --- a/src/l3_centralizedattackdetector/README.md +++ b/src/l3_centralizedattackdetector/README.md @@ -1,3 +1,6 @@ # l3_centralizedattackdetector -- Receives packages and process it with TSTAT -- Functions: ReportSummarizeKpi(KpiList) +- Receives snapshot statistics from Distributed Attack Detector component and performs an inference to detect attacks. It then sends the detected attacks to the Attack Mitigator component for them to be mitigated. +- Functions: + - AnalyzeConnectionStatistics(self, request: L3CentralizedattackdetectorMetrics) + - AnalyzeBatchConnectionStatistics(self, request: L3CentralizedattackdetectorBatchInput) + - GetFeaturesIds(self, request: Empty) diff --git a/src/l3_distributedattackdetector/Config.py b/src/l3_distributedattackdetector/Config.py index e04de0b2622b621fb95f1c382ac3a152836de760..a1419ef09c9b3dcbff5aa576536fae8ffe6bc7a4 100644 --- a/src/l3_distributedattackdetector/Config.py +++ b/src/l3_distributedattackdetector/Config.py @@ -18,7 +18,7 @@ import logging LOG_LEVEL = logging.WARNING # gRPC settings -GRPC_SERVICE_PORT = 10000 # TODO UPM FIXME +GRPC_SERVICE_PORT = 10000 GRPC_MAX_WORKERS = 10 GRPC_GRACE_PERIOD = 60 diff --git a/src/l3_distributedattackdetector/README.md b/src/l3_distributedattackdetector/README.md index d8cac8b72d41c6eb6ce2b2908e6ab7402966ad62..9b87f3944f844fa555948ec1435d9e68d0937913 100644 --- a/src/l3_distributedattackdetector/README.md +++ b/src/l3_distributedattackdetector/README.md @@ -1,3 +1,2 @@ # l3_distributedattackdetector -- Receives packages and process it with TSTAT -- Functions: ReportSummarizeKpi(KpiList) +- Receives packages and processes them with TSTAT to generate traffic snapshot statistics diff --git a/src/l3_distributedattackdetector/requirements.in b/src/l3_distributedattackdetector/requirements.in index a8aba849708799232f6b0732c3661396266da329..01a759eac1419c596bec983cbad08eca1c310d4d 100644 --- a/src/l3_distributedattackdetector/requirements.in +++ b/src/l3_distributedattackdetector/requirements.in @@ -12,4 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# no extra dependency +grpcio-health-checking>=1.47.* +grpcio>=1.47.* +grpcio-tools>=1.47.* +protobuf>=3.20.* +numpy +asyncio \ No newline at end of file diff --git a/src/l3_distributedattackdetector/service/__init__.py b/src/l3_distributedattackdetector/service/__init__.py index 1549d9811aa5d1c193a44ad45d0d7773236c0612..f80ccfd52ebfd4fa1783267201c52eb7381741bf 100644 --- a/src/l3_distributedattackdetector/service/__init__.py +++ b/src/l3_distributedattackdetector/service/__init__.py @@ -10,5 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. - +# limitations under the License. \ No newline at end of file diff --git a/src/l3_distributedattackdetector/service/__main__.py b/src/l3_distributedattackdetector/service/__main__.py index 1f558dfb6c271cf63a9e36ae06cb9993f7e49c57..a8f0ac3c4f9737091c2c1a39134b97ee7bd6de7d 100644 --- a/src/l3_distributedattackdetector/service/__main__.py +++ b/src/l3_distributedattackdetector/service/__main__.py @@ -13,207 +13,39 @@ # limitations under the License. import logging -import sys -import os -import time -import grpc -from common.proto.l3_centralizedattackdetector_pb2_grpc import ( - L3CentralizedattackdetectorStub, -) -from common.proto.l3_centralizedattackdetector_pb2 import ( - ModelInput, -) +from sys import stdout +import sys +from l3_distributedattackdetector import l3_distributedattackdetector -LOGGER = logging.getLogger(__name__) -TSTAT_DIR_NAME = "piped/" -JSON_BLANK = { - "ip_o": "", # Client IP - "port_o": "", # Client port - "ip_d": "", # Server ip - "port_d": "", # Server port - "flow_id": "", # Identifier:c_ip,c_port,s_ip,s_port,time_start - "protocol": "", # Connection protocol - "time_start": 0, # Start of connection - "time_end": 0, # Time of last packet -} +# Setup LOGGER +LOGGER = logging.getLogger("main_dad_LOGGER") +LOGGER.setLevel(logging.INFO) +logFormatter = logging.Formatter(fmt="%(levelname)-8s %(message)s") +consoleHandler = logging.StreamHandler(stdout) +consoleHandler.setFormatter(logFormatter) +LOGGER.addHandler(consoleHandler) -def follow(thefile, time_sleep): - """ - Generator function that yields new lines in a file - It reads the logfie (the opened file) - """ - # seek the end of the file - thefile.seek(0, os.SEEK_END) +PROFILING = False - trozo = "" - # start infinite loop - while True: - # read last line of file - line = thefile.readline() - # sleep if file hasn't been updated - if not line: - time.sleep(time_sleep) # FIXME - continue - - if line[-1] != "\n": - trozo += line - # print ("OJO :"+line+":") - else: - if trozo != "": - line = trozo + line - trozo = "" - yield line - -def load_file(dirname=TSTAT_DIR_NAME): - """ - - Client side - - """ - # "/home/dapi/Tstat/TOSHI/tstat/tstat_DRv4/tstat/piped/" - - while True: - here = os.path.dirname(os.path.abspath(__file__)) - tstat_piped = os.path.join(here, dirname) - tstat_dirs = os.listdir(tstat_piped) - if len(tstat_dirs) > 0: - tstat_dirs.sort() - new_dir = tstat_dirs[-1] - print(new_dir) - # print("dir: {0}".format(new_dir)) - tstat_file = tstat_piped + new_dir + "/log_tcp_temp_complete" - print("tstat_file: {0}".format(tstat_file)) - return tstat_file - else: - print("No tstat directory!") - time.sleep(1) - -def process_line(line): - """ - - Preprocessing before a message per line - - Avoids crash when nan are found by generating a 0s array - - Returns a list of values - """ - - def makeDivision(i, j): - """ - Helper function - """ - return i / j if (j and type(i) != str and type(j) != str) else 0 - - line = line.split(" ") - try: - n_packets_server, n_packets_client = float( - line[16]), float(line[2]) - except: - return [0 for i in range(9)] - n_bits_server, n_bits_client = float(line[22]), float(line[8]) - seconds = float(line[30]) / 1e6 # Duration in ms - values = [ - makeDivision(n_packets_server, seconds), - makeDivision(n_packets_client, seconds), - makeDivision(n_bits_server, seconds), - makeDivision(n_bits_client, seconds), - makeDivision(n_bits_server, n_packets_server), - makeDivision(n_bits_client, n_packets_client), - makeDivision(n_packets_server, n_packets_client), - makeDivision(n_bits_server, n_bits_client), - ] - return values - -def open_channel(input_information): - with grpc.insecure_channel("localhost:10001") as channel: - stub = L3CentralizedattackdetectorStub(channel) - response = stub.SendInput( - ModelInput(**input_information)) - LOGGER.debug("Inferencer send_input sent and received: ", - response.message) - # response = stub.get_output(Inferencer_pb2.empty(message="")) - # print("Inferencer get_output response: \n", response) - -def run(time_sleep, max_lines): - - filename = load_file() - write_salida = None - print( - "following: ", - filename, - " time to wait:", - time_sleep, - "lineas_tope:", - max_lines, - "write salida:", - write_salida, - ) - logfile = open(filename, "r") - # iterate over the generator - loglines = follow(logfile, time_sleep) - lin = 0 - ultima_lin = 0 - last_line = "" - cryptos = 0 - new_connections = {} # Dict for storing NEW data - connections_db = {} # Dict for storing ALL data - print('Reading lines') - for line in loglines: - print('Received Line') - start = time.time() - line_id = line.split(" ") - conn_id = (line_id[0], line_id[1], line_id[14], line_id[15]) - new_connections[conn_id] = process_line(line) - try: - connections_db[conn_id]["time_end"] = time.time() - except KeyError: - connections_db[conn_id] = JSON_BLANK.copy() - connections_db[conn_id]["time_start"] = time.time() - connections_db[conn_id]["time_end"] = time.time() - connections_db[conn_id]["ip_o"] = conn_id[0] - connections_db[conn_id]["port_o"] = conn_id[1] - connections_db[conn_id]["flow_id"] = "".join(conn_id) - connections_db[conn_id]["protocol"] = "TCP" - connections_db[conn_id]["ip_d"] = conn_id[2] - connections_db[conn_id]["port_d"] = conn_id[3] +def main(): + l3_distributedattackdetector() - # CRAFT DICT - inference_information = { - "n_packets_server_seconds": new_connections[conn_id][0], - "n_packets_client_seconds": new_connections[conn_id][1], - "n_bits_server_seconds": new_connections[conn_id][2], - "n_bits_client_seconds": new_connections[conn_id][3], - "n_bits_server_n_packets_server": new_connections[conn_id][4], - "n_bits_client_n_packets_client": new_connections[conn_id][5], - "n_packets_server_n_packets_client": new_connections[conn_id][6], - "n_bits_server_n_bits_client": new_connections[conn_id][7], - "ip_o": connections_db[conn_id]["ip_o"], - "port_o": connections_db[conn_id]["port_o"], - "ip_d": connections_db[conn_id]["ip_d"], - "port_d": connections_db[conn_id]["port_d"], - "flow_id": connections_db[conn_id]["flow_id"], - "protocol": connections_db[conn_id]["protocol"], - "time_start": connections_db[conn_id]["time_start"], - "time_end": connections_db[conn_id]["time_end"], - } - # SEND MSG - try: - open_channel(inference_information) - except: - LOGGER.info("Centralized Attack Mitigator is not up") +if __name__ == "__main__": + if PROFILING: + import cProfile, pstats, io - if write_salida: - print(line, end="") - sys.stdout.flush() - lin += 1 - if lin >= max_lines: - break - elif lin == 1: - print("primera:", ultima_lin) + pr = cProfile.Profile() + pr.enable() - end = time.time() - start - print(end) + main() + if PROFILING: + pr.disable() + s = io.StringIO() + sortby = "cumulative" + ps = pstats.Stats(pr, stream=s).sort_stats(sortby) + ps.print_stats() + LOGGER.info(s.getvalue()) -def main(): - logging.basicConfig() - run(5, 70) - -if __name__ == '__main__': - sys.exit(main()) + sys.exit(0) \ No newline at end of file diff --git a/src/l3_distributedattackdetector/service/l3_distributedattackdetector.py b/src/l3_distributedattackdetector/service/l3_distributedattackdetector.py new file mode 100644 index 0000000000000000000000000000000000000000..0a3c03793d9f9de54918da5ca82665ad7f6685eb --- /dev/null +++ b/src/l3_distributedattackdetector/service/l3_distributedattackdetector.py @@ -0,0 +1,378 @@ +# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from sys import stdout +import os +import time +import signal +import grpc +import numpy as np +import asyncio +from common.proto import L3CentralizedattackdetectorStub +from common.proto.l3_centralizedattackdetector_pb2 import ( + L3CentralizedattackdetectorMetrics, + L3CentralizedattackdetectorBatchInput, + ConnectionMetadata, + Feature, + Empty, +) + +from common.proto.context_pb2 import ( + ServiceTypeEnum, + ContextId, +) + +from common.proto.context_pb2_grpc import ContextServiceStub + +# Setup LOGGER +LOGGER = logging.getLogger("dad_LOGGER") +LOGGER.setLevel(logging.INFO) +logFormatter = logging.Formatter(fmt="%(levelname)-8s %(message)s") +consoleHandler = logging.StreamHandler(stdout) +consoleHandler.setFormatter(logFormatter) +LOGGER.addHandler(consoleHandler) + +TSTAT_DIR_NAME = "piped/" +CENTRALIZED_ATTACK_DETECTOR = "192.168.165.78:10001" +JSON_BLANK = { + "ip_o": "", # Client IP + "port_o": "", # Client port + "ip_d": "", # Server ip + "port_d": "", # Server port + "flow_id": "", # Identifier:c_ip,c_port,s_ip,s_port,time_start + "protocol": "", # Connection protocol + "time_start": 0.0, # Start of connection + "time_end": 0.0, # Time of last packet +} + +STOP = False +IGNORE_FIRST_LINE_TSTAT = True + +CONTEXT_ID = "admin" +CONTEXT_CHANNEL = "192.168.165.78:1010" +PROFILING = False +SEND_DATA_IN_BATCHES = False +BATCH_SIZE = 10 +ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"] + +class l3_distributedattackdetector(): + def __init__(self): + LOGGER.info("Creating Distributed Attack Detector") + + self.feature_ids = [] + + self.cad_features = {} + self.conn_id = () + + self.connections_dict = {} # Dict for storing ALL data + self.new_connections = {} # Dict for storing NEW data + + signal.signal(signal.SIGINT, self.handler) + + with grpc.insecure_channel(CENTRALIZED_ATTACK_DETECTOR) as channel: + self.cad = L3CentralizedattackdetectorStub(channel) + LOGGER.info("Connected to the centralized attack detector") + + LOGGER.info("Obtaining features...") + self.feature_ids = self.get_features_ids() + LOGGER.info("Features Ids.: {0}".format(self.feature_ids)) + + asyncio.run(self.process_traffic()) + + + def handler(self): + if STOP: + exit() + + STOP = True + + LOGGER.info("Gracefully stopping...") + + def follow(self, thefile, time_sleep): + """ + Generator function that yields new lines in a file + It reads the logfie (the opened file) + """ + # seek the end of the file + # thefile.seek(0, os.SEEK_END) + + trozo = "" + + # start infinite loop + while True: + # read last line of file + line = thefile.readline() + + # sleep if file hasn't been updated + if not line: + time.sleep(time_sleep) + continue + if line[-1] != "\n": + trozo += line + else: + if trozo != "": + line = trozo + line + trozo = "" + yield line + + + def load_file(self, dirname=TSTAT_DIR_NAME): # - Client side - + while True: + here = os.path.dirname(os.path.abspath(__file__)) + tstat_piped = os.path.join(here, dirname) + tstat_dirs = os.listdir(tstat_piped) + if len(tstat_dirs) > 0: + tstat_dirs.sort() + new_dir = tstat_dirs[-1] + tstat_file = tstat_piped + new_dir + "/log_tcp_temp_complete" + LOGGER.info("Following: {0}".format(tstat_file)) + return tstat_file + else: + LOGGER.info("No Tstat directory!") + time.sleep(5) + + + def process_line(self, line): + """ + - Preprocessing before a message per line + - Avoids crash when nan are found by generating a 0s array + - Returns a list of values + """ + line = line.split(" ") + + try: + values = [] + for feature_id in self.feature_ids: + feature_id = int(feature_id) + feature = feature_id - 1 + values.append(float(line[feature])) + except IndexError: + print("IndexError: {0}".format(line)) + + return values + + + def get_service_ids(self, context_id_str): + with grpc.insecure_channel(CONTEXT_CHANNEL) as channel: + stub = ContextServiceStub(channel) + context_id = ContextId() + context_id.context_uuid.uuid = context_id_str + return stub.ListServiceIds(context_id) + + + def get_services(self, context_id_str): + with grpc.insecure_channel(CONTEXT_CHANNEL) as channel: + stub = ContextServiceStub(channel) + context_id = ContextId() + context_id.context_uuid.uuid = context_id_str + return stub.ListServices(context_id) + + + def get_service_id(self, context_id): + service_id_list = self.get_service_ids(context_id) + service_id = None + for s_id in service_id_list.service_ids: + if ( + s_id.service_uuid.uuid == "0eaa0752-c7b6-4c2e-97da-317fbfee5112" + ): # TODO: Change this identifier to the L3VPN service identifier with the real router for the demo v2 + service_id = s_id + break + + return service_id + + + def get_service_id2(self, context_id): + service_list = self.get_services(context_id) + service_id = None + for s in service_list.services: + if s.service_type == ServiceTypeEnum.SERVICETYPE_L3NM: + service_id = s.service_id + break + else: + pass + return service_id + + + def get_endpoint_id(self, context_id): + service_list = self.get_services(context_id) + endpoint_id = None + for s in service_list.services: + if s.service_type == ServiceTypeEnum.SERVICETYPE_L3NM: + endpoint_id = s.service_endpoint_ids[0] + break + return endpoint_id + + + def get_features_ids(self): + return self.cad.GetFeaturesIds(Empty()).auto_features + + + def check_types(self): + for feature in self.cad_features["features"]: + assert isinstance(feature, float) + + assert isinstance(self.cad_features["connection_metadata"]["ip_o"], str) + assert isinstance(self.cad_features["connection_metadata"]["port_o"], str) + assert isinstance(self.cad_features["connection_metadata"]["ip_d"], str) + assert isinstance(self.cad_features["connection_metadata"]["port_d"], str) + assert isinstance(self.cad_features["connection_metadata"]["flow_id"], str) + assert isinstance(self.cad_features["connection_metadata"]["protocol"], str) + assert isinstance(self.cad_features["connection_metadata"]["time_start"], float) + assert isinstance(self.cad_features["connection_metadata"]["time_end"], float) + + + def insert_connection(self): + try: + self.connections_dict[self.conn_id]["time_end"] = time.time() + except KeyError: + self.connections_dict[self.conn_id] = JSON_BLANK.copy() + self.connections_dict[self.conn_id]["time_start"] = time.time() + self.connections_dict[self.conn_id]["time_end"] = time.time() + self.connections_dict[self.conn_id]["ip_o"] = self.conn_id[0] + self.connections_dict[self.conn_id]["port_o"] = self.conn_id[1] + self.connections_dict[self.conn_id]["flow_id"] = ":".join(self.conn_id) + self.connections_dict[self.conn_id]["service_id"] = self.get_service_id2(CONTEXT_ID) + self.connections_dict[self.conn_id]["endpoint_id"] = self.get_endpoint_id(CONTEXT_ID) + self.connections_dict[self.conn_id]["protocol"] = "TCP" + self.connections_dict[self.conn_id]["ip_d"] = self.conn_id[2] + self.connections_dict[self.conn_id]["port_d"] = self.conn_id[3] + + + def check_if_connection_is_attack(self): + if self.conn_id[0] in ATTACK_IPS or self.conn_id[2] in ATTACK_IPS: + LOGGER.info("Attack detected. Origin: {0}, destination: {1}".format(self.conn_id[0], self.conn_id[2])) + + + def create_cad_features(self): + self.cad_features = { + "features": self.new_connections[self.conn_id][0:10], + "connection_metadata": { + "ip_o": self.connections_dict[self.conn_id]["ip_o"], + "port_o": self.connections_dict[self.conn_id]["port_o"], + "ip_d": self.connections_dict[self.conn_id]["ip_d"], + "port_d": self.connections_dict[self.conn_id]["port_d"], + "flow_id": self.connections_dict[self.conn_id]["flow_id"], + "service_id": self.connections_dict[self.conn_id]["service_id"], + "endpoint_id": self.connections_dict[self.conn_id]["endpoint_id"], + "protocol": self.connections_dict[self.conn_id]["protocol"], + "time_start": self.connections_dict[self.conn_id]["time_start"], + "time_end": self.connections_dict[self.conn_id]["time_end"], + }, + } + + + async def send_batch_async(self, metrics_list_pb): + loop = asyncio.get_running_loop() + + # Create metrics batch + metrics_batch = L3CentralizedattackdetectorBatchInput() + metrics_batch.metrics.extend(metrics_list_pb) + + # Send batch + future = loop.run_in_executor( + None, self.cad.AnalyzeBatchConnectionStatistics, metrics_batch + ) + + try: + await future + except Exception as e: + LOGGER.error(f"Error sending batch: {e}") + + + async def send_data(self, metrics_list_pb, send_data_times): + # Send to CAD + if SEND_DATA_IN_BATCHES: + if len(metrics_list_pb) == BATCH_SIZE: + send_data_time_start = time.time() + await self.send_batch_async(metrics_list_pb) + metrics_list_pb = [] + + send_data_time_end = time.time() + send_data_time = send_data_time_end - send_data_time_start + send_data_times = np.append(send_data_times, send_data_time) + + else: + send_data_time_start = time.time() + self.cad.AnalyzeConnectionStatistics(metrics_list_pb[-1]) + + send_data_time_end = time.time() + send_data_time = send_data_time_end - send_data_time_start + send_data_times = np.append(send_data_times, send_data_time) + + return metrics_list_pb, send_data_times + + + async def process_traffic(self): + LOGGER.info("Loading Tstat log file...") + logfile = open(self.load_file(), "r") + + LOGGER.info("Following Tstat log file...") + loglines = self.follow(logfile, 5) + + process_time = [] + num_lines = 0 + + send_data_times = np.array([]) + metrics_list_pb = [] + + LOGGER.info("Starting to process data...") + + index = 0 + while True: + line = next(loglines, None) + + while line == None: + LOGGER.info("Waiting for new data...") + time.sleep(1 / 100) + line = next(loglines, None) + if index == 0 and IGNORE_FIRST_LINE_TSTAT: + index = index + 1 + continue + if STOP: + break + + num_lines += 1 + start = time.time() + line_id = line.split(" ") + self.conn_id = (line_id[0], line_id[1], line_id[14], line_id[15]) + self.new_connections[self.conn_id] = self.process_line(line) + + self.check_if_connection_is_attack() + + self.insert_connection() + + self.create_cad_features() + + self.check_types() + + connection_metadata = ConnectionMetadata(**self.cad_features["connection_metadata"]) + metrics = L3CentralizedattackdetectorMetrics() + + for feature in self.cad_features["features"]: + feature_obj = Feature() + feature_obj.feature = feature + metrics.features.append(feature_obj) + + metrics.connection_metadata.CopyFrom(connection_metadata) + metrics_list_pb.append(metrics) + + metrics_list_pb, send_data_times = await self.send_data(metrics_list_pb, send_data_times) + + index = index + 1 + + process_time.append(time.time() - start) + + if num_lines % 10 == 0: + LOGGER.info(f"Number of lines: {num_lines} - Average processing time: {sum(process_time) / num_lines}") \ No newline at end of file