diff --git a/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py b/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py index e3e75f590fe381e2a18d48c0812de5a5bea1ec52..07b5c6db19bfd726ffd4fedf71788507414972c9 100644 --- a/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py +++ b/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py @@ -32,6 +32,12 @@ from common.proto.context_pb2 import ( from common.proto.context_pb2_grpc import ContextServiceStub from common.proto.service_pb2_grpc import ServiceServiceStub +from common.proto.acl_pb2 import AclForwardActionEnum, AclLogActionEnum, AclRuleTypeEnum +from common.proto.context_pb2 import ConfigActionEnum, Service, ServiceId +from common.tools.grpc.Tools import grpc_message_to_json_string +from context.client.ContextClient import ContextClient +from service.client.ServiceClient import ServiceClient + LOGGER = logging.getLogger(__name__) CONTEXT_CHANNEL = "192.168.165.78:1010" SERVICE_CHANNEL = "192.168.165.78:3030" @@ -43,6 +49,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer): self.last_value = -1 self.last_tag = 0 + self.sequence_id = 0 def GenerateRuleValue(self, ip_o, ip_d, port_o, port_d): value = { @@ -84,6 +91,76 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer): return config_rule + def configure_acl_rule( + self, + context_uuid: str, + service_uuid: str, + device_uuid: str, + endpoint_uuid: str, + src_ip: str, + dst_ip: str, + src_port: str, + dst_port: str, + ) -> None: + # Create ServiceId + service_id = ServiceId() + service_id.context_id.context_uuid.uuid = context_uuid + service_id.service_uuid.uuid = service_uuid + + # Get service form Context + context_client = ContextClient() + + try: + _service: Service = context_client.GetService(service_id) + except: + raise Exception("Service({:s}) not found".format(grpc_message_to_json_string(service_id))) + + # _service is read-only; copy it to have an updatable service message + service_request = Service() + service_request.CopyFrom(_service) + + # Add ACL ConfigRule into the service service_request + acl_config_rule = service_request.service_config.config_rules.add() + acl_config_rule.action = ConfigActionEnum.CONFIGACTION_SET + + # Set EndpointId associated to the ACLRuleSet + acl_endpoint_id = acl_config_rule.acl.endpoint_id + acl_endpoint_id.device_id.device_uuid.uuid = device_uuid + acl_endpoint_id.endpoint_uuid.uuid = endpoint_uuid + + # Set RuleSet for this ACL ConfigRule + acl_rule_set = acl_config_rule.acl.rule_set + # TODO: update the following parameters; for instance, add them as parameters of the method configure_acl_rule + acl_rule_set.name = "DROP-HTTPS" + acl_rule_set.type = AclRuleTypeEnum.ACLRULETYPE_IPV4 + acl_rule_set.description = "DROP undesired HTTPS traffic" + + # Add ACLEntry to the ACLRuleSet + acl_entry = acl_rule_set.entries.add() + acl_entry.sequence_id = self.sequence_id + acl_entry.description = "DROP-{src_ip}:{src_port}-{dst_ip}:{dst_port}".format( + src_ip=src_ip, src_port=src_port, dst_ip=dst_ip, dst_port=dst_port + ) + acl_entry.match.protocol = ( + 6 # TCP according to https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml + ) + acl_entry.match.src_address = "{}/32".format(src_ip) + acl_entry.match.dst_address = "{}/32".format(dst_ip) + acl_entry.match.src_port = src_port + acl_entry.match.dst_port = dst_port + # TODO: update the following parameters; for instance, add them as parameters of the method configure_acl_rule + acl_entry.action.forward_action = AclForwardActionEnum.ACLFORWARDINGACTION_DROP + acl_entry.action.log_action = AclLogActionEnum.ACLLOGACTION_NOLOG + + # Update the Service with the new ACL RuleSet + service_client = ServiceClient() + service_reply: ServiceId = service_client.UpdateService(service_request) + + # TODO: Log the service_reply details + + if service_reply != service_request.service_id: # pylint: disable=no-member + raise Exception("Service update failed. Wrong ServiceId was returned") + def SendOutput(self, request, context): last_value = request.confidence last_tag = request.tag @@ -103,7 +180,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer): counter = 0 service_id = request.service_id - LOGGER.info("Service id:\n{}".format(service_id)) + LOGGER.info("ServiceId:\n{}".format(service_id)) while sentinel: try: @@ -114,19 +191,32 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer): LOGGER.debug("Waiting 2 seconds", counter, e) time.sleep(2) - LOGGER.info("Service obtained from id:\n{}".format(service)) + LOGGER.info("Service obtained from ServiceId:\n{}".format(service)) + # Old version config_rule = self.GetConfigRule(ip_o, ip_d, port_o, port_d) service_config = ServiceConfig() service_config.config_rules.extend([config_rule]) service.service_config.CopyFrom(service_config) + # New version + # self.configure_acl_rule( + # context_uuid=service_id.context_id.context_uuid.uuid, + # service_uuid=service_id.service_uuid.uuid, + # device_uuid=request.device_id.device_uuid.uuid, + # endpoint_uuid=request.endpoint_id.endpoint_uuid.uuid, + # src_ip=ip_o, + # dst_ip=ip_d, + # src_port=port_o, + # dst_port=port_d, + # ) + LOGGER.info("Service with new rule:\n{}".format(service)) self.UpdateService(service) service2 = self.GetService(service_id) - LOGGER.info("Service obtained from id after updating:\n{}".format(service2)) + LOGGER.info("Service obtained from ServiceId after updating with the new rule:\n{}".format(service2)) return Empty(message=f"OK, received values: {last_tag} with confidence {last_value}.")