# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
from enum import Enum
from typing import Dict, Tuple
from prometheus_client import Counter, Histogram
from prometheus_client.metrics import MetricWrapperBase, INF

class MetricTypeEnum(Enum):
    COUNTER_STARTED    = '{:s}_counter_requests_started'
    COUNTER_COMPLETED  = '{:s}_counter_requests_completed'
    COUNTER_FAILED     = '{:s}_counter_requests_failed'
    HISTOGRAM_DURATION = '{:s}_histogram_duration'

METRIC_TO_CLASS_PARAMS = {
    MetricTypeEnum.COUNTER_STARTED   : (Counter,   {}),
    MetricTypeEnum.COUNTER_COMPLETED : (Counter,   {}),
    MetricTypeEnum.COUNTER_FAILED    : (Counter,   {}),
    MetricTypeEnum.HISTOGRAM_DURATION: (Histogram, {
        'buckets': (.005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, INF)
    })
}

class MetricsPool:
    def __init__(self) -> None:
        self._metrics : Dict[str, MetricWrapperBase] = dict()
        self._lock = threading.Lock()

    def get_or_create(self, function_name : str, metric_type : MetricTypeEnum, **metric_params) -> MetricWrapperBase:
        metric_name = str(metric_type.value).format(function_name).upper()
        with self._lock:
            if metric_name not in self._metrics:
                metric_tuple : Tuple[MetricWrapperBase, Dict] = METRIC_TO_CLASS_PARAMS.get(metric_type)
                metric_class, default_metric_params = metric_tuple
                if len(metric_params) == 0: metric_params = default_metric_params
                self._metrics[metric_name] = metric_class(metric_name, '', **metric_params)
            return self._metrics[metric_name]

def meter_method(metrics_pool : MetricsPool):
    def outer_wrapper(func):
        func_name = func.__name__
        histogram_duration : Histogram = metrics_pool.get_or_create(func_name, MetricTypeEnum.HISTOGRAM_DURATION)
        counter_started    : Counter   = metrics_pool.get_or_create(func_name, MetricTypeEnum.COUNTER_STARTED)
        counter_completed  : Counter   = metrics_pool.get_or_create(func_name, MetricTypeEnum.COUNTER_COMPLETED)
        counter_failed     : Counter   = metrics_pool.get_or_create(func_name, MetricTypeEnum.COUNTER_FAILED)

        @histogram_duration.time()
        def inner_wrapper(self, *args, **kwargs):
            counter_started.inc()
            try:
                reply = func(self, *args, **kwargs)
                counter_completed.inc()
                return reply
            except KeyboardInterrupt:   # pylint: disable=try-except-raise
                raise
            except Exception:           # pylint: disable=broad-except
                counter_failed.inc()

        return inner_wrapper
    return outer_wrapper
