Skip to content
Snippets Groups Projects
Commit 86ca2f05 authored by Carlos Natalino Da Silva's avatar Carlos Natalino Da Silva
Browse files

Initial implementation of the gRPC calls.

parent 2977f991
No related branches found
No related tags found
No related merge requests found
# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: cachingservice
spec:
selector:
matchLabels:
app: cachingservice
template:
metadata:
labels:
app: cachingservice
spec:
containers:
- name: redis
image: redis:7.0-alpine
env:
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: redis-secrets
key: REDIS_PASSWORD
ports:
- containerPort: 6379
name: client
command: ["redis-server"]
args:
- --requirepass
- $(REDIS_PASSWORD)
---
apiVersion: v1
kind: Service
metadata:
name: cachingservice
spec:
type: ClusterIP
selector:
app: cachingservice
ports:
- name: redis
port: 6379
targetPort: 6379
\ No newline at end of file
...@@ -17,15 +17,10 @@ syntax = "proto3"; ...@@ -17,15 +17,10 @@ syntax = "proto3";
package optical_attack_detector; package optical_attack_detector;
import "context.proto"; import "context.proto";
import "monitoring.proto";
service OpticalAttackDetectorService { service OpticalAttackDetectorService {
// rpc that triggers the attack detection loop // rpc that triggers the attack detection loop
rpc DetectAttack (context.Empty ) returns (context.Empty) {} rpc DetectAttack (context.ServiceId ) returns (context.Empty) {}
// rpc called by the distributed component to report KPIs
rpc ReportSummarizedKpi (monitoring.KpiList) returns (context.Empty) {}
rpc ReportKpi (monitoring.KpiList) returns (context.Empty) {}
} }
...@@ -53,5 +53,5 @@ class DbscanServingClient: ...@@ -53,5 +53,5 @@ class DbscanServingClient:
request.num_features request.num_features
)) ))
response = self.stub.Detect(request) response = self.stub.Detect(request)
LOGGER.debug('Detect result: {:s}'.format(str(response))) LOGGER.debug('Detect result with {:s} cluster indices'.format(len(response.cluster_indices)))
return response return response
celery[redis]
...@@ -33,7 +33,7 @@ from common.proto.optical_attack_detector_pb2_grpc import ( ...@@ -33,7 +33,7 @@ from common.proto.optical_attack_detector_pb2_grpc import (
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
SERVICE_NAME = 'OpticalAttackDetector' SERVICE_NAME = 'OpticalAttackDetector'
METHOD_NAMES = ['NotifyServiceUpdate', 'DetectAttack', 'ReportSummarizedKpi', 'ReportKpi'] METHOD_NAMES = ['DetectAttack']
METRICS = create_metrics(SERVICE_NAME, METHOD_NAMES) METRICS = create_metrics(SERVICE_NAME, METHOD_NAMES)
context_client: ContextClient = ContextClient() context_client: ContextClient = ContextClient()
...@@ -50,65 +50,37 @@ class OpticalAttackDetectorServiceServicerImpl(OpticalAttackDetectorServiceServi ...@@ -50,65 +50,37 @@ class OpticalAttackDetectorServiceServicerImpl(OpticalAttackDetectorServiceServi
LOGGER.debug('Servicer Created') LOGGER.debug('Servicer Created')
@safe_and_metered_rpc_method(METRICS, LOGGER) @safe_and_metered_rpc_method(METRICS, LOGGER)
def DetectAttack(self, request : Empty, context : grpc.ServicerContext) -> Empty: def DetectAttack(self, service_id : ServiceId, context : grpc.ServicerContext) -> Empty:
LOGGER.debug('Received request for {}/{}...'.format(
# retrieve list with current contexts service_id.context_id.context_uuid.uuid,
# import pdb; pdb.set_trace() service_id.service_uuid.uuid
context_ids: ContextIdList = context_client.ListContextIds(Empty()) ))
# run attack detection for every service
# for each context, retrieve list of current services request: DetectionRequest = DetectionRequest()
services = [] request.num_samples = 310
for context_id in context_ids.context_ids: request.num_features = 100
request.eps = 100.5
context_services: ServiceIdList = context_client.ListServices(context_id) request.min_samples = 5
for service in context_services.services: for _ in range(200):
services.append(service) grpc_sample = Sample()
for __ in range(100):
for service in services: grpc_sample.features.append(random.uniform(0., 10.))
for endpoint in service.service_endpoint_ids: request.samples.append(grpc_sample)
# get instant KPI for this endpoint for _ in range(100):
LOGGER.warning(f'service: {service.service_id.service_uuid.uuid}\t endpoint: {endpoint.endpoint_uuid.uuid}\tdevice: {endpoint.device_id.device_uuid.uuid}') grpc_sample = Sample()
for __ in range(100):
# run attack detection for every service grpc_sample.features.append(random.uniform(50., 60.))
request: DetectionRequest = DetectionRequest() request.samples.append(grpc_sample)
for _ in range(10):
request.num_samples = 310 grpc_sample = Sample()
request.num_features = 100 for __ in range(100):
request.eps = 100.5 grpc_sample.features.append(random.uniform(5000., 6000.))
request.min_samples = 50 request.samples.append(grpc_sample)
response: DetectionResponse = dbscanserving_client.Detect(request)
for _ in range(200): if -1 in response.cluster_indices: # attack detected
grpc_sample = Sample() attack = AttackDescription()
for __ in range(100): attack.cs_id.uuid = service_id.service_uuid.uuid
grpc_sample.features.append(random.uniform(0., 10.)) response: AttackResponse = attack_mitigator_client.NotifyAttack(attack)
request.samples.append(grpc_sample)
for _ in range(100):
grpc_sample = Sample()
for __ in range(100):
grpc_sample.features.append(random.uniform(50., 60.))
request.samples.append(grpc_sample)
for _ in range(10):
grpc_sample = Sample()
for __ in range(100):
grpc_sample.features.append(random.uniform(5000., 6000.))
request.samples.append(grpc_sample)
response: DetectionResponse = dbscanserving_client.Detect(request)
if -1 in response.cluster_indices: # attack detected
attack = AttackDescription()
attack.cs_id.uuid = service.service_id.service_uuid.uuid
response: AttackResponse = attack_mitigator_client.NotifyAttack(attack)
# if attack is detected, run the attack mitigator # if attack is detected, run the attack mitigator
return Empty() return Empty()
@safe_and_metered_rpc_method(METRICS, LOGGER)
def ReportSummarizedKpi(self, request : KpiList, context : grpc.ServicerContext) -> Empty:
return Empty()
@safe_and_metered_rpc_method(METRICS, LOGGER)
def ReportKpi(self, request : KpiList, context : grpc.ServicerContext) -> Empty:
return Empty()
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import logging, signal, sys, time, threading, random import logging, signal, sys, time, threading, random
from multiprocessing import Manager, Process from multiprocessing import Manager, Process
from prometheus_client import start_http_server from prometheus_client import start_http_server
from celery import Celery
import asyncio import asyncio
from common.Constants import DEFAULT_GRPC_MAX_WORKERS, DEFAULT_GRPC_GRACE_PERIOD from common.Constants import DEFAULT_GRPC_MAX_WORKERS, DEFAULT_GRPC_GRACE_PERIOD
...@@ -47,28 +46,6 @@ def signal_handler(signal, frame): # pylint: disable=redefined-outer-name ...@@ -47,28 +46,6 @@ def signal_handler(signal, frame): # pylint: disable=redefined-outer-name
terminate.set() terminate.set()
def detect_attack_old(monitoring_interval):
time.sleep(10) # wait for the service to start
LOGGER.info("Starting the attack detection loop")
client = OpticalAttackDetectorClient(address='localhost', port=GRPC_SERVICE_PORT)
client.connect()
while True: # infinite loop that runs until the terminate is set
if terminate.is_set(): # if terminate is set
LOGGER.warning("Stopping execution...")
client.close()
break # break the while and stop execution
client.DetectAttack(Empty())
# sleep
LOGGER.debug("Sleeping for {} seconds...".format(monitoring_interval))
time.sleep(monitoring_interval)
async def call_detection_grpc(request):
dbscanserving_client: DbscanServingClient = DbscanServingClient()
response: DetectionResponse = dbscanserving_client.Detect(request)
dbscanserving_client.close()
return result
def main(): def main():
global LOGGER # pylint: disable=global-statement global LOGGER # pylint: disable=global-statement
...@@ -87,87 +64,13 @@ def main(): ...@@ -87,87 +64,13 @@ def main():
# Start metrics server # Start metrics server
metrics_port = get_metrics_port() metrics_port = get_metrics_port()
start_http_server(metrics_port) # start_http_server(metrics_port) # TODO: remove this comment
attack_mitigator_client: OpticalAttackMitigatorClient = OpticalAttackMitigatorClient()
monitoring_client: MonitoringClient = MonitoringClient()
# Starting CentralizedCybersecurity service # Starting CentralizedCybersecurity service
grpc_service = OpticalAttackDetectorService( grpc_service = OpticalAttackDetectorService(
port=service_port, max_workers=max_workers, grace_period=grace_period) port=service_port, max_workers=max_workers, grace_period=grace_period)
grpc_service.start() grpc_service.start()
# p = multiprocessing.Process(target=detect_attack, args=(monitoring_interval, ))
# p.start()
# detect_attack(monitoring_interval)
LOGGER.info('Connecting with REDIS...')
REDIS_PASSWORD = get_setting('REDIS_PASSWORD')
REDIS_HOST = get_setting('CACHINGSERVICE_SERVICE_HOST')
REDIS_PORT = get_setting('CACHINGSERVICE_SERVICE_PORT')
BROKER_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0'
BACKEND_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/1'
app = Celery(
'cybersecurity',
broker=BROKER_URL,
backend=BACKEND_URL
)
LOGGER.info('Connected to REDIS...')
@app.task(name='detect_attack')
def detect_attack(context_id, service_id, kpi_id):
LOGGER.info('Detecting attack for {}/{}'.format(context_id, service_id))
alien_samples = random.randint(2, 10)
# run attack detection for every service
request: DetectionRequest = DetectionRequest()
request.num_samples = 200 + alien_samples
request.num_features = 20
request.eps = 100.5
request.min_samples = 5
for _ in range(200):
grpc_sample = Sample()
for __ in range(request.num_features):
grpc_sample.features.append(random.uniform(0., 10.))
request.samples.append(grpc_sample)
# for _ in range(100):
# grpc_sample = Sample()
# for __ in range(20):
# grpc_sample.features.append(random.uniform(50., 60.))
# request.samples.append(grpc_sample)
for _ in range(alien_samples):
grpc_sample = Sample()
for __ in range(request.num_features):
grpc_sample.features.append(random.uniform(5000., 6000.))
request.samples.append(grpc_sample)
# call the grpc
dbscanserving_client: DbscanServingClient = DbscanServingClient()
# response: DetectionResponse = dbscanserving_client.Detect(request)
dbscanserving_client.connect()
dbscanserving_client.close()
# including KPI
kpi = Kpi()
kpi.kpi_id.kpi_id.uuid = kpi_id
kpi.timestamp.timestamp = timestamp_utcnow_to_float()
# kpi.kpi_value.int32Val = response.cluster_indices[-1]
kpi.kpi_value.int32Val = 1
# monitoring_client.IncludeKpi(kpi)
# if -1 in response.cluster_indices: # attack detected
# attack = AttackDescription()
# # attack.cs_id.uuid = service.service_id.service_uuid.uuid
# response: AttackResponse = attack_mitigator_client.NotifyAttack(attack)
return "0"
app.worker_main([
'worker',
'--loglevel={}'.format(log_level),
'--autoscale',
'1',
'--pool=gevent'
])
# Wait for Ctrl+C or termination signal # Wait for Ctrl+C or termination signal
while not terminate.wait(timeout=0.1): pass while not terminate.wait(timeout=0.1): pass
......
celery[redis]
\ No newline at end of file
from celery import Celery import asyncio
import logging
import grpc
import random import random
from common.Settings import get_log_level, get_metrics_port, get_setting from common.Settings import get_log_level, get_setting
from common.proto.dbscanserving_pb2 import DetectionRequest, DetectionResponse, Sample from common.proto.dbscanserving_pb2 import DetectionRequest, DetectionResponse, Sample
from dbscanserving.client.DbscanServingClient import DbscanServingClient from common.proto.dbscanserving_pb2_grpc import DetectorStub
# For more channel options, please see https://grpc.io/grpc/core/group__grpc__arg__keys.html
dbscanserving_client: DbscanServingClient = DbscanServingClient() CHANNEL_OPTIONS = [('grpc.lb_policy_name', 'pick_first'),
('grpc.enable_retries', True),
alien_samples = random.randint(2, 20) ('grpc.keepalive_timeout_ms', 10000)]
request: DetectionRequest = DetectionRequest() # based on https://github.com/grpc/grpc/blob/master/examples/python/helloworld/async_greeter_client_with_options.py
request.num_samples = 300 + alien_samples
request.num_features = 20 async def run(endpoint, service_id) -> None:
request.eps = 100.5
request.min_samples = 5 async with grpc.aio.insecure_channel(target=endpoint,
for _ in range(200): options=CHANNEL_OPTIONS) as channel:
grpc_sample = Sample() stub = DetectorStub(channel)
for __ in range(20):
grpc_sample.features.append(random.uniform(0., 10.)) # generate data
request.samples.append(grpc_sample) alien_samples = random.randint(2, 20)
for _ in range(100):
grpc_sample = Sample() request: DetectionRequest = DetectionRequest()
for __ in range(20): request.num_samples = 300 + alien_samples
grpc_sample.features.append(random.uniform(50., 60.)) request.num_features = 20
request.samples.append(grpc_sample) request.eps = 100.5
for _ in range(alien_samples): request.min_samples = 5
grpc_sample = Sample() for _ in range(200):
for __ in range(20): grpc_sample = Sample()
grpc_sample.features.append(random.uniform(5000., 6000.)) for __ in range(20):
request.samples.append(grpc_sample) grpc_sample.features.append(random.uniform(0., 10.))
response: DetectionResponse = dbscanserving_client.Detect(request) request.samples.append(grpc_sample)
for _ in range(100):
REDIS_PASSWORD = get_setting('REDIS_PASSWORD') grpc_sample = Sample()
REDIS_HOST = get_setting('CACHINGSERVICE_SERVICE_HOST') for __ in range(20):
REDIS_PORT = get_setting('CACHINGSERVICE_SERVICE_PORT') grpc_sample.features.append(random.uniform(50., 60.))
BROKER_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0' request.samples.append(grpc_sample)
app = Celery( for _ in range(alien_samples):
'cybersecurity', grpc_sample = Sample()
broker=BROKER_URL, for __ in range(20):
backend=BROKER_URL grpc_sample.features.append(random.uniform(5000., 6000.))
) request.samples.append(grpc_sample)
# Timeout in seconds.
service = {'context': 'admin', 'service': '23bb5c96-e377-4943-a47a-4db9c54104cc', 'kpi': '1'} # Please refer gRPC Python documents for more detail. https://grpc.io/grpc/python/grpc.html
result = app.send_task('detect_attack', (service['context'], service['service'], service['kpi'])) response: DetectionResponse = await stub.Detect(request,
print('waiting for result...') timeout=10)
print('Result:', result.get()) print("Greeter client received:", service_id)
\ No newline at end of file return service_id * 2
async def main_loop():
host = get_setting('DBSCANSERVINGSERVICE_SERVICE_HOST')
port = get_setting('DBSCANSERVINGSERVICE_SERVICE_PORT_GRPC')
endpoint = '{:s}:{:s}'.format(str(host), str(port))
ret = await asyncio.gather(
run(endpoint, 1),
run(endpoint, 2)
)
print(ret)
if __name__ == '__main__':
logging.basicConfig()
asyncio.run(main_loop())
\ No newline at end of file
import asyncio, grpc, random
from common.proto.optical_attack_detector_pb2_grpc import OpticalAttackDetectorServiceStub
import logging, signal, sys, time, threading import logging, signal, sys, time, threading
from multiprocessing import Manager, Process from multiprocessing import Manager, Process
from typing import List from typing import List
from prometheus_client import start_http_server from prometheus_client import start_http_server
from celery import Celery
from common.Settings import get_log_level, get_metrics_port, get_setting from common.Settings import get_log_level, get_metrics_port, get_setting
from common.proto.context_pb2 import ContextIdList, Empty, EventTypeEnum, ServiceId, ServiceIdList from common.proto.context_pb2 import ContextIdList, Empty, EventTypeEnum, ServiceId, ServiceIdList
...@@ -15,14 +16,41 @@ from monitoring.client.MonitoringClient import MonitoringClient ...@@ -15,14 +16,41 @@ from monitoring.client.MonitoringClient import MonitoringClient
terminate = threading.Event() terminate = threading.Event()
LOGGER = None LOGGER = None
# For more channel options, please see https://grpc.io/grpc/core/group__grpc__arg__keys.html
CHANNEL_OPTIONS = [('grpc.lb_policy_name', 'pick_first'),
('grpc.enable_retries', True),
('grpc.keepalive_timeout_ms', 10000)]
# TODO: configure retries
def signal_handler(signal, frame): # pylint: disable=redefined-outer-name def signal_handler(signal, frame): # pylint: disable=redefined-outer-name
LOGGER.warning('Terminate signal received') LOGGER.warning('Terminate signal received')
terminate.set() terminate.set()
def monitor_services(app: Celery, service_list: List[ServiceId]): async def detect_attack(endpoint, context_id, service_id):
async with grpc.aio.insecure_channel(target=endpoint,
options=CHANNEL_OPTIONS) as channel:
stub = OpticalAttackDetectorServiceStub(channel)
service_id = ServiceId()
service_id.context_id.context_uuid.uuid = context_id
service_id.service_uuid.uuid = str(service_id)
# Timeout in seconds.
# Please refer gRPC Python documents for more detail. https://grpc.io/grpc/python/grpc.html
await stub.DetectAttack(service_id,
timeout=10)
print("Greeter client received:", service_id)
async def monitor_services(service_list: List[ServiceId]):
monitoring_interval = int(get_setting('MONITORING_INTERVAL', default=MONITORING_INTERVAL))
host = get_setting('OPTICALATTACKDETECTORSERVICE_SERVICE_HOST')
port = get_setting('OPTICALATTACKDETECTORSERVICE_SERVICE_PORT_GRPC')
endpoint = '{:s}:{:s}'.format(str(host), str(port))
monitoring_interval = get_setting('MONITORING_INTERVAL', default=MONITORING_INTERVAL) LOGGER.info('Starting execution of the async loop')
while not terminate.is_set(): while not terminate.is_set():
...@@ -32,26 +60,16 @@ def monitor_services(app: Celery, service_list: List[ServiceId]): ...@@ -32,26 +60,16 @@ def monitor_services(app: Celery, service_list: List[ServiceId]):
start_time = time.time() start_time = time.time()
try: tasks = []
tasks = [] for service in service_list:
aw = detect_attack(endpoint, service['context'], service['service'])
for service in service_list: tasks.append(aw)
LOGGER.debug('Scheduling service: {}'.format(service)) [await aw for aw in tasks]
tasks.append(
app.send_task('detect_attack', (service['context'], service['service'], service['kpi']))
)
for task in tasks:
LOGGER.debug('Waiting for task {}...'.format(task))
result = task.get()
LOGGER.debug('Result for task {} is {}...'.format(task, result))
except Exception as e:
LOGGER.exception(e)
end_time = time.time() end_time = time.time()
diff = end_time - start_time diff = end_time - start_time
LOGGER.info('Monitoring loop with {} services took {} seconds...'.format(len(service_list), diff)) LOGGER.info('Monitoring loop with {} services took {} seconds ({:.2f}%)... Waiting for {:.2f} seconds...'.format(len(service_list), diff, (diff / monitoring_interval) * 100, monitoring_interval - diff))
if diff / monitoring_interval > 0.9: if diff / monitoring_interval > 0.9:
LOGGER.warning('Monitoring loop is taking {} % of the desired time ({} seconds)'.format((diff / monitoring_interval) * 100, monitoring_interval)) LOGGER.warning('Monitoring loop is taking {} % of the desired time ({} seconds)'.format((diff / monitoring_interval) * 100, monitoring_interval))
...@@ -117,28 +135,17 @@ def main(): ...@@ -117,28 +135,17 @@ def main():
# Start metrics server # Start metrics server
metrics_port = get_metrics_port() metrics_port = get_metrics_port()
start_http_server(metrics_port) # start_http_server(metrics_port) # TODO: uncomment this line
LOGGER.info('Connecting with context component...') LOGGER.info('Connecting with context component...')
context_client: ContextClient = ContextClient() context_client: ContextClient = ContextClient()
context_client.connect() context_client.connect()
LOGGER.info('Connected successfully...') LOGGER.info('Connected successfully...')
LOGGER.info('Connecting with REDIS...')
REDIS_PASSWORD = get_setting('REDIS_PASSWORD')
REDIS_HOST = get_setting('CACHINGSERVICE_SERVICE_HOST')
REDIS_PORT = get_setting('CACHINGSERVICE_SERVICE_PORT')
BROKER_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0'
BACKEND_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/1'
app = Celery(
'cybersecurity',
broker=BROKER_URL,
backend=BACKEND_URL
)
LOGGER.info('Connected to REDIS...')
# creating a thread-safe list to be shared among threads # creating a thread-safe list to be shared among threads
service_list = Manager().list() service_list = Manager().list()
service_list.append({'context': 'admin', "service": "1213"})
service_list.append({'context': 'admin', "service": "1456"})
context_ids: ContextIdList = context_client.ListContextIds(Empty()) context_ids: ContextIdList = context_client.ListContextIds(Empty())
...@@ -155,10 +162,10 @@ def main(): ...@@ -155,10 +162,10 @@ def main():
process_context = Process(target=get_context_updates, args=(service_list,)) process_context = Process(target=get_context_updates, args=(service_list,))
process_context.start() process_context.start()
monitor_services(app, service_list) # runs the async loop in the background
loop = asyncio.get_event_loop()
# process_security_loop = Process(target=monitor_services, args=(app, service_list)) loop.run_until_complete(monitor_services(service_list))
# process_security_loop.start() # asyncio.create_task(monitor_services(service_list))
# Wait for Ctrl+C or termination signal # Wait for Ctrl+C or termination signal
while not terminate.wait(timeout=0.1): pass while not terminate.wait(timeout=0.1): pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment