from __future__ import print_function
from concurrent import futures
from subprocess import Popen, DEVNULL
import argparse
import sys
import time
from datetime import datetime
import os
import grpc
import numpy as np
import pickle as pkl
import logging
import grpc, logging
from common.orm.Database import Database
from l3_centralizedattackdetector.proto.context_pb2 import Empty
from l3_centralizedattackdetector.proto.l3_centralizedattackdetector_pb2 import (
    Empty,
)
from l3_centralizedattackdetector.proto.l3_centralizedattackdetector_pb2_grpc import (
    L3CentralizedattackdetectorServicer,
    add_L3CentralizedattackdetectorServicer_to_server
)

from l3_centralizedattackdetector.proto.l3_attackmitigator_pb2 import (
    Output,
)
from l3_centralizedattackdetector.proto.l3_attackmitigator_pb2_grpc import (
    l3_attackmitigatorStub,
)

LOGGER = logging.getLogger(__name__)
here = os.path.dirname(os.path.abspath(__file__))
MODEL_FILE = os.path.join(here, "ml_model/RF_Netflow_TF")
INFERENCE_VALUES = []


class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer):

    def __init__(self, database: Database):
        LOGGER.debug("Creating Servicer...")
        self.database = database
        LOGGER.debug("Servicer Created")
    
    class CAD(L3CentralizedattackdetectorServicer):
        def __init__(self, ml_model):
            self.ml_model = ml_model

        def send_input(self, request, context):
            # PERFORM INFERENCE WITH SENT INPUTS
            logging.debug("")
            print("Inferencing ...")

            # STORE VALUES
            INFERENCE_VALUES.append(request)

            # MAKE INFERENCE
            output = self.make_inference(request)

            # SEND INFO TO MITIGATION SERVER
            try:
                with grpc.insecure_channel("localhost:10002") as channel:
                    stub = l3_attackmitigatorStub(channel)
                    print("Sending to mitigator...")
                    response = stub.send_output(output)
                    print("Sent output to mitigator and received: ", response.message)

                # RETURN "OK" TO THE CALLER
                return Empty(
                    message="OK, information received and mitigator notified"
                )
            except:
                print('Couldnt find l3_attackmitigator')
                return Empty(
                        message="Mitigator Not found"
                )

        def make_inference(self, request):
            # ML MODEL
            # new_predictions = model.predict_proba(list(new_connections.values()))
            predictions = self.ml_model.predict_proba(
                [
                    [
                        request.n_packets_server_seconds,
                        request.n_packets_client_seconds,
                        request.n_bits_server_seconds,
                        request.n_bits_client_seconds,
                        request.n_bits_server_n_packets_server,
                        request.n_bits_client_n_packets_client,
                        request.n_packets_server_n_packets_client,
                        request.n_bits_server_n_bits_client,
                    ]
                ]
            )
            # Output format
            output_message = {
                "confidence": None,
                "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
                "ip_o": request.ip_o,
                "tag_name": None,
                "tag": None,
                "flow_id": request.flow_id,
                "protocol": request.protocol,
                "port_d": request.port_d,
                "ml_id": "RandomForest",
                "time_start": request.time_start,
                "time_end": request.time_end,
            }
            if predictions[0][1] >= 0.5:
                output_message["confidence"] = predictions[0][1]
                output_message["tag_name"] = "Crypto"
                output_message["tag"] = 1
            else:
                output_message["confidence"] = predictions[0][0]
                output_message["tag_name"] = "Normal"
                output_message["tag"] = 0

            return Output(**output_message)

        def get_output(self, request, context):
            logging.debug("")
            print("Returing inference output...")
            k = np.multiply(INFERENCE_VALUES, [2])
            k = np.sum(k)
            return self.make_inference(k)


    def setup_l3_centralizedattackdetector(self):
        print('Starting CAD')
        with open(MODEL_FILE, "rb") as f:
            ml_model = pkl.load(f)
        server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        add_L3CentralizedattackdetectorServicer_to_server(
            self.CAD(ml_model), server)
        server.add_insecure_port("[::]:10001")
        server.start()
        server.wait_for_termination()

    def DetectAttack(self, request: Empty, grpc_context: grpc.ServicerContext) -> Empty:
        LOGGER.debug('DetectAttack request: {}'.format(str(request)))
        reply = Empty()
        LOGGER.debug('DetectAttack reply: {}'.format(str(reply)))
        return reply




    
