# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
import os
import signal
import time
from sys import stdout

import grpc
import numpy as np

from common.proto.context_pb2 import ContextId, Empty, ServiceTypeEnum
from common.proto.context_pb2_grpc import ContextServiceStub
from common.proto.l3_centralizedattackdetector_pb2 import (
    ConnectionMetadata,
    Feature,
    L3CentralizedattackdetectorBatchInput,
    L3CentralizedattackdetectorMetrics,
)
from common.proto.l3_centralizedattackdetector_pb2_grpc import L3CentralizedattackdetectorStub

# Setup LOGGER
LOGGER = logging.getLogger("dad_LOGGER")
LOGGER.setLevel(logging.INFO)
logFormatter = logging.Formatter(fmt="%(levelname)-8s %(message)s")
consoleHandler = logging.StreamHandler(stdout)
consoleHandler.setFormatter(logFormatter)
LOGGER.addHandler(consoleHandler)

# Define constants
TSTAT_DIR_NAME = "piped/"
CONTROLLER_IP = "192.168.165.78"  # Change this to the IP of the controller
CONTEXT_ID = "admin"  # Change this to the context ID to be used
CONTEXT_CHANNEL = f"{CONTROLLER_IP}:1010"
CENTRALIZED_ATTACK_DETECTOR = f"{CONTROLLER_IP}:10001"
JSON_BLANK = {
    "ip_o": "",  # Client IP
    "port_o": "",  # Client port
    "ip_d": "",  # Server ip
    "port_d": "",  # Server port
    "flow_id": "",  # Identifier:c_ip,c_port,s_ip,s_port,time_start
    "protocol": "",  # Connection protocol
    "time_start": 0.0,  # Start of connection
    "time_end": 0.0,  # Time of last packet
}
STOP = False
IGNORE_FIRST_LINE_TSTAT = True
PROFILING = False
SEND_DATA_IN_BATCHES = False
BATCH_SIZE = 10


class l3_distributedattackdetector:
    def __init__(self):
        LOGGER.info("Creating Distributed Attack Detector")
        
        self.feature_ids = []

        self.cad_features = {}
        self.conn_id = ()

        self.connections_dict = {}  # Dict for storing ALL data
        self.new_connections = {}  # Dict for storing NEW data

        self.known_attack_ips = self.read_kwnown_attack_ips()
        
        signal.signal(signal.SIGINT, self.handler)

        with grpc.insecure_channel(CENTRALIZED_ATTACK_DETECTOR) as channel:
            self.cad = L3CentralizedattackdetectorStub(channel)
            LOGGER.info("Connected to the Centralized Attack Detector")
            
            LOGGER.info("Obtaining features Ids. from the Centralized Attack Detector...")
            self.feature_ids = self.get_features_ids()
            LOGGER.info("Features Ids.: {:s}".format(str(self.feature_ids)))

            asyncio.run(self.process_traffic())
    
    def read_kwnown_attack_ips(self):
        known_attack_ips = []
        
        # open known attack ips csv file
        with open("known_attack_ips.csv", "r") as f:
            known_attack_ips = f.read().split(",")
        
        return known_attack_ips

    def handler(self):
        if STOP:
            exit()

        STOP = True

        LOGGER.info("Gracefully stopping...")

    def follow(self, file, time_sleep):
        """
        Generator function that yields new lines in a file
        """

        # seek the end of the file
        # file.seek(0, os.SEEK_END)

        chunk = ""

        # start an infinite loop
        while True:
            # read last line of the file
            line = file.readline()

            # sleep if the file hasn't been updated
            if not line:
                time.sleep(time_sleep)
                continue

            if line[-1] != "\n":
                chunk += line
            else:
                if chunk != "":
                    line = chunk + line
                    chunk = ""

                yield line

    def load_file(self, dirname=TSTAT_DIR_NAME):
        while True:
            here = os.path.dirname(os.path.abspath(__file__))
            tstat_piped = os.path.join(here, dirname)
            tstat_dirs = os.listdir(tstat_piped)
            
            if len(tstat_dirs) > 0:
                tstat_dirs.sort()
                new_dir = tstat_dirs[-1]
                tstat_file = tstat_piped + new_dir + "/log_tcp_temp_complete"
                
                LOGGER.info("Following: {:s}".format(str(tstat_file)))
                
                return tstat_file
            else:
                LOGGER.info("No Tstat directory!")
                time.sleep(5)

    def process_line(self, line):
        """
        - Preprocessing before a message per line
        - Avoids crash when nan are found by generating a 0s array
        - Returns a list of values
        """
        line = line.split(" ")

        try:
            values = []
            for feature_id in self.feature_ids:
                feature_id = int(feature_id)
                feature = feature_id - 1
                values.append(float(line[feature]))
        except IndexError:
            LOGGER.error("IndexError: {0}".format(line))

        return values

    def get_service_ids(self, context_id_str):
        with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
            stub = ContextServiceStub(channel)
            context_id = ContextId()
            context_id.context_uuid.uuid = context_id_str
            
            return stub.ListServiceIds(context_id)

    def get_services(self, context_id_str):
        with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
            stub = ContextServiceStub(channel)
            context_id = ContextId()
            context_id.context_uuid.uuid = context_id_str
            
            return stub.ListServices(context_id)

    def get_service_id(self, context_id):
        service_list = self.get_services(context_id)
        service_id = None
        
        for s in service_list.services:
            if s.service_type == ServiceTypeEnum.SERVICETYPE_L3NM:
                service_id = s.service_id
                break
            else:
                pass
            
        return service_id

    def get_endpoint_id(self, context_id):
        service_list = self.get_services(context_id)
        endpoint_id = None
        
        for s in service_list.services:
            if s.service_type == ServiceTypeEnum.SERVICETYPE_L3NM:
                endpoint_id = s.service_endpoint_ids[0]
                break
            
        return endpoint_id

    def get_features_ids(self):
        return self.cad.GetFeaturesIds(Empty()).auto_features

    def check_types(self):
        for feature in self.cad_features["features"]:
            assert isinstance(feature, float)

        assert isinstance(self.cad_features["connection_metadata"]["ip_o"], str)
        assert isinstance(self.cad_features["connection_metadata"]["port_o"], str)
        assert isinstance(self.cad_features["connection_metadata"]["ip_d"], str)
        assert isinstance(self.cad_features["connection_metadata"]["port_d"], str)
        assert isinstance(self.cad_features["connection_metadata"]["flow_id"], str)
        assert isinstance(self.cad_features["connection_metadata"]["protocol"], str)
        assert isinstance(self.cad_features["connection_metadata"]["time_start"], float)
        assert isinstance(self.cad_features["connection_metadata"]["time_end"], float)

    def insert_connection(self):
        try:
            self.connections_dict[self.conn_id]["time_end"] = time.time()
        except KeyError:
            self.connections_dict[self.conn_id] = JSON_BLANK.copy()
            self.connections_dict[self.conn_id]["time_start"] = time.time()
            self.connections_dict[self.conn_id]["time_end"] = time.time()
            self.connections_dict[self.conn_id]["ip_o"] = self.conn_id[0]
            self.connections_dict[self.conn_id]["port_o"] = self.conn_id[1]
            self.connections_dict[self.conn_id]["flow_id"] = ":".join(self.conn_id)
            self.connections_dict[self.conn_id]["service_id"] = self.get_service_id(CONTEXT_ID)
            self.connections_dict[self.conn_id]["endpoint_id"] = self.get_endpoint_id(CONTEXT_ID)
            self.connections_dict[self.conn_id]["protocol"] = "TCP"
            self.connections_dict[self.conn_id]["ip_d"] = self.conn_id[2]
            self.connections_dict[self.conn_id]["port_d"] = self.conn_id[3]

    def check_if_connection_is_attack(self):
        if self.conn_id[0] in self.known_attack_ips or self.conn_id[2] in self.known_attack_ips:
            LOGGER.info("Attack detected. Origin: {0}, destination: {1}".format(self.conn_id[0], self.conn_id[2]))

    def create_cad_features(self):
        self.cad_features = {
            "features": self.new_connections[self.conn_id][0:10],
            "connection_metadata": {
                "ip_o": self.connections_dict[self.conn_id]["ip_o"],
                "port_o": self.connections_dict[self.conn_id]["port_o"],
                "ip_d": self.connections_dict[self.conn_id]["ip_d"],
                "port_d": self.connections_dict[self.conn_id]["port_d"],
                "flow_id": self.connections_dict[self.conn_id]["flow_id"],
                "service_id": self.connections_dict[self.conn_id]["service_id"],
                "endpoint_id": self.connections_dict[self.conn_id]["endpoint_id"],
                "protocol": self.connections_dict[self.conn_id]["protocol"],
                "time_start": self.connections_dict[self.conn_id]["time_start"],
                "time_end": self.connections_dict[self.conn_id]["time_end"],
            },
        }

    async def send_batch_async(self, metrics_list_pb):
        loop = asyncio.get_running_loop()

        # Create metrics batch
        metrics_batch = L3CentralizedattackdetectorBatchInput()
        metrics_batch.metrics.extend(metrics_list_pb)

        # Send batch
        future = loop.run_in_executor(None, self.cad.AnalyzeBatchConnectionStatistics, metrics_batch)

        try:
            await future
        except Exception as e:
            LOGGER.error(f"Error sending batch: {e}")

    async def send_data(self, metrics_list_pb, send_data_times):
        # Send data to CAD
        if SEND_DATA_IN_BATCHES:
            if len(metrics_list_pb) == BATCH_SIZE:
                send_data_time_start = time.time()
                await self.send_batch_async(metrics_list_pb)
                metrics_list_pb = []

                send_data_time_end = time.time()
                send_data_time = send_data_time_end - send_data_time_start
                send_data_times = np.append(send_data_times, send_data_time)

        else:
            send_data_time_start = time.time()
            self.cad.AnalyzeConnectionStatistics(metrics_list_pb[-1])

            send_data_time_end = time.time()
            send_data_time = send_data_time_end - send_data_time_start
            send_data_times = np.append(send_data_times, send_data_time)

        return metrics_list_pb, send_data_times

    async def process_traffic(self):
        LOGGER.info("Loading Tstat log file...")
        logfile = open(self.load_file(), "r")

        LOGGER.info("Following Tstat log file...")
        loglines = self.follow(logfile, 5)

        process_time = []
        num_lines = 0

        send_data_times = np.array([])
        metrics_list_pb = []

        LOGGER.info("Starting to process data...")

        index = 0
        while True:
            line = next(loglines, None)

            while line is None:
                LOGGER.info("Waiting for new data...")
                time.sleep(1 / 100)
                line = next(loglines, None)
                
            if index == 0 and IGNORE_FIRST_LINE_TSTAT:
                index = index + 1
                continue
            
            if STOP:
                break

            num_lines += 1
            start = time.time()
            line_id = line.split(" ")
            
            self.conn_id = (line_id[0], line_id[1], line_id[14], line_id[15])
            self.new_connections[self.conn_id] = self.process_line(line)

            self.check_if_connection_is_attack()
            self.insert_connection()
            self.create_cad_features()
            self.check_types()

            connection_metadata = ConnectionMetadata(**self.cad_features["connection_metadata"])
            metrics = L3CentralizedattackdetectorMetrics()

            for feature in self.cad_features["features"]:
                feature_obj = Feature()
                feature_obj.feature = feature
                metrics.features.append(feature_obj)

            metrics.connection_metadata.CopyFrom(connection_metadata)
            metrics_list_pb.append(metrics)

            metrics_list_pb, send_data_times = await self.send_data(metrics_list_pb, send_data_times)

            index = index + 1

            process_time.append(time.time() - start)

            if num_lines % 10 == 0:
                LOGGER.info(f"Number of lines: {num_lines} - Average processing time: {sum(process_time) / num_lines}")