Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
OpticalAttackDetectorServiceServicerImpl.py 9.57 KiB
# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import pickle
import random

import grpc
import numpy as np
import redis
from prometheus_client import Histogram

from common.Constants import ServiceNameEnum
from common.method_wrappers.Decorator import (MetricsPool, MetricTypeEnum,
                                              safe_and_metered_rpc_method)
from common.proto import dbscanserving_pb2 as dbscan
from common.proto import optical_attack_detector_pb2 as oad
from common.proto.context_pb2 import Empty
from common.proto.monitoring_pb2 import Kpi
from common.proto.optical_attack_detector_pb2_grpc import \
    OpticalAttackDetectorServiceServicer
from common.proto.optical_attack_mitigator_pb2 import (AttackDescription,
                                                       AttackResponse)
from common.Settings import get_service_host, get_setting
from common.tools.timestamp.Converters import timestamp_utcnow_to_float
from dbscanserving.client.DbscanServingClient import DbscanServingClient
from monitoring.client.MonitoringClient import MonitoringClient
from opticalattackmitigator.client.OpticalAttackMitigatorClient import \
    OpticalAttackMitigatorClient

LOGGER = logging.getLogger(__name__)

METRICS_POOL = MetricsPool("OpticalAttackDetector", "RPC")

METRICS_POOL_DETAILS = MetricsPool(
    "OpticalAttackDetector",
    "execution",
    labels={
        "operation": "",
        "step": "",
    },
)

METRIC_LABELS = dict(operation="detect")
HISTOGRAM_DURATION: Histogram = METRICS_POOL_DETAILS.get_or_create(
    "details", MetricTypeEnum.HISTOGRAM_DURATION
)

monitoring_client: MonitoringClient = MonitoringClient()
dbscanserving_client: DbscanServingClient = DbscanServingClient()
attack_mitigator_client: OpticalAttackMitigatorClient = OpticalAttackMitigatorClient()

redis_host = get_service_host(ServiceNameEnum.CACHING)
r = None
if redis_host is not None:
    redis_port = int(get_setting("CACHINGSERVICE_SERVICE_PORT_REDIS", default=6379))
    redis_password = get_setting("REDIS_PASSWORD")

    r = redis.Redis(host=redis_host, port=redis_port, password=redis_password)
    r.ping()

# detecting preloading of the stats
path = get_setting("PATH_OPM_INFORMATION_SUMMARY", default=None)
if path is not None and len(path) > 0:
    with open(path, "rb") as file:
        opm_information_stats = pickle.load(file)
        LOGGER.info("Using provided dataset: {}".format(path))
else:
    opm_information_stats = None

WAD_WINDOW = 20
WAD_SAMPLES = 9


class OpticalAttackDetectorServiceServicerImpl(OpticalAttackDetectorServiceServicer):
    def __init__(self):
        LOGGER.debug("Creating Servicer...")
        LOGGER.debug("Servicer Created")

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def DetectAttack(
        self, request: oad.DetectionRequest, context: grpc.ServicerContext
    ) -> Empty:

        s_uuid = request.service_id.service_uuid.uuid

        # detect if specific configuration is required
        # to set the variable on the fly, try:
        # https://stackoverflow.com/questions/45050050/can-i-modify-containers-environment-variables-without-restarting-pod-using-kube
        _temp = r.get("CLASS_SERVICE_{}".format(s_uuid.replace("-", "_")))
        if _temp is not None and len(_temp) > 0:
            _class = int(_temp)
        else:  # if not, assume no attack is present
            _class = 0
        LOGGER.debug("Using class {} for service {}".format(_class, s_uuid))

        _temp = r.get("PREVIOUS_CLASS_SERVICE_{}".format(s_uuid.replace("-", "_")))
        if _temp is not None and int(_temp) != 0 and _class == 0:
            # if value changed to no-attack
            # reset the experiment
            r.delete("opm_{}".format(s_uuid.replace("-", "_")))
        r.set("PREVIOUS_CLASS_SERVICE_{}".format(s_uuid.replace("-", "_")), _class)

        # code used to validate resiliency against failures and timeouts
        # if random.random() > 0.5:
        #     time.sleep(10)

        # run attack detection for every service
        detection_request: dbscan.DetectionRequest = dbscan.DetectionRequest()
        detection_request.num_samples = 310
        detection_request.num_features = 11

        # checking if we have enough samples already for this service
        length = r.llen("opm_{}".format(s_uuid.replace("-", "_")))
        if length < detection_request.num_samples:
            # if the number of samples is not sufficient,
            # we insert new samples
            for _ in range(detection_request.num_samples - length):
                detection_sample = []
                if opm_information_stats is not None:
                    for col in range(1, len(opm_information_stats.columns), 2):
                        name = opm_information_stats.columns[col][0]
                        # [result.columns[x][0] for x in range(1, len(result.columns), 2)]
                        detection_sample.append(
                            np.random.normal(
                                loc=opm_information_stats[name]["mean"][_class],
                                scale=opm_information_stats[name]["std"][_class],
                            )
                        )
                else:
                    for __ in range(detection_request.num_features):
                        detection_sample.append(random.uniform(0.0, 10.0))
                # push the sample into the list
                r.rpush(
                    "opm_{}".format(s_uuid.replace("-", "_")),
                    pickle.dumps(tuple(detection_sample)),
                )

        # remove the oldest sample from the list
        r.lpop("opm_{}".format(s_uuid.replace("-", "_")))

        # generate the latest sample
        detection_sample = []
        if opm_information_stats is not None:

            for col in range(1, len(opm_information_stats.columns), 2):
                name = opm_information_stats.columns[col][0]
                # [result.columns[x][0] for x in range(1, len(result.columns), 2)]
                detection_sample.append(
                    np.random.normal(
                        loc=opm_information_stats[name]["mean"][_class],
                        scale=opm_information_stats[name]["std"][_class],
                    )
                )

            # generate data based on the stats based on the configuration of
            # https://dx.doi.org/10.1109/JLT.2020.2987032
            detection_request.eps = 1.0
            detection_request.min_samples = 5
        else:
            detection_request.eps = 100.5
            detection_request.min_samples = 5

            if _class == 0:
                for __ in range(detection_request.num_features):
                    detection_sample.append(random.uniform(0.0, 10.0))
            else:  # if not, assume no attack is present
                for __ in range(detection_request.num_features):
                    detection_sample.append(random.uniform(5000.0, 6000.0))

        # adding the sample to the cache and recovering the cache
        with HISTOGRAM_DURATION.labels(step="cachefetch", **METRIC_LABELS).time():
            r.rpush(
                "opm_{}".format(s_uuid.replace("-", "_")),
                pickle.dumps(tuple(detection_sample)),
            )
            cached_samples = r.lrange("opm_{}".format(s_uuid.replace("-", "_")), 0, -1)
            LOGGER.info(
                "Recovered {} samples from the cache".format(len(cached_samples))
            )

        for raw_sample in cached_samples:
            sample = pickle.loads(raw_sample)
            detection_sample = dbscan.Sample()
            for feature in sample:
                detection_sample.features.append(feature)
            detection_request.samples.append(detection_sample)

        with HISTOGRAM_DURATION.labels(step="uldetection", **METRIC_LABELS).time():
            response: dbscan.DetectionResponse = dbscanserving_client.Detect(
                detection_request
            )

        # including KPI
        kpi = Kpi()
        kpi.kpi_id.kpi_id.uuid = request.kpi_id.kpi_id.uuid
        kpi.timestamp.timestamp = timestamp_utcnow_to_float()
        # implementing WAD from https://ieeexplore.ieee.org/abstract/document/9064530
        if response.cluster_indices[-WAD_WINDOW:].count(-1) >= WAD_SAMPLES:
            kpi.kpi_value.int32Val = 1
            LOGGER.info(
                "Attack detected for service {}".format(
                    request.service_id.service_uuid.uuid
                )
            )
        else:
            kpi.kpi_value.int32Val = 0

        with HISTOGRAM_DURATION.labels(step="includekpi", **METRIC_LABELS).time():
            monitoring_client.IncludeKpi(kpi)

        # if -1 in response.cluster_indices:  # attack detected
        if kpi.kpi_value.int32Val == -1:
            attack = AttackDescription()
            attack.cs_id.uuid = request.service_id.service_uuid.uuid
            with HISTOGRAM_DURATION.labels(step="mitigation", **METRIC_LABELS).time():
                # with MITIGATION_RESPONSE_TIME.time():
                response: AttackResponse = attack_mitigator_client.NotifyAttack(attack)

        # if attack is detected, run the attack mitigator
        return Empty()