Commit 3b14aafe authored by Lluis Gifre Renom's avatar Lluis Gifre Renom
Browse files

Tests - Tools - Firewall Agent:

- Update TODO.md file
- Update test commands
- Enable add/insert rules
- Enabled sorting of rules
- Factorized code for converting openconfig to NFT
parent 35a3f12a
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
# TODO

- To block traffic to a MicroK8s service, apply rule:
  `sudo nft add/insert rule ip filter FORWARD iifname "enp0s3" tcp dport 85 drop`

  - when applying to ingress, also apply to FORWARD
  - when applying to egress, also apply to FORWARD
  - rules (INPUT, FORWARD, OUTPUT) should be applied based on sequence_id
    - negative sequence_id => insert on top (first -1, then -2, then -3) so that order of rules is -3,-2,-1 at top of the chain
    - positive sequence_id => append on bottom (first 0, then 1, then 2) so that order of rules is 0, 1, 2 at bottom of the chain
+131 −147
Original line number Diff line number Diff line
@@ -13,22 +13,88 @@
# limitations under the License.


import ipaddress, logging
import logging
from flask import request
from flask_restful import Api, Resource, abort
from typing import Dict, List, Set, Tuple
from .nft_model.ActionEnum import ActionEnum, get_action_from_str
from .nft_model.DirectionEnum import DirectionEnum
from .nft_model.FamilyEnum import FamilyEnum
from .nft_model.NFTables import NFTables
from .nft_model.ProtocolEnum import ProtocolEnum, get_protocol_from_str
from .nft_model.Rule import Rule
from .nft_model.TableEnum import TableEnum


LOGGER = logging.getLogger(__name__)


BASE_URL_ROOT = '/restconf/data/openconfig-acl:acl'
BASE_URL_ITEM = '/restconf/data/openconfig-acl:acl/acl-sets/acl-set=<name>'

LOGGER = logging.getLogger(__name__)
CHAIN_NAME_INPUT   = 'INPUT'
CHAIN_NAME_FORWARD = 'FORWARD'
CHAIN_NAME_OUTPUT  = 'OUTPUT'

CHAINS_INPUT  = [CHAIN_NAME_INPUT, CHAIN_NAME_FORWARD]
CHAINS_OUTPUT = [CHAIN_NAME_FORWARD, CHAIN_NAME_OUTPUT]
CHAINS_ALL    = [CHAIN_NAME_INPUT, CHAIN_NAME_FORWARD, CHAIN_NAME_OUTPUT]

TYPE_ACL_RULE_SEQ_ID    = Tuple[str, int]
TYPE_IFACE_DIRECTION    = Tuple[str, DirectionEnum]
TYPE_IFACE_DIRECTIONS   = List[TYPE_IFACE_DIRECTION]
TYPE_ACL_RULE_TO_IF_DIR = Dict[TYPE_ACL_RULE_SEQ_ID, TYPE_IFACE_DIRECTIONS]

def get_family_from_acl_set_type(acl_set_type : str) -> FamilyEnum:
    return {
        'ACL_IPV4' : FamilyEnum.IPV4,
        'ACL_IPV6' : FamilyEnum.IPV6,
    }[acl_set_type]

class AclRuleToInterfaceDirection:
    def __init__(self, nft : NFTables):
        self._nft = nft
        self._acl_rule_to_iface_direction : TYPE_ACL_RULE_TO_IF_DIR = dict()

    def create_nft_chains_in_table(self, acl_set_type : str, chain_names : List[str]) -> None:
        family = get_family_from_acl_set_type(acl_set_type)
        table = self._nft.get_or_create_table(family, TableEnum.FILTER)
        for chain_name in chain_names:
            table.get_or_create_chain(chain_name)

    def add_acl_set(self, if_name : str, acl_set : Dict, direction : DirectionEnum) -> None:
        acl_set_name = acl_set['config']['set-name']
        acl_set_type = acl_set['config']['type']

        if direction == DirectionEnum.INGRESS:
            self.create_nft_chains_in_table(acl_set_type, CHAINS_INPUT)
        elif direction == DirectionEnum.EGRESS:
            self.create_nft_chains_in_table(acl_set_type, CHAINS_OUTPUT)
        else:
            self.create_nft_chains_in_table(acl_set_type, CHAINS_ALL)

        for acl_set_entry in acl_set['acl-entries']['acl-entry']:
            sequence_id = int(acl_set_entry['sequence-id'])
            key = (acl_set_name, sequence_id)
            if_dir_list = self._acl_rule_to_iface_direction.setdefault(key, list())
            if_dir_list.append((if_name, direction))

    def add_interface(self, interface : Dict) -> None:
        if_name = interface['config']['id']
        for direction in [DirectionEnum.INGRESS, DirectionEnum.EGRESS]:
            direction_value = direction.value
            acl_sets_obj = interface.get(f'{direction_value}-acl-sets', dict())
            acl_sets_lst = acl_sets_obj.get(f'{direction_value}-acl-set', list())
            for acl_set in acl_sets_lst:
                self.add_acl_set(if_name, acl_set, DirectionEnum.INGRESS)

    def add_interfaces(self, interfaces : List[Dict]) -> None:
        for interface in interfaces:
            self.add_interface(interface)

    def get_interfaces_directions(
        self, acl_set_name : str, sequence_id : int
    ) -> TYPE_IFACE_DIRECTIONS:
        return self._acl_rule_to_iface_direction.get((acl_set_name, sequence_id), [])


class ACLs(Resource):
    def get(self):
@@ -38,136 +104,57 @@ class ACLs(Resource):

    def post(self):
        payload = request.get_json(force=True)
        # RESTCONF wrapper may be 'openconfig-acl:acl-set' or 'acl-set'
        if not isinstance(payload, dict): abort(400, message='invalid payload')
        if 'openconfig-acl:acl' in payload:
        if not isinstance(payload, dict):
            abort(400, message='invalid payload')

        content = payload.get('openconfig-acl:acl')
        elif 'acl' in payload:
        if content is None:
            content = payload.get('acl')
        else:
            if content is None:
                abort(400, message='invalid payload')

        if not isinstance(content, dict): abort(400, message='invalid content')


        nft = NFTables()
        if not isinstance(content, dict):
            abort(400, message='invalid content')

        interfaces = content['interfaces']['interface']
        if not isinstance(interfaces, list): abort(400, message='invalid interfaces')
        interfaces_struct : Dict[str, Dict[DirectionEnum, Dict[str, Set[int]]]] = dict()
        acl_rule_to_iface_direction : Dict[Tuple[str, int], List[Tuple[str, DirectionEnum]]] = dict()
        for interface in interfaces:
            if_name = interface['config']['id']

            ingress_acl_sets = interface.get('ingress-acl-sets', dict()).get('ingress-acl-set', list())
            for ingress_acl_set in ingress_acl_sets:
                acl_set_name = ingress_acl_set['config']['set-name']
                acl_set_type = ingress_acl_set['config']['type']

                family = {
                    'ACL_IPV4' : FamilyEnum.IPV4,
                    'ACL_IPV6' : FamilyEnum.IPV6,
                }.get(acl_set_type)
                table = nft.get_or_create_table(family, TableEnum.FILTER)
                table.get_or_create_chain('input')

                acl_set_sequence_ids = {
                    int(acl_set_entry['sequence-id'])
                    for acl_set_entry in ingress_acl_set['acl-entries']['acl-entry']
                }

                interfaces_struct.setdefault(if_name, dict())\
                    .setdefault(DirectionEnum.INGRESS, dict())\
                    .setdefault(acl_set_name, set())\
                    .update(acl_set_sequence_ids)

                for sequence_id in acl_set_sequence_ids:
                    key = (acl_set_name, sequence_id)
                    acl_rule_to_iface_direction.setdefault(key, list()).append((if_name, DirectionEnum.INGRESS))

            egress_acl_sets = interface.get('egress-acl-sets', dict()).get('egress-acl-set', list())
            for egress_acl_set in egress_acl_sets:
                acl_set_name = egress_acl_set['config']['set-name']
                acl_set_type = egress_acl_set['config']['type']

                family = {
                    'ACL_IPV4' : FamilyEnum.IPV4,
                    'ACL_IPV6' : FamilyEnum.IPV6,
                }.get(acl_set_type)
                table = nft.get_or_create_table(family, TableEnum.FILTER)
                table.get_or_create_chain('output')
        if not isinstance(interfaces, list):
            abort(400, message='invalid interfaces')

                acl_set_sequence_ids = {
                    int(acl_set_entry['sequence-id'])
                    for acl_set_entry in egress_acl_set['acl-entries']['acl-entry']
                }

                interfaces_struct.setdefault(if_name, dict())\
                    .setdefault(DirectionEnum.EGRESS, dict())\
                    .setdefault(acl_set_name, set())\
                    .update(acl_set_sequence_ids)
        nft = NFTables()

                for sequence_id in acl_set_sequence_ids:
                    key = (acl_set_name, sequence_id)
                    acl_rule_to_iface_direction.setdefault(key, list()).append((if_name, DirectionEnum.EGRESS))
        arid = AclRuleToInterfaceDirection(nft)
        arid.add_interfaces(interfaces)

        acl_sets = content['acl-sets']['acl-set']
        if not isinstance(acl_sets, list): abort(400, message='invalid acl_sets')
        if not isinstance(acl_sets, list):
            abort(400, message='invalid acl_sets')

        for acl_set in acl_sets:
            acl_set_name = acl_set['config']['name']
            acl_set_type = acl_set['config']['type']

            family = {
                'ACL_IPV4' : FamilyEnum.IPV4,
                'ACL_IPV6' : FamilyEnum.IPV6,
            }.get(acl_set_type)
            family = get_family_from_acl_set_type(acl_set_type)
            table = TableEnum.FILTER

            for acl_entry in acl_set['acl-entries']['acl-entry']:
                sequence_id = acl_entry['config']['sequence-id']
                description = acl_entry['config']['description']
                ipv4_config = acl_entry.get('ipv4', {}).get('config', {})
                transp_config = acl_entry.get('transport', {}).get('config', {})

                interfaces_directions = acl_rule_to_iface_direction[(acl_set_name, sequence_id)]
                interfaces_directions = arid.get_interfaces_directions(
                    acl_set_name, acl_entry['config']['sequence-id']
                )
                for if_name, direction in interfaces_directions:
                    table = nft.get_or_create_table(family, TableEnum.FILTER)

                    if direction == DirectionEnum.INGRESS:
                        rule = Rule(family=family, table=TableEnum.FILTER, chain='input')
                        rule.input_if_name = if_name
                        chain = table.get_or_create_chain('input')
                        chain_list = CHAINS_INPUT
                        input_if_name, output_if_name = if_name, None
                    elif direction == DirectionEnum.EGRESS:
                        rule = Rule(family=family, table=TableEnum.FILTER, chain='output')
                        rule.output_if_name = if_name
                        chain = table.get_or_create_chain('output')
                        chain_list = CHAINS_OUTPUT
                        input_if_name, output_if_name = None, if_name
                    else:
                        raise Exception('Unsupported direction: {:s}'.format(str(direction)))

                    rule.comment       = description

                    if 'source-address' in ipv4_config:
                        rule.src_ip_addr = ipaddress.IPv4Interface(ipv4_config['source-address'])

                    if 'destination-address' in ipv4_config:
                        rule.dst_ip_addr = ipaddress.IPv4Interface(ipv4_config['destination-address'])

                    if 'protocol' in ipv4_config:
                        rule.ip_protocol = {
                            'IP_TCP'  : ProtocolEnum.TCP,
                            'IP_UDP'  : ProtocolEnum.UDP,
                            'IP_ICMP' : ProtocolEnum.ICMP,
                        }.get(ipv4_config['protocol'], None)

                    rule.src_port = transp_config.get('source-port')
                    rule.dst_port = transp_config.get('destination-port')

                    rule.action = {
                        'ACCEPT' : ActionEnum.ACCEPT,
                        'DROP'   : ActionEnum.DROP,
                        'REJECT' : ActionEnum.REJECT,
                    }.get(acl_entry['actions']['config']['forwarding-action'], None)

                    chain.rules.append(rule)
                    for chain_name in chain_list:
                        rule = Rule.from_openconfig(family, table, chain_name, acl_entry)
                        rule.input_if_name = input_if_name
                        rule.output_if_name = output_if_name
                        nft.add_rule(rule)

        entries = nft.dump()
        for entry in entries:
@@ -178,45 +165,42 @@ class ACLs(Resource):
        return {}, 201


class ACL(Resource):
    def get(self, name : str):
def load_nftables_by_rule_comment(rule_comment : str) -> NFTables:
    nft = NFTables()
    nft.load(FamilyEnum.IPV4, TableEnum.FILTER)

    tables_to_remove : Set[Tuple[FamilyEnum, TableEnum]] = set()
    for table_key, table in nft.tables.items():

        chains_to_remove : Set[str] = set()
        for chain_name, chain in table.chains.items():

            for rule in reversed(chain.rules):
                    if rule.comment == name: continue
                if rule.comment == rule_comment: continue
                chain.rules.remove(rule) # not a rule of interest

            if len(chain.rules) > 0: continue
            chains_to_remove.add(chain_name)

        for chain_name in chains_to_remove:
            table.chains.pop(chain_name)

        if len(nft.tables) > 0: continue
        tables_to_remove.add(table_key)

    for table_key in tables_to_remove:
        nft.tables.pop(table_key)

    return nft


class ACL(Resource):
    def get(self, name : str):
        nft = load_nftables_by_rule_comment(name)
        return nft.to_openconfig(), 200

    def delete(self, name : str):
        nft = NFTables()
        nft.load(FamilyEnum.IPV4, TableEnum.FILTER)
        tables_to_remove : Set[Tuple[FamilyEnum, TableEnum]] = set()
        for table_key, table in nft.tables.items():
            chains_to_remove : Set[str] = set()
            for chain_name, chain in table.chains.items():
                for rule in reversed(chain.rules):
                    if rule.comment == name: continue
                    chain.rules.remove(rule) # not a rule of interest
                if len(chain.rules) > 0: continue
                chains_to_remove.add(chain_name)
            for chain_name in chains_to_remove:
                table.chains.pop(chain_name)
            if len(nft.tables) > 0: continue
            tables_to_remove.add(table_key)
        for table_key in tables_to_remove:
            nft.tables.pop(table_key)
        nft = load_nftables_by_rule_comment(name)
        nft.execute(removal=True, verbose=True)
        return {}, 204

+3 −3
Original line number Diff line number Diff line
@@ -74,8 +74,8 @@ class Chain:
        for rule in self.rules: entries.extend(rule.dump())
        return entries

    def get_commands(self, removal : bool = False) -> List[str]:
        commands : List[Dict] = list()
    def get_commands(self, removal : bool = False) -> List[Tuple[int, str]]:
        commands : List[Tuple[int, str]] = list()
        if removal:
            # NOTE: For now, do not remove chains. We do not process all kinds of
            # chains and their removal might cause side effects on NFTables.
@@ -88,7 +88,7 @@ class Chain:
                'policy', 'accept', ';',
                '}'
            ]
            commands.append(' '.join(parts))
            commands.append(-1, ' '.join(parts))
        for rule in self.rules:
            commands.append(rule.get_command(removal=removal))
        return commands
+15 −7
Original line number Diff line number Diff line
@@ -15,11 +15,13 @@

import logging
from dataclasses import dataclass, field
import operator
from typing import Dict, List, Optional, Set, Tuple
from .DirectionEnum import DirectionEnum
from .Exceptions import UnsupportedElementException
from .FamilyEnum import FamilyEnum, get_family_from_str
from .NFTablesCommand import NFTablesCommand
from .Rule import Rule
from .Table import Table
from .TableEnum import TableEnum, get_table_from_str

@@ -38,10 +40,6 @@ class NFTables:
        entries = NFTablesCommand.list(family=family, table=table, chain=chain)
        for entry in entries: self.parse_entry(entry)

    def execute(self, removal : bool = False, verbose : bool = True) -> None:
        commands = self.get_commands(removal=removal)
        NFTablesCommand.execute(commands, verbose=verbose)

    def get_or_create_table(self, family : FamilyEnum, table : TableEnum) -> Table:
        return self.tables.setdefault((family, table), Table(family, table))

@@ -82,6 +80,11 @@ class NFTables:
        if table not in {TableEnum.FILTER}: return
        self.get_or_create_table(family, table).add_rule_by_entry(entry)

    def add_rule(self, rule : Rule) -> None:
        table = self.get_or_create_table(rule.family, rule.table)
        chain = table.get_or_create_chain(rule.chain)
        chain.rules.append(rule)

    def to_openconfig(self) -> List[Dict]:
        acl_sets : List[Dict] = list()
        interfaces_struct : Dict[str, Dict[DirectionEnum, Dict[str, Set[int]]]] = dict()
@@ -146,8 +149,13 @@ class NFTables:
        for table in self.tables.values(): entries.extend(table.dump())
        return entries

    def get_commands(self, removal : bool = False) -> List[str]:
        commands : List[Dict] = list()
    def get_commands(self, removal : bool = False) -> List[Tuple[int, str]]:
        commands : List[Tuple[int, str]] = list()
        for table in self.tables.values():
            commands.extend(table.get_commands(removal=removal))
        return commands
        # return a sorted list of commands by their priority (lower first)
        return sorted(commands, key=operator.itemgetter(0))

    def execute(self, removal : bool = False, verbose : bool = True) -> None:
        commands = self.get_commands(removal=removal)
        NFTablesCommand.execute(commands, verbose=verbose)
+4 −4
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@


import json, logging, nftables
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
from .Exceptions import (
    InvalidArgumentException, MalformedOutputException, RuntimeException
)
@@ -67,12 +67,12 @@ class NFTablesCommand:
        return json_nftables['nftables']

    @staticmethod
    def execute(commands : List[str], verbose : bool = True) -> None:
    def execute(commands : List[Tuple[int, str]], verbose : bool = True) -> None:
        nft = nftables.Nftables()
        nft.set_json_output(True)
        for command in commands:
        for priority, command in commands:
            if verbose:
                LOGGER.info(f'Executing: {command}')
                LOGGER.info(f'Executing [priority={str(priority)}]: {command}')
            rc, output, error = nft.cmd(command)
            if verbose:
                LOGGER.info(f'rc={str(rc)} output={str(output)} error={str(error)}')
Loading