# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import pickle
import uuid
import logging
import multiprocessing
import signal
import os
import threading
import yaml

import requests
import redis
from kubernetes import client, config

from common.Constants import ServiceNameEnum
from common.Settings import get_service_host, get_setting, wait_for_environment_variables

from configs import base_results_folder, datetime_format, hpa_data

LOGGER = None
SERVICE_LIST_KEY = get_setting(
    "OPTICALATTACKMANAGER_SERVICE_LIST_KEY", default="opt-sec:active-services"
)

# Configs can be set in Configuration class directly or using helper utility
config.load_kube_config()

logging.getLogger("kubernetes").setLevel(logging.INFO) # avoid lengthy messages

# setting up graceful shutdown
terminate = threading.Event()


def signal_handler(signal, frame):  # pylint: disable=redefined-outer-name
    LOGGER.warning("Terminate signal received")
    terminate.set()


def manage_number_services(terminate, folder):

    # connecting with Redis
    redis_host = get_service_host(ServiceNameEnum.CACHING)
    redis_password = None
    if redis_host is not None:
        redis_port = int(get_setting("CACHINGSERVICE_SERVICE_PORT_REDIS"))
        redis_password = get_setting("REDIS_PASSWORD")
    else:
        LOGGER.fatal("No environment variables set for Redis")
    
    cache = None
    try:
        cache = redis.Redis(host=redis_host, port=redis_port, password=redis_password)
    except Exception as e:
        LOGGER.exception(e)
    
    # clean the existing list that will be populated later on in this function
    cache.delete(SERVICE_LIST_KEY)

    # make sure we have the correct loop time
    cache.set("MONITORING_INTERVAL", 30)

    # connecting to the HPA API
    autoscaling = client.AutoscalingV1Api()

    # connecting to the custom objects API
    api = client.CustomObjectsApi()

    # open the file that will store the information
    services_file = open(os.path.join(folder, "services.csv"), "wt", encoding="utf-8")
    services_file.write("# file with number of services\n")
    services_file.write("timestamp,number_services")

    hpa_file = open(os.path.join(folder, "hpas.csv"), "wt", encoding="utf-8")

    # writing headers for the HPA columns
    hpas = autoscaling.list_namespaced_horizontal_pod_autoscaler(namespace="tfs")
    
    for hpa in hpas.items:
        hpa_file.write(hpa.metadata.name + "\n")
        for d in hpa_data:
            services_file.write(f",{hpa.metadata.name}_{d}")
    
    # monitoring CPU and RAM usage of the single Pods
    for s in ["cache", "manager"]:
        for k in ["cpu", "ram"]:
            services_file.write(f",{s}_{k}")

    services_file.write("\n")
    services_file.flush()

    hpa_file.flush()
    hpa_file.close()

    # define number of services
    # 6 values followed by two zeros
    number_services = [0, 10]

    loads = [120, 240, 480, 960, 1440, 1920, 1922]
    for load in loads:
        number_services.append(int(load/2))
        for _ in range(5):
            number_services.append(load)
        for _ in range(2):
            number_services.append(0)

    ticks = 1  # defines how much to wait
    set_to_60 = False
    cur_tick = 0
    LOGGER.info("Starting load!")
    while not terminate.wait(timeout=30):  # timeout=300
        if cur_tick % ticks == 0:
            LOGGER.debug("go load!")

            # getting data from autoscaler
            hpas = autoscaling.list_namespaced_horizontal_pod_autoscaler(namespace="tfs")
            # - "cur_utilization"
            # - "target_utilization"
            # - "cur_replicas"
            # - "desired_replicas"
            hpa_string = ""
            for hpa in hpas.items:
                hpa_string += f",{hpa.status.current_cpu_utilization_percentage}"
                hpa_string += f",{hpa.spec.target_cpu_utilization_percentage}"
                hpa_string += f",{hpa.status.current_replicas}"
                hpa_string += f",{hpa.status.desired_replicas}"
            
            # monitoring resource usage
            k8s_pods = api.list_cluster_custom_object(
                "metrics.k8s.io", "v1beta1", "namespaces/tfs/pods"
            )
            # - "cache_cpu"
            # - "cache_ram"
            # - "manager_cpu"
            # - "manager_ram"
            resource_string = ""

            # we use two loops to ensure the same order
            for stats in k8s_pods["items"]:
                if "caching" in stats['metadata']['name']:
                    resource_string += f",{stats['containers'][0]['usage']['cpu']}"
                    resource_string += f",{stats['containers'][0]['usage']['memory']}"
                    break
            for stats in k8s_pods["items"]:
                if "opticalattackmanager" in stats['metadata']['name']:
                    resource_string += f",{stats['containers'][0]['usage']['cpu']}"
                    resource_string += f",{stats['containers'][0]['usage']['memory']}"
                    break

            # calculate the difference between current and expected
            cur_services = cache.llen(SERVICE_LIST_KEY)
            diff_services = cur_services - number_services[cur_tick % len(number_services)]

            if not set_to_60 and number_services[cur_tick + 1 % len(number_services)] == 961:
                cache.set("MONITORING_INTERVAL", 60)
                LOGGER.info("Setting monitoring interval to 60")
                set_to_60 = True

            # write current number with one second difference
            cur_datetime = datetime.datetime.now()
            reported_datetime = cur_datetime - datetime.timedelta(seconds=1)
            reported_datetime_str = reported_datetime.strftime(datetime_format)
            services_file.write(f"{reported_datetime_str},{cur_services}{hpa_string}\n")

            if diff_services < 0:  # current is lower than expected
                LOGGER.debug(f"inserting <{-diff_services}> services")
                for _ in range(-diff_services):
                    cache.lpush(
                        SERVICE_LIST_KEY,
                        pickle.dumps(
                            {
                                "context": str(uuid.uuid4()),
                                "service": str(uuid.uuid4()),
                                "kpi": str(uuid.uuid4()),
                            }
                        ),
                    )
            elif diff_services > 0:  # current is greater than expected
                # delete services
                LOGGER.debug(f"deleting <{diff_services}> services")
                cache.lpop(SERVICE_LIST_KEY, diff_services)

            # writing the new number with the current time
            services_file.write(
                f"{datetime.datetime.now().strftime(datetime_format)},"
                f"{number_services[cur_tick % len(number_services)]}"
                f"{hpa_string}{resource_string}\n"
            )

            assert number_services[cur_tick % len(number_services)] == cache.llen(SERVICE_LIST_KEY)

            services_file.flush()
        else:
            LOGGER.debug("tick load!")
        cur_tick += 1
        if cur_tick > len(number_services) + 1:
            break
    services_file.flush()
    services_file.close()
    # make sure we have the correct loop time
    cache.set("MONITORING_INTERVAL", 30)

    LOGGER.info("Finished load!")


def monitor_endpoints(terminate):
    LOGGER.info("starting experiment!")
    v1 = client.CoreV1Api()
    while not terminate.wait(timeout=30):

        # load base yaml
        with open("/home/carda/projects/prometheus/prometheus.yml.backup", "rt") as file:
            current_version = yaml.load(file, Loader=yaml.FullLoader)

        # checking endpoints
        ret = v1.list_namespaced_endpoints(namespace="tfs", watch=False)
        for item in ret.items:
            found = False

            for subset in item.subsets:
                for p, q in enumerate(subset.ports):
                    if q.name == "metrics":  # endpoint is ready for being scraped
                        found = True
            if not found:
                continue  # if no `metrics` port, jump!

            found = False  # now look for existing configuration
            for i in range(len(current_version["scrape_configs"])):
                if current_version["scrape_configs"][i]["job_name"] == item.metadata.name:
                    found = True
                    break # found it! `i` will contain the correct index

            if not found:  # write it from zero!
                current_version["scrape_configs"].append({})
                current_version["scrape_configs"][-1]["job_name"] = item.metadata.name

                # set the correct `i` value
                i = len(current_version["scrape_configs"]) - 1

            if "static_configs" not in current_version["scrape_configs"][i]:
                current_version["scrape_configs"][i]["static_configs"] = [{"targets": []}]
            # reset IPs
            current_version["scrape_configs"][i]["static_configs"][0]["targets"] = []
            for subset in item.subsets:
                for p, q in enumerate(subset.ports):
                    if q.name == "metrics":
                        for c, a in enumerate(subset.addresses):
                            print(f"{item.metadata.name}\t{a.ip}:{q.port}")
                            current_version["scrape_configs"][i]["static_configs"][0]["targets"].append(f"{a.ip}:9192")

        # write yaml
        with open("/home/carda/projects/prometheus/prometheus.yml", "wt") as file:
            yaml.dump(current_version, file)
        
        # reloading configuration
        # docs: https://www.robustperception.io/reloading-prometheus-configuration/
        requests.post("http://127.0.0.1:9090/-/reload")

    # resetting prometheus to the original state
    # load base yaml
    with open("/home/carda/projects/prometheus/prometheus.yml.backup", "rt") as file:
        current_version = yaml.load(file, Loader=yaml.FullLoader)
    
    # write yaml
    with open("/home/carda/projects/prometheus/prometheus.yml", "wt") as file:
        yaml.dump(current_version, file)        
    # reloading configuration
    # docs: https://www.robustperception.io/reloading-prometheus-configuration/
    requests.post("http://127.0.0.1:9090/-/reload")

    LOGGER.info("Finished experiment!")


if __name__ == "__main__":
    # logging.basicConfig(level="DEBUG")
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s.%(msecs)03d %(levelname)s - %(funcName)s: %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
    )
    LOGGER = logging.getLogger(__name__)

    wait_for_environment_variables(
        ["CACHINGSERVICE_SERVICE_PORT_REDIS", "REDIS_PASSWORD"]
    )

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    # generate results folder
    output_folder = os.path.join(base_results_folder, "jocn_" + datetime.datetime.now(
        datetime.timezone.utc
    ).strftime("%Y%m%dT%H%M%S.%fUTC"))
    os.makedirs(output_folder)

    # start load handler
    proc_load = multiprocessing.Process(
        target=manage_number_services,
        args=(
            terminate,
            output_folder,
        ),
    )
    proc_load.start()

    # start experiment monitoring
    proc_experiment = multiprocessing.Process(
        target=monitor_endpoints, args=(terminate,)
    )
    proc_experiment.start()

    # Wait for Ctrl+C or termination signal
    while not terminate.wait(timeout=0.1):
        pass

    # waits for the processes to finish
    proc_load.join()
    proc_experiment.join()

    # exits
    LOGGER.info("Bye!")
