# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import logging import grpc import time import json from common.proto.l3_centralizedattackdetector_pb2 import Empty from common.proto.l3_attackmitigator_pb2_grpc import L3AttackmitigatorServicer from common.proto.context_pb2 import ( ServiceId, ServiceConfig, ContextId, Uuid, ConfigRule, ConfigRule_Custom, ConfigActionEnum, ) from common.proto.context_pb2_grpc import ContextServiceStub from common.proto.service_pb2_grpc import ServiceServiceStub from common.proto.acl_pb2 import AclForwardActionEnum, AclLogActionEnum, AclRuleTypeEnum from common.proto.context_pb2 import ConfigActionEnum, Service, ServiceId from common.tools.grpc.Tools import grpc_message_to_json_string from context.client.ContextClient import ContextClient from service.client.ServiceClient import ServiceClient LOGGER = logging.getLogger(__name__) CONTEXT_CHANNEL = "192.168.165.78:1010" SERVICE_CHANNEL = "192.168.165.78:3030" class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer): def __init__(self): LOGGER.info("Creating Attack Mitigator Service") self.last_value = -1 self.last_tag = 0 self.sequence_id = 0 def GenerateRuleValue(self, ip_o, ip_d, port_o, port_d): value = { "ipv4:source-address": ip_o, "ipv4:destination-address": ip_d, "transport:source-port": port_o, "transport:destination-port": port_d, "forwarding-action": "DROP", } return value def GenerateContextId(self, context_id): context_id_obj = ContextId() uuid = Uuid() uuid.uuid = context_id context_id_obj.context_uuid.CopyFrom(uuid) return context_id_obj def GenerateServiceId(self, service_id): service_id_obj = ServiceId() context_id = ContextId() uuid = Uuid() uuid.uuid = service_id context_id.context_uuid.CopyFrom(uuid) service_id_obj.context_id.CopyFrom(context_id) service_id_obj.service_uuid.CopyFrom(uuid) return service_id_obj def GetConfigRule(self, ip_o, ip_d, port_o, port_d): config_rule = ConfigRule() config_rule_custom = ConfigRule_Custom() config_rule.action = ConfigActionEnum.CONFIGACTION_SET config_rule_custom.resource_key = "acl" config_rule_custom.resource_value = json.dumps(self.GenerateRuleValue(ip_o, ip_d, port_o, port_d)) config_rule.custom.CopyFrom(config_rule_custom) return config_rule def configure_acl_rule( self, context_uuid: str, service_uuid: str, device_uuid: str, endpoint_uuid: str, src_ip: str, dst_ip: str, src_port: str, dst_port: str, ) -> None: # Create ServiceId service_id = ServiceId() service_id.context_id.context_uuid.uuid = context_uuid service_id.service_uuid.uuid = service_uuid # Get service form Context context_client = ContextClient() try: _service: Service = context_client.GetService(service_id) except: raise Exception("Service({:s}) not found".format(grpc_message_to_json_string(service_id))) # _service is read-only; copy it to have an updatable service message service_request = Service() service_request.CopyFrom(_service) # Add ACL ConfigRule into the service service_request acl_config_rule = service_request.service_config.config_rules.add() acl_config_rule.action = ConfigActionEnum.CONFIGACTION_SET # Set EndpointId associated to the ACLRuleSet acl_endpoint_id = acl_config_rule.acl.endpoint_id acl_endpoint_id.device_id.device_uuid.uuid = device_uuid acl_endpoint_id.endpoint_uuid.uuid = endpoint_uuid # Set RuleSet for this ACL ConfigRule acl_rule_set = acl_config_rule.acl.rule_set # TODO: update the following parameters; for instance, add them as parameters of the method configure_acl_rule acl_rule_set.name = "DROP-HTTPS" acl_rule_set.type = AclRuleTypeEnum.ACLRULETYPE_IPV4 acl_rule_set.description = "DROP undesired HTTPS traffic" # Add ACLEntry to the ACLRuleSet acl_entry = acl_rule_set.entries.add() acl_entry.sequence_id = self.sequence_id acl_entry.description = "DROP-{src_ip}:{src_port}-{dst_ip}:{dst_port}".format( src_ip=src_ip, src_port=src_port, dst_ip=dst_ip, dst_port=dst_port ) acl_entry.match.protocol = ( 6 # TCP according to https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml ) acl_entry.match.src_address = "{}/32".format(src_ip) acl_entry.match.dst_address = "{}/32".format(dst_ip) acl_entry.match.src_port = src_port acl_entry.match.dst_port = dst_port # TODO: update the following parameters; for instance, add them as parameters of the method configure_acl_rule acl_entry.action.forward_action = AclForwardActionEnum.ACLFORWARDINGACTION_DROP acl_entry.action.log_action = AclLogActionEnum.ACLLOGACTION_NOLOG # Update the Service with the new ACL RuleSet service_client = ServiceClient() service_reply: ServiceId = service_client.UpdateService(service_request) # TODO: Log the service_reply details if service_reply != service_request.service_id: # pylint: disable=no-member raise Exception("Service update failed. Wrong ServiceId was returned") def SendOutput(self, request, context): last_value = request.confidence last_tag = request.tag LOGGER.info( "Attack Mitigator received attack mitigation information. Prediction confidence: %s, Predicted class: %s", last_value, last_tag, ) ip_o = request.ip_o ip_d = request.ip_d port_o = request.port_o port_d = request.port_d sentinel = True counter = 0 service_id = request.service_id LOGGER.info("ServiceId:\n{}".format(service_id)) while sentinel: try: service = self.GetService(service_id) sentinel = False except Exception as e: counter = counter + 1 LOGGER.debug("Waiting 2 seconds", counter, e) time.sleep(2) LOGGER.info("Service obtained from ServiceId:\n{}".format(service)) # Old version config_rule = self.GetConfigRule(ip_o, ip_d, port_o, port_d) service_config = ServiceConfig() service_config.config_rules.extend([config_rule]) service.service_config.CopyFrom(service_config) # New version # self.configure_acl_rule( # context_uuid=service_id.context_id.context_uuid.uuid, # service_uuid=service_id.service_uuid.uuid, # device_uuid=request.device_id.device_uuid.uuid, # endpoint_uuid=request.endpoint_id.endpoint_uuid.uuid, # src_ip=ip_o, # dst_ip=ip_d, # src_port=port_o, # dst_port=port_d, # ) LOGGER.info("Service with new rule:\n{}".format(service)) self.UpdateService(service) service2 = self.GetService(service_id) LOGGER.info("Service obtained from ServiceId after updating with the new rule:\n{}".format(service2)) return Empty(message=f"OK, received values: {last_tag} with confidence {last_value}.") def GetService(self, service_id): with grpc.insecure_channel(CONTEXT_CHANNEL) as channel: stub = ContextServiceStub(channel) return stub.GetService(service_id) def ListServiceIds(self, context_id): with grpc.insecure_channel(CONTEXT_CHANNEL) as channel: stub = ContextServiceStub(channel) return stub.ListServiceIds(context_id) def UpdateService(self, service): with grpc.insecure_channel(SERVICE_CHANNEL) as channel: stub = ServiceServiceStub(channel) stub.UpdateService(service) def GetMitigation(self, request, context): logging.info("Returning mitigation strategy...") k = self.last_value * 2 return Empty(message=f"Mitigation with double confidence = {k}")