Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • tfs/controller
1 result
Show changes
Commits on Source (2)
......@@ -18,7 +18,7 @@ import logging
LOG_LEVEL = logging.WARNING
# gRPC settings
GRPC_SERVICE_PORT = 10002 # TODO UPM FIXME
GRPC_SERVICE_PORT = 10002
GRPC_MAX_WORKERS = 10
GRPC_GRACE_PERIOD = 60
......
......@@ -14,41 +14,29 @@
from __future__ import print_function
import logging
import grpc
import time
import json
from common.proto.l3_centralizedattackdetector_pb2 import Empty
from common.proto.l3_attackmitigator_pb2_grpc import L3AttackmitigatorServicer
from common.proto.context_pb2 import (
Service,
ServiceId,
ServiceConfig,
ServiceTypeEnum,
ServiceStatusEnum,
ServiceStatus,
Context,
ContextId,
Uuid,
Timestamp,
ConfigRule,
ConfigRule_Custom,
ConfigActionEnum,
Device,
DeviceId,
DeviceConfig,
DeviceOperationalStatusEnum,
DeviceDriverEnum,
EndPoint,
Link,
LinkId,
EndPoint,
EndPointId,
Topology,
TopologyId,
)
from common.proto.context_pb2_grpc import ContextServiceStub
from common.proto.service_pb2_grpc import ServiceServiceStub
from datetime import datetime
import grpc
import time
import json
from common.proto.acl_pb2 import AclForwardActionEnum, AclLogActionEnum, AclRuleTypeEnum
from common.proto.context_pb2 import ConfigActionEnum, Service, ServiceId
from common.tools.grpc.Tools import grpc_message_to_json_string
from context.client.ContextClient import ContextClient
from service.client.ServiceClient import ServiceClient
LOGGER = logging.getLogger(__name__)
CONTEXT_CHANNEL = "192.168.165.78:1010"
......@@ -57,9 +45,11 @@ SERVICE_CHANNEL = "192.168.165.78:3030"
class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
def __init__(self):
LOGGER.debug("Creating Servicer...")
LOGGER.info("Creating Attack Mitigator Service")
self.last_value = -1
self.last_tag = 0
self.sequence_id = 0
def GenerateRuleValue(self, ip_o, ip_d, port_o, port_d):
value = {
......@@ -69,6 +59,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
"transport:destination-port": port_d,
"forwarding-action": "DROP",
}
return value
def GenerateContextId(self, context_id):
......@@ -76,6 +67,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
uuid = Uuid()
uuid.uuid = context_id
context_id_obj.context_uuid.CopyFrom(uuid)
return context_id_obj
def GenerateServiceId(self, service_id):
......@@ -86,47 +78,109 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
context_id.context_uuid.CopyFrom(uuid)
service_id_obj.context_id.CopyFrom(context_id)
service_id_obj.service_uuid.CopyFrom(uuid)
return service_id_obj
def GetConfigRule(self, ip_o, ip_d, port_o, port_d):
config_rule = ConfigRule()
config_rule_custom = ConfigRule_Custom()
config_rule.action = ConfigActionEnum.CONFIGACTION_SET
config_rule_custom.resource_key = "test"
# config_rule_custom.resource_value = str(self.GenerateRuleValue(ip_o, ip_d, port_o, port_d))
config_rule_custom.resource_key = "acl"
config_rule_custom.resource_value = json.dumps(self.GenerateRuleValue(ip_o, ip_d, port_o, port_d))
config_rule.custom.CopyFrom(config_rule_custom)
return config_rule
def SendOutput(self, request, context):
# SEND CONFIDENCE TO MITIGATION SERVER
print("Server received mitigation values...", request.confidence, flush=True)
def configure_acl_rule(
self,
context_uuid: str,
service_uuid: str,
device_uuid: str,
endpoint_uuid: str,
src_ip: str,
dst_ip: str,
src_port: str,
dst_port: str,
) -> None:
# Create ServiceId
service_id = ServiceId()
service_id.context_id.context_uuid.uuid = context_uuid
service_id.service_uuid.uuid = service_uuid
# Get service form Context
context_client = ContextClient()
try:
_service: Service = context_client.GetService(service_id)
except:
raise Exception("Service({:s}) not found".format(grpc_message_to_json_string(service_id)))
# _service is read-only; copy it to have an updatable service message
service_request = Service()
service_request.CopyFrom(_service)
# Add ACL ConfigRule into the service service_request
acl_config_rule = service_request.service_config.config_rules.add()
acl_config_rule.action = ConfigActionEnum.CONFIGACTION_SET
# Set EndpointId associated to the ACLRuleSet
acl_endpoint_id = acl_config_rule.acl.endpoint_id
acl_endpoint_id.device_id.device_uuid.uuid = device_uuid
acl_endpoint_id.endpoint_uuid.uuid = endpoint_uuid
# Set RuleSet for this ACL ConfigRule
acl_rule_set = acl_config_rule.acl.rule_set
# TODO: update the following parameters; for instance, add them as parameters of the method configure_acl_rule
acl_rule_set.name = "DROP-HTTPS"
acl_rule_set.type = AclRuleTypeEnum.ACLRULETYPE_IPV4
acl_rule_set.description = "DROP undesired HTTPS traffic"
# Add ACLEntry to the ACLRuleSet
acl_entry = acl_rule_set.entries.add()
acl_entry.sequence_id = self.sequence_id
acl_entry.description = "DROP-{src_ip}:{src_port}-{dst_ip}:{dst_port}".format(
src_ip=src_ip, src_port=src_port, dst_ip=dst_ip, dst_port=dst_port
)
acl_entry.match.protocol = (
6 # TCP according to https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
)
acl_entry.match.src_address = "{}/32".format(src_ip)
acl_entry.match.dst_address = "{}/32".format(dst_ip)
acl_entry.match.src_port = src_port
acl_entry.match.dst_port = dst_port
# TODO: update the following parameters; for instance, add them as parameters of the method configure_acl_rule
acl_entry.action.forward_action = AclForwardActionEnum.ACLFORWARDINGACTION_DROP
acl_entry.action.log_action = AclLogActionEnum.ACLLOGACTION_NOLOG
# Update the Service with the new ACL RuleSet
service_client = ServiceClient()
service_reply: ServiceId = service_client.UpdateService(service_request)
# TODO: Log the service_reply details
if service_reply != service_request.service_id: # pylint: disable=no-member
raise Exception("Service update failed. Wrong ServiceId was returned")
def SendOutput(self, request, context):
last_value = request.confidence
last_tag = request.tag
LOGGER.info(
"Attack Mitigator received attack mitigation information. Prediction confidence: %s, Predicted class: %s",
last_value,
last_tag,
)
ip_o = request.ip_o
ip_d = request.ip_d
port_o = request.port_o
port_d = request.port_d
# service_id = self.GenerateServiceId(request.service_id)
# service = GetService(service_id)
# context_id = self.GenerateContextId("admin")
sentinel = True
counter = 0
# service_id_list = self.ListServiceIds(context_id)
# print(hello, flush = True)
# print(hello.service_ids[0].service_uuid.uuid, flush=True)
# service_id = service_id_list.service_ids[0]
service_id = request.service_id
print("Service id: ", service_id, flush=True)
LOGGER.info("ServiceId:\n{}".format(service_id))
while sentinel:
try:
......@@ -134,24 +188,36 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
sentinel = False
except Exception as e:
counter = counter + 1
print("Waiting 2 seconds", counter, e, flush=True)
LOGGER.debug("Waiting 2 seconds", counter, e)
time.sleep(2)
print("Service obtained from id: ", service, flush=True)
LOGGER.info("Service obtained from ServiceId:\n{}".format(service))
# Old version
config_rule = self.GetConfigRule(ip_o, ip_d, port_o, port_d)
service_config = ServiceConfig()
service_config.config_rules.extend([config_rule])
service.service_config.CopyFrom(service_config)
print("Service with new rule: ", service, flush=True)
# New version
# self.configure_acl_rule(
# context_uuid=service_id.context_id.context_uuid.uuid,
# service_uuid=service_id.service_uuid.uuid,
# device_uuid=request.device_id.device_uuid.uuid,
# endpoint_uuid=request.endpoint_id.endpoint_uuid.uuid,
# src_ip=ip_o,
# dst_ip=ip_d,
# src_port=port_o,
# dst_port=port_d,
# )
LOGGER.info("Service with new rule:\n{}".format(service))
self.UpdateService(service)
service2 = self.GetService(service_id)
print("Service obtained from id after updating: ", service2, flush=True)
LOGGER.info("Service obtained from ServiceId after updating with the new rule:\n{}".format(service2))
# RETURN OK TO THE CALLER
return Empty(message=f"OK, received values: {last_tag} with confidence {last_value}.")
def GetService(self, service_id):
......@@ -170,8 +236,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
stub.UpdateService(service)
def GetMitigation(self, request, context):
# GET OR PERFORM MITIGATION STRATEGY
logging.debug("")
print("Returing mitigation strategy...")
logging.info("Returning mitigation strategy...")
k = self.last_value * 2
return Empty(message=f"Mitigation with double confidence = {k}")