Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • tfs/controller
1 result
Show changes
Showing
with 2063 additions and 160 deletions
...@@ -12,32 +12,52 @@ ...@@ -12,32 +12,52 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random, logging, pytest, numpy import logging
from dbscanserving.Config import GRPC_SERVICE_PORT, GRPC_MAX_WORKERS, GRPC_GRACE_PERIOD import os
import random
from unittest.mock import patch
import pytest
from common.proto.dbscanserving_pb2 import (DetectionRequest,
DetectionResponse, Sample)
from dbscanserving.client.DbscanServingClient import DbscanServingClient from dbscanserving.client.DbscanServingClient import DbscanServingClient
from dbscanserving.Config import GRPC_SERVICE_PORT
from dbscanserving.service.DbscanService import DbscanService from dbscanserving.service.DbscanService import DbscanService
from dbscanserving.proto.dbscanserving_pb2 import DetectionRequest, DetectionResponse, Sample
port = 10000 + GRPC_SERVICE_PORT # avoid privileged ports port = 10000 + GRPC_SERVICE_PORT # avoid privileged ports
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.DEBUG) LOGGER.setLevel(logging.DEBUG)
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def dbscanserving_service(): def dbscanserving_service():
_service = DbscanService( _service = DbscanService(port=port)
port=port, max_workers=GRPC_MAX_WORKERS, grace_period=GRPC_GRACE_PERIOD)
_service.start() _service.start()
yield _service yield _service
_service.stop() _service.stop()
@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def dbscanserving_client(): def dbscanserving_client():
_client = DbscanServingClient(address='127.0.0.1', port=port) with patch.dict(
yield _client os.environ,
{
"DBSCANSERVINGSERVICE_SERVICE_HOST": "127.0.0.1",
"DBSCANSERVINGSERVICE_SERVICE_PORT_GRPC": str(port),
},
clear=True,
):
_client = DbscanServingClient()
yield _client
_client.close() _client.close()
def test_detection_correct(dbscanserving_service, dbscanserving_client: DbscanServingClient):
def test_detection_correct(
dbscanserving_service, dbscanserving_client: DbscanServingClient
):
request: DetectionRequest = DetectionRequest() request: DetectionRequest = DetectionRequest()
request.num_samples = 310 request.num_samples = 310
...@@ -48,25 +68,28 @@ def test_detection_correct(dbscanserving_service, dbscanserving_client: DbscanSe ...@@ -48,25 +68,28 @@ def test_detection_correct(dbscanserving_service, dbscanserving_client: DbscanSe
for _ in range(200): for _ in range(200):
grpc_sample = Sample() grpc_sample = Sample()
for __ in range(100): for __ in range(100):
grpc_sample.features.append(random.uniform(0., 10.)) grpc_sample.features.append(random.uniform(0.0, 10.0))
request.samples.append(grpc_sample) request.samples.append(grpc_sample)
for _ in range(100): for _ in range(100):
grpc_sample = Sample() grpc_sample = Sample()
for __ in range(100): for __ in range(100):
grpc_sample.features.append(random.uniform(50., 60.)) grpc_sample.features.append(random.uniform(50.0, 60.0))
request.samples.append(grpc_sample) request.samples.append(grpc_sample)
for _ in range(10): for _ in range(10):
grpc_sample = Sample() grpc_sample = Sample()
for __ in range(100): for __ in range(100):
grpc_sample.features.append(random.uniform(5000., 6000.)) grpc_sample.features.append(random.uniform(5000.0, 6000.0))
request.samples.append(grpc_sample) request.samples.append(grpc_sample)
response: DetectionResponse = dbscanserving_client.Detect(request) response: DetectionResponse = dbscanserving_client.Detect(request)
assert len(response.cluster_indices) == 310 assert len(response.cluster_indices) == 310
def test_detection_incorrect(dbscanserving_service, dbscanserving_client: DbscanServingClient):
def test_detection_incorrect(
dbscanserving_service, dbscanserving_client: DbscanServingClient
):
request: DetectionRequest = DetectionRequest() request: DetectionRequest = DetectionRequest()
request.num_samples = 210 request.num_samples = 210
...@@ -77,25 +100,28 @@ def test_detection_incorrect(dbscanserving_service, dbscanserving_client: Dbscan ...@@ -77,25 +100,28 @@ def test_detection_incorrect(dbscanserving_service, dbscanserving_client: Dbscan
for _ in range(200): for _ in range(200):
grpc_sample = Sample() grpc_sample = Sample()
for __ in range(100): for __ in range(100):
grpc_sample.features.append(random.uniform(0., 10.)) grpc_sample.features.append(random.uniform(0.0, 10.0))
request.samples.append(grpc_sample) request.samples.append(grpc_sample)
for _ in range(100): for _ in range(100):
grpc_sample = Sample() grpc_sample = Sample()
for __ in range(100): for __ in range(100):
grpc_sample.features.append(random.uniform(50., 60.)) grpc_sample.features.append(random.uniform(50.0, 60.0))
request.samples.append(grpc_sample) request.samples.append(grpc_sample)
for _ in range(10): for _ in range(10):
grpc_sample = Sample() grpc_sample = Sample()
for __ in range(100): for __ in range(100):
grpc_sample.features.append(random.uniform(5000., 6000.)) grpc_sample.features.append(random.uniform(5000.0, 6000.0))
request.samples.append(grpc_sample) request.samples.append(grpc_sample)
with pytest.raises(Exception): with pytest.raises(Exception):
response: DetectionResponse = dbscanserving_client.Detect(request) _: DetectionResponse = dbscanserving_client.Detect(request)
def test_detection_clusters(dbscanserving_service, dbscanserving_client: DbscanServingClient):
def test_detection_clusters(
dbscanserving_service, dbscanserving_client: DbscanServingClient
):
request: DetectionRequest = DetectionRequest() request: DetectionRequest = DetectionRequest()
request.num_samples = 310 request.num_samples = 310
...@@ -106,19 +132,19 @@ def test_detection_clusters(dbscanserving_service, dbscanserving_client: DbscanS ...@@ -106,19 +132,19 @@ def test_detection_clusters(dbscanserving_service, dbscanserving_client: DbscanS
for _ in range(200): for _ in range(200):
grpc_sample = Sample() grpc_sample = Sample()
for __ in range(100): for __ in range(100):
grpc_sample.features.append(random.uniform(0., 10.)) grpc_sample.features.append(random.uniform(0.0, 10.0))
request.samples.append(grpc_sample) request.samples.append(grpc_sample)
for _ in range(100): for _ in range(100):
grpc_sample = Sample() grpc_sample = Sample()
for __ in range(100): for __ in range(100):
grpc_sample.features.append(random.uniform(50., 60.)) grpc_sample.features.append(random.uniform(50.0, 60.0))
request.samples.append(grpc_sample) request.samples.append(grpc_sample)
for _ in range(10): for _ in range(10):
grpc_sample = Sample() grpc_sample = Sample()
for __ in range(100): for __ in range(100):
grpc_sample.features.append(random.uniform(5000., 6000.)) grpc_sample.features.append(random.uniform(5000.0, 6000.0))
request.samples.append(grpc_sample) request.samples.append(grpc_sample)
response: DetectionResponse = dbscanserving_client.Detect(request) response: DetectionResponse = dbscanserving_client.Detect(request)
......
...@@ -66,7 +66,7 @@ def main(): ...@@ -66,7 +66,7 @@ def main():
grpc_service.start() grpc_service.start()
# Wait for Ctrl+C or termination signal # Wait for Ctrl+C or termination signal
while not terminate.wait(timeout=0.1): pass while not terminate.wait(timeout=1.0): pass
LOGGER.info('Terminating...') LOGGER.info('Terminating...')
grpc_service.stop() grpc_service.stop()
......
...@@ -81,7 +81,7 @@ class P4Driver(_Driver): ...@@ -81,7 +81,7 @@ class P4Driver(_Driver):
self.__endpoint = None self.__endpoint = None
self.__settings = settings self.__settings = settings
self.__id = None self.__id = None
self.__name = None self.__name = DRIVER_NAME
self.__vendor = P4_VAL_DEF_VENDOR self.__vendor = P4_VAL_DEF_VENDOR
self.__hw_version = P4_VAL_DEF_HW_VER self.__hw_version = P4_VAL_DEF_HW_VER
self.__sw_version = P4_VAL_DEF_SW_VER self.__sw_version = P4_VAL_DEF_SW_VER
......
...@@ -58,7 +58,7 @@ def main(): ...@@ -58,7 +58,7 @@ def main():
grpc_service.start() grpc_service.start()
# Wait for Ctrl+C or termination signal # Wait for Ctrl+C or termination signal
while not terminate.wait(timeout=0.1): pass while not terminate.wait(timeout=1.0): pass
LOGGER.info('Terminating...') LOGGER.info('Terminating...')
grpc_service.stop() grpc_service.stop()
......
...@@ -49,7 +49,7 @@ def main(): ...@@ -49,7 +49,7 @@ def main():
grpc_service.start() grpc_service.start()
# Wait for Ctrl+C or termination signal # Wait for Ctrl+C or termination signal
while not terminate.wait(timeout=0.1): pass while not terminate.wait(timeout=1.0): pass
LOGGER.info('Terminating...') LOGGER.info('Terminating...')
grpc_service.stop() grpc_service.stop()
......
...@@ -72,7 +72,7 @@ def main(): ...@@ -72,7 +72,7 @@ def main():
#remote_domain_clients.add_peer('remote-teraflow', 'remote-teraflow', interdomain_service_port_grpc) #remote_domain_clients.add_peer('remote-teraflow', 'remote-teraflow', interdomain_service_port_grpc)
# Wait for Ctrl+C or termination signal # Wait for Ctrl+C or termination signal
while not terminate.wait(timeout=0.1): pass while not terminate.wait(timeout=1.0): pass
LOGGER.info('Terminating...') LOGGER.info('Terminating...')
topology_abstractor.stop() topology_abstractor.stop()
......
...@@ -18,7 +18,7 @@ import logging ...@@ -18,7 +18,7 @@ import logging
LOG_LEVEL = logging.WARNING LOG_LEVEL = logging.WARNING
# gRPC settings # gRPC settings
GRPC_SERVICE_PORT = 10002 # TODO UPM FIXME GRPC_SERVICE_PORT = 10002
GRPC_MAX_WORKERS = 10 GRPC_MAX_WORKERS = 10
GRPC_GRACE_PERIOD = 60 GRPC_GRACE_PERIOD = 60
......
...@@ -63,6 +63,9 @@ RUN python3 -m pip install -r requirements.txt ...@@ -63,6 +63,9 @@ RUN python3 -m pip install -r requirements.txt
# Add component files into working directory # Add component files into working directory
WORKDIR /var/teraflow WORKDIR /var/teraflow
COPY src/l3_attackmitigator/. l3_attackmitigator COPY src/l3_attackmitigator/. l3_attackmitigator
COPY src/monitoring/. monitoring
COPY src/context/. context/
COPY src/service/. service/
# Start the service # Start the service
ENTRYPOINT ["python", "-m", "l3_attackmitigator.service"] ENTRYPOINT ["python", "-m", "l3_attackmitigator.service"]
...@@ -13,13 +13,18 @@ ...@@ -13,13 +13,18 @@
# limitations under the License. # limitations under the License.
import grpc, logging import grpc, logging
from common.Constants import ServiceNameEnum
from common.Settings import get_service_host, get_service_port_grpc
from common.tools.client.RetryDecorator import retry, delay_exponential from common.tools.client.RetryDecorator import retry, delay_exponential
from common.proto.l3_attackmitigator_pb2_grpc import ( from common.proto.l3_attackmitigator_pb2_grpc import (
L3AttackmitigatorStub, L3AttackmitigatorStub,
) )
from common.proto.l3_attackmitigator_pb2 import ( from common.proto.l3_attackmitigator_pb2 import (
Output, L3AttackmitigatorOutput, ACLRules
EmptyMitigator )
from common.proto.context_pb2 import (
Empty
) )
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
...@@ -28,8 +33,10 @@ DELAY_FUNCTION = delay_exponential(initial=0.01, increment=2.0, maximum=5.0) ...@@ -28,8 +33,10 @@ DELAY_FUNCTION = delay_exponential(initial=0.01, increment=2.0, maximum=5.0)
RETRY_DECORATOR = retry(max_retries=MAX_RETRIES, delay_function=DELAY_FUNCTION, prepare_method_name='connect') RETRY_DECORATOR = retry(max_retries=MAX_RETRIES, delay_function=DELAY_FUNCTION, prepare_method_name='connect')
class l3_attackmitigatorClient: class l3_attackmitigatorClient:
def __init__(self, address, port): def __init__(self, host=None, port=None):
self.endpoint = "{}:{}".format(address, port) if not host: host = get_service_host(ServiceNameEnum.L3_AM)
if not port: port = get_service_port_grpc(ServiceNameEnum.L3_AM)
self.endpoint = "{}:{}".format(host, port)
LOGGER.debug("Creating channel to {}...".format(self.endpoint)) LOGGER.debug("Creating channel to {}...".format(self.endpoint))
self.channel = None self.channel = None
self.stub = None self.stub = None
...@@ -47,16 +54,23 @@ class l3_attackmitigatorClient: ...@@ -47,16 +54,23 @@ class l3_attackmitigatorClient:
self.stub = None self.stub = None
@RETRY_DECORATOR @RETRY_DECORATOR
def SendOutput(self, request: Output) -> EmptyMitigator: def PerformMitigation(self, request: L3AttackmitigatorOutput) -> Empty:
LOGGER.debug('SendOutput request: {}'.format(request)) LOGGER.debug('PerformMitigation request: {}'.format(request))
response = self.stub.SendOutput(request) response = self.stub.PerformMitigation(request)
LOGGER.debug('SendOutput result: {}'.format(response)) LOGGER.debug('PerformMitigation result: {}'.format(response))
return response return response
@RETRY_DECORATOR @RETRY_DECORATOR
def GetMitigation(self, request: EmptyMitigator) -> EmptyMitigator: def GetMitigation(self, request: Empty) -> Empty:
LOGGER.debug('GetMitigation request: {}'.format(request)) LOGGER.debug('GetMitigation request: {}'.format(request))
response = self.stub.GetMitigation(request) response = self.stub.GetMitigation(request)
LOGGER.debug('GetMitigation result: {}'.format(response)) LOGGER.debug('GetMitigation result: {}'.format(response))
return response return response
@RETRY_DECORATOR
def GetConfiguredACLRules(self, request: Empty) -> ACLRules:
LOGGER.debug('GetConfiguredACLRules request: {}'.format(request))
response = self.stub.GetConfiguredACLRules(request)
LOGGER.debug('GetConfiguredACLRules result: {}'.format(response))
return response
...@@ -52,7 +52,7 @@ def main(): ...@@ -52,7 +52,7 @@ def main():
grpc_service.start() grpc_service.start()
# Wait for Ctrl+C or termination signal # Wait for Ctrl+C or termination signal
while not terminate.wait(timeout=0.1): pass while not terminate.wait(timeout=1.0): pass
logger.info('Terminating...') logger.info('Terminating...')
grpc_service.stop() grpc_service.stop()
......
...@@ -14,41 +14,184 @@ ...@@ -14,41 +14,184 @@
from __future__ import print_function from __future__ import print_function
import logging import logging
from common.proto.l3_attackmitigator_pb2 import ( import time
EmptyMitigator
) from common.proto.l3_centralizedattackdetector_pb2 import Empty
from common.proto.l3_attackmitigator_pb2_grpc import ( from common.proto.l3_attackmitigator_pb2_grpc import L3AttackmitigatorServicer
L3AttackmitigatorServicer, from common.proto.l3_attackmitigator_pb2 import ACLRules
from common.proto.context_pb2 import (
ServiceId,
ConfigActionEnum,
) )
from common.proto.acl_pb2 import AclForwardActionEnum, AclLogActionEnum, AclRuleTypeEnum
from common.proto.context_pb2 import ConfigActionEnum, Service, ServiceId, ConfigRule
from common.tools.grpc.Tools import grpc_message_to_json_string
from context.client.ContextClient import ContextClient
from service.client.ServiceClient import ServiceClient
from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer): METRICS_POOL = MetricsPool('l3_attackmitigator', 'RPC')
class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
def __init__(self): def __init__(self):
LOGGER.debug("Creating Servicer...") LOGGER.info("Creating Attack Mitigator Service")
self.last_value = -1 self.last_value = -1
self.last_tag = 0 self.last_tag = 0
self.sequence_id = 0
def SendOutput(self, request, context):
# SEND CONFIDENCE TO MITIGATION SERVER self.context_client = ContextClient()
logging.debug("") self.service_client = ServiceClient()
print("Server received mitigation values...", request.confidence) self.configured_acl_config_rules = []
def configure_acl_rule(
self,
context_uuid: str,
service_uuid: str,
device_uuid: str,
endpoint_uuid: str,
src_ip: str,
dst_ip: str,
src_port: str,
dst_port: str,
) -> None:
# Create ServiceId
service_id = ServiceId()
service_id.context_id.context_uuid.uuid = context_uuid
service_id.service_uuid.uuid = service_uuid
# Get service form Context
# context_client = ContextClient()
try:
_service: Service = self.context_client.GetService(service_id)
except:
raise Exception("Service({:s}) not found".format(grpc_message_to_json_string(service_id)))
# _service is read-only; copy it to have an updatable service message
service_request = Service()
service_request.CopyFrom(_service)
# Add ACL ConfigRule into the service service_request
acl_config_rule = service_request.service_config.config_rules.add()
acl_config_rule.action = ConfigActionEnum.CONFIGACTION_SET
# Set EndpointId associated to the ACLRuleSet
acl_endpoint_id = acl_config_rule.acl.endpoint_id
acl_endpoint_id.device_id.device_uuid.uuid = device_uuid
acl_endpoint_id.endpoint_uuid.uuid = endpoint_uuid
# Set RuleSet for this ACL ConfigRule
acl_rule_set = acl_config_rule.acl.rule_set
# TODO: update the following parameters; for instance, add them as parameters of the method configure_acl_rule
# acl_rule_set.name = "DROP-HTTPS"
acl_rule_set.name = "DROP-TCP"
acl_rule_set.type = AclRuleTypeEnum.ACLRULETYPE_IPV4
# acl_rule_set.description = "DROP undesired HTTPS traffic"
acl_rule_set.description = "DROP undesired TCP traffic"
# Add ACLEntry to the ACLRuleSet
acl_entry = acl_rule_set.entries.add()
acl_entry.sequence_id = self.sequence_id
acl_entry.description = "DROP-{src_ip}:{src_port}-{dst_ip}:{dst_port}".format(
src_ip=src_ip, src_port=src_port, dst_ip=dst_ip, dst_port=dst_port
)
acl_entry.match.protocol = (
6 # TCP according to https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
)
acl_entry.match.src_address = "{}/32".format(src_ip)
acl_entry.match.dst_address = "{}/32".format(dst_ip)
acl_entry.match.src_port = int(src_port)
acl_entry.match.dst_port = int(dst_port)
# TODO: update the following parameters; for instance, add them as parameters of the method configure_acl_rule
acl_entry.action.forward_action = AclForwardActionEnum.ACLFORWARDINGACTION_DROP
acl_entry.action.log_action = AclLogActionEnum.ACLLOGACTION_NOLOG
LOGGER.info("ACL Rule Set: %s", acl_rule_set)
LOGGER.info("ACL Config Rule: %s", acl_config_rule)
# Add the ACLRuleSet to the list of configured ACLRuleSets
self.configured_acl_config_rules.append(acl_config_rule)
# Update the Service with the new ACL RuleSet
# service_client = ServiceClient()
service_reply: ServiceId = self.service_client.UpdateService(service_request)
# TODO: Log the service_reply details
if service_reply != service_request.service_id: # pylint: disable=no-member
raise Exception("Service update failed. Wrong ServiceId was returned")
@safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
def PerformMitigation(self, request, context):
last_value = request.confidence last_value = request.confidence
last_tag = request.tag last_tag = request.tag
# RETURN OK TO THE CALLER
return EmptyMitigator( LOGGER.info(
message=f"OK, received values: {last_tag} with confidence {last_value}." "Attack Mitigator received attack mitigation information. Prediction confidence: %s, Predicted class: %s",
last_value,
last_tag,
) )
def GetMitigation(self, request, context): ip_o = request.ip_o
# GET OR PERFORM MITIGATION STRATEGY ip_d = request.ip_d
logging.debug("") port_o = request.port_o
print("Returing mitigation strategy...") port_d = request.port_d
k = self.last_value * 2
return EmptyMitigator( sentinel = True
message=f"Mitigation with double confidence = {k}" counter = 0
service_id = request.service_id
LOGGER.info("Service Id.:\n{}".format(service_id))
LOGGER.info("Retrieving service from Context")
while sentinel:
try:
service = self.context_client.GetService(service_id)
sentinel = False
except Exception as e:
counter = counter + 1
LOGGER.debug("Waiting 2 seconds", counter, e)
time.sleep(2)
LOGGER.info(f"Service with Service Id.: {service_id}\n{service}")
LOGGER.info("Adding new rule to the service to block the attack")
self.configure_acl_rule(
context_uuid=service_id.context_id.context_uuid.uuid,
service_uuid=service_id.service_uuid.uuid,
device_uuid=request.endpoint_id.device_id.device_uuid.uuid,
endpoint_uuid=request.endpoint_id.endpoint_uuid.uuid,
src_ip=ip_o,
dst_ip=ip_d,
src_port=port_o,
dst_port=port_d,
) )
LOGGER.info("Service with new rule:\n{}".format(service))
LOGGER.info("Updating service with the new rule")
self.service_client.UpdateService(service)
LOGGER.info(
"Service obtained from Context after updating with the new rule:\n{}".format(
self.context_client.GetService(service_id)
)
)
return Empty(message=f"OK, received values: {last_tag} with confidence {last_value}.")
@safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
def GetConfiguredACLRules(self, request, context):
acl_rules = ACLRules()
for acl_config_rule in self.configured_acl_config_rules:
acl_rules.acl_rules.append(acl_config_rule)
return acl_rules
# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import logging
from common.proto.l3_centralizedattackdetector_pb2 import (
Empty
)
from common.proto.l3_attackmitigator_pb2_grpc import (
L3AttackmitigatorServicer,
)
from common.proto.context_pb2 import (
Service, ServiceId, ServiceConfig, ServiceTypeEnum, ServiceStatusEnum, ServiceStatus, Context, ContextId, Uuid,
Timestamp, ConfigRule, ConfigRule_Custom, ConfigActionEnum, Device, DeviceId, DeviceConfig,
DeviceOperationalStatusEnum, DeviceDriverEnum, EndPoint, Link, LinkId, EndPoint, EndPointId, Topology, TopologyId
)
from common.proto.context_pb2_grpc import (
ContextServiceStub
)
from common.proto.service_pb2_grpc import (
ServiceServiceStub
)
from datetime import datetime
import grpc
LOGGER = logging.getLogger(__name__)
CONTEXT_CHANNEL = "192.168.165.78:1010"
SERVICE_CHANNEL = "192.168.165.78:3030"
class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
def GetNewService(self, service_id):
service = Service()
service_id_obj = self.GenerateServiceId(service_id)
service.service_id.CopyFrom(service_id_obj)
service.service_type = ServiceTypeEnum.SERVICETYPE_L3NM
service_status = ServiceStatus()
service_status.service_status = ServiceStatusEnum.SERVICESTATUS_ACTIVE
service.service_status.CopyFrom(service_status)
timestamp = Timestamp()
timestamp.timestamp = datetime.timestamp(datetime.now())
service.timestamp.CopyFrom(timestamp)
return service
def GetNewContext(self, service_id):
context = Context()
context_id = ContextId()
uuid = Uuid()
uuid.uuid = service_id
context_id.context_uuid.CopyFrom(uuid)
context.context_id.CopyFrom(context_id)
return context
def GetNewDevice(self, service_id):
device = Device()
device_id = DeviceId()
uuid = Uuid()
uuid.uuid = service_id
device_id.device_uuid.CopyFrom(uuid)
device.device_type="test"
device.device_id.CopyFrom(device_id)
device.device_operational_status = DeviceOperationalStatusEnum.DEVICEOPERATIONALSTATUS_ENABLED
return device
def GetNewLink(self, service_id):
link = Link()
link_id = LinkId()
uuid = Uuid()
uuid.uuid = service_id
link_id.link_uuid.CopyFrom(uuid)
link.link_id.CopyFrom(link_id)
return link
def GetNewTopology(self,context_id, device_id, link_id):
topology = Topology()
topology_id = TopologyId()
topology_id.context_id.CopyFrom(context_id)
uuid = Uuid()
uuid.uuid = "test_crypto"
topology_id.topology_uuid.CopyFrom(uuid)
topology.topology_id.CopyFrom(topology_id)
topology.device_ids.extend([device_id])
topology.link_ids.extend([link_id])
return topology
def GetNewEndpoint(self, topology_id, device_id, uuid_name):
endpoint = EndPoint()
endpoint_id = EndPointId()
endpoint_id.topology_id.CopyFrom(topology_id)
endpoint_id.device_id.CopyFrom(device_id)
uuid = Uuid()
uuid.uuid = uuid_name
endpoint_id.endpoint_uuid.CopyFrom(uuid)
endpoint.endpoint_id.CopyFrom(endpoint_id)
endpoint.endpoint_type = "test"
return endpoint
def __init__(self):
LOGGER.debug("Creating Servicer...")
self.last_value = -1
self.last_tag = 0
"""
context = self.GetNewContext("test_crypto")
print(context, flush=True)
print(self.SetContext(context))
service = self.GetNewService("test_crypto")
print("This is the new service", self.CreateService(service), flush = True)
ip_o = "127.0.0.1"
ip_d = "127.0.0.2"
port_o = "123"
port_d = "124"
service_id = self.GenerateServiceId("test_crypto")
config_rule = self.GetConfigRule(ip_o, ip_d, port_o, port_d)
service = self.GetService(service_id)
print("Service obtained from id", service, flush=True)
config_rule = self.GetConfigRule(ip_o, ip_d, port_o, port_d)
#service_config = service.service_config
#service_config.append(config_rule)
service_config = ServiceConfig()
service_config.config_rules.extend([config_rule])
service.service_config.CopyFrom(service_config)
device = self.GetNewDevice("test_crypto")
print("New device", device, flush=True)
device_id = self.SetDevice(device)
link = self.GetNewLink("test_crypto")
print("New link", link, flush=True)
link_id = self.SetLink(link)
topology = self.GetNewTopology(context.context_id, device.device_id, link.link_id)
print("New topology", topology, flush=True)
topology_id = self.SetTopology(topology)
endpoint = self.GetNewEndpoint(topology.topology_id, device.device_id, "test_crypto")
print("New endpoint", endpoint, flush=True)
link.link_endpoint_ids.extend([endpoint.endpoint_id])
self.SetLink(link)
print("Service with new rule", service, flush=True)
self.UpdateService(service)
service2 = self.GetService(service_id)
print("Service obtained from id after updating", service2, flush=True)
"""
def GenerateRuleValue(self, ip_o, ip_d, port_o, port_d):
value = {
'ipv4:source-address': ip_o,
'ipv4:destination-address': ip_d,
'transport:source-port': port_o,
'transport:destination-port': port_d,
'forwarding-action': 'DROP',
}
return value
def GetConfigRule(self, ip_o, ip_d, port_o, port_d):
config_rule = ConfigRule()
config_rule_custom = ConfigRule_Custom()
config_rule.action = ConfigActionEnum.CONFIGACTION_SET
config_rule_custom.resource_key = 'test'
config_rule_custom.resource_value = str(self.GenerateRuleValue(ip_o, ip_d, port_o, port_d))
config_rule.custom.CopyFrom(config_rule_custom)
return config_rule
def GenerateServiceId(self, service_id):
service_id_obj = ServiceId()
context_id = ContextId()
uuid = Uuid()
uuid.uuid = service_id
context_id.context_uuid.CopyFrom(uuid)
service_id_obj.context_id.CopyFrom(context_id)
service_id_obj.service_uuid.CopyFrom(uuid)
return service_id_obj
def SendOutput(self, request, context):
# SEND CONFIDENCE TO MITIGATION SERVER
print("Server received mitigation values...", request.confidence, flush=True)
last_value = request.confidence
last_tag = request.tag
ip_o = request.ip_o
ip_d = request.ip_d
port_o = request.port_o
port_d = request.port_d
service_id = self.GenerateServiceId(request.service_id)
config_rule = self.GetConfigRule(ip_o, ip_d, port_o, port_d)
service = GetService(service_id)
print(service)
#service.config_rules.append(config_rule)
#UpdateService(service)
# RETURN OK TO THE CALLER
return Empty(
message=f"OK, received values: {last_tag} with confidence {last_value}."
)
def SetDevice(self, device):
with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
stub = ContextServiceStub(channel)
return stub.SetDevice(device)
def SetLink(self, link):
with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
stub = ContextServiceStub(channel)
return stub.SetLink(link)
def SetTopology(self, link):
with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
stub = ContextServiceStub(channel)
return stub.SetTopology(link)
def GetService(self, service_id):
with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
stub = ContextServiceStub(channel)
return stub.GetService(service_id)
def SetContext(self, context):
with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
stub = ContextServiceStub(channel)
return stub.SetContext(context)
def UpdateService(self, service):
with grpc.insecure_channel(SERVICE_CHANNEL) as channel:
stub = ServiceServiceStub(channel)
stub.UpdateService(service)
def CreateService(self, service):
with grpc.insecure_channel(SERVICE_CHANNEL) as channel:
stub = ServiceServiceStub(channel)
stub.CreateService(service)
def GetMitigation(self, request, context):
# GET OR PERFORM MITIGATION STRATEGY
logging.debug("")
print("Returing mitigation strategy...")
k = self.last_value * 2
return Empty(
message=f"Mitigation with double confidence = {k}"
)
...@@ -52,7 +52,7 @@ unit test l3_centralizedattackdetector: ...@@ -52,7 +52,7 @@ unit test l3_centralizedattackdetector:
- if docker container ls | grep $IMAGE_NAME; then docker rm -f $IMAGE_NAME; else echo "$IMAGE_NAME image is not in the system"; fi - if docker container ls | grep $IMAGE_NAME; then docker rm -f $IMAGE_NAME; else echo "$IMAGE_NAME image is not in the system"; fi
script: script:
- docker pull "$CI_REGISTRY_IMAGE/$IMAGE_NAME:$IMAGE_TAG" - docker pull "$CI_REGISTRY_IMAGE/$IMAGE_NAME:$IMAGE_TAG"
- docker run --name $IMAGE_NAME -d -p 10001:10001 -v "$PWD/src/$IMAGE_NAME/tests:/opt/results" --network=teraflowbridge $CI_REGISTRY_IMAGE/$IMAGE_NAME:$IMAGE_TAG - docker run --name $IMAGE_NAME -d -p 10001:10001 --env CAD_CLASSIFICATION_THRESHOLD=0.5 -v "$PWD/src/$IMAGE_NAME/tests:/opt/results" --network=teraflowbridge $CI_REGISTRY_IMAGE/$IMAGE_NAME:$IMAGE_TAG
- sleep 5 - sleep 5
- docker ps -a - docker ps -a
- docker logs $IMAGE_NAME - docker logs $IMAGE_NAME
......
...@@ -44,6 +44,10 @@ WORKDIR /var/teraflow/common ...@@ -44,6 +44,10 @@ WORKDIR /var/teraflow/common
COPY src/common/. ./ COPY src/common/. ./
RUN rm -rf proto RUN rm -rf proto
RUN mkdir -p /var/teraflow/l3_attackmitigator
WORKDIR /var/teraflow/l3_attackmitigator
COPY src/l3_attackmitigator/. ./
# Create proto sub-folder, copy .proto files, and generate Python code # Create proto sub-folder, copy .proto files, and generate Python code
RUN mkdir -p /var/teraflow/common/proto RUN mkdir -p /var/teraflow/common/proto
WORKDIR /var/teraflow/common/proto WORKDIR /var/teraflow/common/proto
...@@ -63,6 +67,7 @@ RUN python3 -m pip install -r requirements.txt ...@@ -63,6 +67,7 @@ RUN python3 -m pip install -r requirements.txt
# Add component files into working directory # Add component files into working directory
WORKDIR /var/teraflow WORKDIR /var/teraflow
COPY src/l3_centralizedattackdetector/. l3_centralizedattackdetector COPY src/l3_centralizedattackdetector/. l3_centralizedattackdetector
COPY src/monitoring/. monitoring
# Start the service # Start the service
ENTRYPOINT ["python", "-m", "l3_centralizedattackdetector.service"] ENTRYPOINT ["python", "-m", "l3_centralizedattackdetector.service"]
...@@ -18,7 +18,10 @@ from common.proto.l3_centralizedattackdetector_pb2_grpc import ( ...@@ -18,7 +18,10 @@ from common.proto.l3_centralizedattackdetector_pb2_grpc import (
L3CentralizedattackdetectorStub, L3CentralizedattackdetectorStub,
) )
from common.proto.l3_centralizedattackdetector_pb2 import ( from common.proto.l3_centralizedattackdetector_pb2 import (
AutoFeatures,
Empty, Empty,
L3CentralizedattackdetectorBatchInput,
L3CentralizedattackdetectorMetrics,
ModelInput, ModelInput,
ModelOutput ModelOutput
) )
...@@ -48,17 +51,24 @@ class l3_centralizedattackdetectorClient: ...@@ -48,17 +51,24 @@ class l3_centralizedattackdetectorClient:
self.stub = None self.stub = None
@RETRY_DECORATOR @RETRY_DECORATOR
def SendInput(self, request: ModelInput) -> Empty: def AnalyzeConnectionStatistics(self, request: L3CentralizedattackdetectorMetrics) -> Empty:
LOGGER.debug('SendInput request: {}'.format(request)) LOGGER.debug('AnalyzeConnectionStatistics request: {}'.format(request))
response = self.stub.SendInput(request) response = self.stub.AnalyzeConnectionStatistics(request)
LOGGER.debug('SendInput result: {}'.format(response)) LOGGER.debug('AnalyzeConnectionStatistics result: {}'.format(response))
return response return response
@RETRY_DECORATOR @RETRY_DECORATOR
def GetOutput(self, request: Empty) -> ModelOutput: def AnalyzeBatchConnectionStatistics(self, request: L3CentralizedattackdetectorBatchInput) -> Empty:
LOGGER.debug('GetOutput request: {}'.format(request)) LOGGER.debug('AnalyzeBatchConnectionStatistics request: {}'.format(request))
response = self.stub.GetOutput(request) response = self.stub.GetOutput(request)
LOGGER.debug('GetOutput result: {}'.format(response)) LOGGER.debug('AnalyzeBatchConnectionStatistics result: {}'.format(response))
return response
@RETRY_DECORATOR
def GetFeaturesIds(self, request: Empty) -> AutoFeatures:
LOGGER.debug('GetFeaturesIds request: {}'.format(request))
response = self.stub.GetOutput(request)
LOGGER.debug('GetFeaturesIds result: {}'.format(response))
return response return response
...@@ -53,7 +53,7 @@ def main(): ...@@ -53,7 +53,7 @@ def main():
grpc_service.start() grpc_service.start()
# Wait for Ctrl+C or termination signal # Wait for Ctrl+C or termination signal
while not terminate.wait(timeout=0.1): pass while not terminate.wait(timeout=1.0): pass
logger.info('Terminating...') logger.info('Terminating...')
grpc_service.stop() grpc_service.stop()
......
...@@ -85,7 +85,6 @@ class l3_centralizedattackdetectorService: ...@@ -85,7 +85,6 @@ class l3_centralizedattackdetectorService:
) # pylint: disable=maybe-no-member ) # pylint: disable=maybe-no-member
LOGGER.debug("Service started") LOGGER.debug("Service started")
#self.l3_centralizedattackdetector_servicer.setup_l3_centralizedattackdetector()
def stop(self): def stop(self):
LOGGER.debug( LOGGER.debug(
......