Commit 1c312bc1 authored by Georgios Katsikas's avatar Georgios Katsikas Committed by Georgios P. Katsikas
Browse files

chore: style changes and code review

parent ffe8dc40
Loading
Loading
Loading
Loading
+0 −123
Original line number Diff line number Diff line
import grpc, anytree, logging, threading
from typing import Any, Iterator, List, Optional, Tuple, Union
from .P4Util import P4RuntimeClient,\
    P4_ATTR_DEV_ID, P4_ATTR_DEV_NAME, P4_ATTR_DEV_VENDOR,\
    P4_ATTR_DEV_HW_VER, P4_ATTR_DEV_SW_VER, P4_ATTR_DEV_PIPECONF

try:
    from Checkers import chk_float, chk_length, chk_string, chk_type
    from _Driver import _Driver
    from AnyTreeTools import TreeNode, dump_subtree, get_subnode,\
        set_subnode_value
except ImportError:
    from common.type_checkers.Checkers import chk_float, chk_length, chk_string, chk_type
    from device.service.driver_api._Driver import _Driver
    from device.service.driver_api.AnyTreeTools import TreeNode, dump_subtree, get_subnode, set_subnode_value

LOGGER = logging.getLogger(__name__)


class P4Driver(_Driver):
    def __init__(self, address: str, port: int, **settings) -> None:
            # pylint: disable=super-init-not-called
        self.__client = None
        self.__address = address
        self.__port = int(port)
        self.__settings = settings
        if P4_ATTR_DEV_ID in self.__settings:
            self.__dev_id = self.__settings.get(P4_ATTR_DEV_ID)
        if P4_ATTR_DEV_NAME in self.__settings:
            self.__dev_name = self.__settings.get(P4_ATTR_DEV_NAME)
        if P4_ATTR_DEV_VENDOR in self.__settings:
            self.__dev_vendor = self.__settings.get(P4_ATTR_DEV_VENDOR)
        if P4_ATTR_DEV_HW_VER in self.__settings:
            self.__dev_hw_version = self.__settings.get(P4_ATTR_DEV_HW_VER)
        if P4_ATTR_DEV_SW_VER in self.__settings:
            self.__dev_sw_version = self.__settings.get(P4_ATTR_DEV_SW_VER)
        if P4_ATTR_DEV_PIPECONF in self.__settings:
            self.__dev_pipeconf = self.__settings.get(P4_ATTR_DEV_PIPECONF)

        LOGGER.info('Initializing P4 device at {}:{} with settings:'.format(
            self.__address, self.__port))

        self.__lock = threading.Lock()
        self.__started = threading.Event()
        self.__terminate = threading.Event()

        for key, value in settings.items():
            LOGGER.info('\t%8s = %s' %(key, value))

    def Connect(self) -> bool:
        LOGGER.info('Connecting to P4 device {}:{}...'.format(
            self.__address, self.__port))
        with self.__lock:
            # Skip if already connected
            if self.__started.is_set():
                return True

            # Instantiate a gRPC channel with the P4 device
            grpc_address = '{}:{}'.format(self.__address, self.__port)
            election_id = (1, 0)
            self.__client = P4RuntimeClient(self.__dev_id, grpc_address, election_id)
            LOGGER.info('\tConnected!')
            self.__started.set()

            return True

    def Disconnect(self) -> bool:
        LOGGER.info('Disconnecting from P4 device {}:{}...'.format(
            self.__address, self.__port))

        # If not started, assume it is already disconnected
        if not self.__started.is_set():
            return True

        # gRPC client must already be instantiated
        assert self.__client

        # Trigger termination of loops and processes
        self.__terminate.set()

        # Trigger connection tear down with the P4Runtime server
        self.__client.tear_down()
        self.__client = None

        LOGGER.info('\tDisonnected!')

        return True

    def GetInitialConfig(self) -> List[Tuple[str, Any]]:
        LOGGER.info('P4 GetInitialConfig()')
        return []

    def GetConfig(self, resource_keys : List[str] = [])\
            -> List[Tuple[str, Union[Any, None, Exception]]]:
        LOGGER.info('P4 GetConfig()')
        return []

    def SetConfig(self, resources : List[Tuple[str, Any]])\
            -> List[Union[bool, Exception]]:
        LOGGER.info('P4 SetConfig()')
        return []

    def DeleteConfig(self, resources : List[Tuple[str, Any]]) -> List[Union[bool, Exception]]:
        LOGGER.info('P4 DeleteConfig()')
        return []

    def GetResource(self, endpoint_uuid : str) -> Optional[str]:
        LOGGER.info('P4 GetResource()')
        return ""

    def GetState(self, blocking=False) -> Iterator[Tuple[str, Any]]:
        LOGGER.info('P4 GetState()')
        return []

    def SubscribeState(self, subscriptions : List[Tuple[str, float, float]])\
            -> List[Union[bool, Exception]]:
        LOGGER.info('P4 SubscribeState()')
        return []

    def UnsubscribeState(self, subscriptions : List[Tuple[str, float, float]])\
            -> List[Union[bool, Exception]]:
        LOGGER.info('P4 UnsubscribeState()')
        return []
+0 −181
Original line number Diff line number Diff line
from functools import wraps
import google.protobuf.text_format
from google.rpc import code_pb2
import grpc
import logging
import queue
import sys
import threading

from p4.v1 import p4runtime_pb2
from p4.v1 import p4runtime_pb2_grpc

P4_ATTR_DEV_ID = 'id'
P4_ATTR_DEV_NAME = 'name'
P4_ATTR_DEV_VENDOR = 'vendor'
P4_ATTR_DEV_HW_VER = 'hw_ver'
P4_ATTR_DEV_SW_VER = 'sw_ver'
P4_ATTR_DEV_PIPECONF = 'pipeconf'

LOGGER = logging.getLogger(__name__)


class P4RuntimeErrorFormatException(Exception):
    def __init__(self, message):
        super().__init__(message)


class P4RuntimeException(Exception):
    def __init__(self, grpc_error):
        super().__init__()
        self.grpc_error = grpc_error

    def __str__(self):
        message = "P4Runtime RPC error ({}): {}".format(
            self.grpc_error.code().name, self.grpc_error.details())
        return message


def parse_p4runtime_error(f):
    @wraps(f)
    def handle(*args, **kwargs):
        try:
            return f(*args, **kwargs)
        except grpc.RpcError as e:
            raise P4RuntimeException(e) from None
    return handle


class P4RuntimeClient:
    def __init__(self, device_id, grpc_address, election_id, role_name=None):
        self.device_id = device_id
        self.election_id = election_id
        self.role_name = role_name
        self.stream_in_q = None
        self.stream_out_q = None
        self.stream = None
        self.stream_recv_thread = None
        LOGGER.debug(
            "Connecting to device {} at {}".format(device_id, grpc_address))
        try:
            self.channel = grpc.insecure_channel(grpc_address)
        except Exception:
            LOGGER.critical("Failed to connect to P4Runtime server")
            sys.exit(1)
        self.stub = p4runtime_pb2_grpc.P4RuntimeStub(self.channel)
        self.set_up_stream()

    def set_up_stream(self):
        self.stream_out_q = queue.Queue()
        # queues for different messages
        self.stream_in_q = {
            "arbitration": queue.Queue(),
            "packet": queue.Queue(),
            "digest": queue.Queue(),
            "unknown": queue.Queue(),
        }

        def stream_req_iterator():
            while True:
                p = self.stream_out_q.get()
                if p is None:
                    break
                yield p

        def stream_recv_wrapper(stream):
            @parse_p4runtime_error
            def stream_recv():
                for p in stream:
                    if p.HasField("arbitration"):
                        self.stream_in_q["arbitration"].put(p)
                    elif p.HasField("packet"):
                        self.stream_in_q["packet"].put(p)
                    elif p.HasField("digest"):
                        self.stream_in_q["digest"].put(p)
                    else:
                        self.stream_in_q["unknown"].put(p)
            try:
                stream_recv()
            except P4RuntimeException as e:
                LOGGER.critical("StreamChannel error, closing stream")
                LOGGER.critical(e)
                for k in self.stream_in_q:
                    self.stream_in_q[k].put(None)
        self.stream = self.stub.StreamChannel(stream_req_iterator())
        self.stream_recv_thread = threading.Thread(
            target=stream_recv_wrapper, args=(self.stream,))
        self.stream_recv_thread.start()
        self.handshake()

    def handshake(self):
        req = p4runtime_pb2.StreamMessageRequest()
        arbitration = req.arbitration
        arbitration.device_id = self.device_id
        election_id = arbitration.election_id
        election_id.high = self.election_id[0]
        election_id.low = self.election_id[1]
        if self.role_name is not None:
            arbitration.role.name = self.role_name
        self.stream_out_q.put(req)

        rep = self.get_stream_packet("arbitration", timeout=2)
        if rep is None:
            LOGGER.critical("Failed to establish session with server")
            sys.exit(1)
        is_primary = (rep.arbitration.status.code == code_pb2.OK)
        LOGGER.debug("Session established, client is '{}'".format(
            'primary' if is_primary else 'backup'))
        if not is_primary:
            LOGGER.warning("You are not the primary client, you only have read access to the server")

    def get_stream_packet(self, type_, timeout=1):
        if type_ not in self.stream_in_q:
            LOGGER.critical("Unknown stream type '{}'".format(type_))
            return None
        try:
            msg = self.stream_in_q[type_].get(timeout=timeout)
            return msg
        except queue.Empty:  # timeout expired
            return None

    @parse_p4runtime_error
    def get_p4info(self):
        LOGGER.debug("Retrieving P4Info file")
        req = p4runtime_pb2.GetForwardingPipelineConfigRequest()
        req.device_id = self.device_id
        req.response_type = p4runtime_pb2.GetForwardingPipelineConfigRequest.P4INFO_AND_COOKIE
        rep = self.stub.GetForwardingPipelineConfig(req)
        return rep.config.p4info

    @parse_p4runtime_error
    def set_fwd_pipe_config(self, p4info_path, bin_path):
        LOGGER.debug("Setting forwarding pipeline config")
        req = p4runtime_pb2.SetForwardingPipelineConfigRequest()
        req.device_id = self.device_id
        if self.role_name is not None:
            req.role = self.role_name
        election_id = req.election_id
        election_id.high = self.election_id[0]
        election_id.low = self.election_id[1]
        req.action = p4runtime_pb2.SetForwardingPipelineConfigRequest.VERIFY_AND_COMMIT
        with open(p4info_path, 'r') as f1:
            with open(bin_path, 'rb') as f2:
                try:
                    google.protobuf.text_format.Merge(f1.read(), req.config.p4info)
                except google.protobuf.text_format.ParseError:
                    LOGGER.error("Error when parsing P4Info")
                    raise
                req.config.p4_device_config = f2.read()
        return self.stub.SetForwardingPipelineConfig(req)

    def tear_down(self):
        if self.stream_out_q:
            LOGGER.debug("Cleaning up stream")
            self.stream_out_q.put(None)
        if self.stream_in_q:
            for k in self.stream_in_q:
                self.stream_in_q[k].put(None)
        if self.stream_recv_thread:
            self.stream_recv_thread.join()
        self.channel.close()
        del self.channel  # avoid a race condition if channel deleted when process terminates

src/device/tests/Device_P4.py

deleted100644 → 0
+0 −42
Original line number Diff line number Diff line
from copy import deepcopy
from device.proto.context_pb2 import DeviceDriverEnum, DeviceOperationalStatusEnum
from .Tools import config_rule_set

DEVICE_P4_ID = 0
DEVICE_P4_NAME = 'device:leaf1'
DEVICE_P4_TYPE = 'p4-switch'
DEVICE_P4_ADDRESS = '127.0.0.1'
DEVICE_P4_PORT = '50101'
DEVICE_P4_DRIVERS = [DeviceDriverEnum.DEVICEDRIVER_P4]
DEVICE_P4_VENDOR = 'Open Networking Foundation'
DEVICE_P4_HW_VER = 'BMv2 simple_switch'
DEVICE_P4_SW_VER = 'Stratum'
DEVICE_P4_PIPECONF = 'org.onosproject.pipelines.fabric'
DEVICE_P4_WORKERS = 2
DEVICE_P4_GRACE_PERIOD = 60

DEVICE_P4_UUID = {'device_uuid': {'uuid': DEVICE_P4_NAME}}
DEVICE_P4 = {
    'device_id': deepcopy(DEVICE_P4_UUID),
    'device_type': DEVICE_P4_TYPE,
    'device_config': {'config_rules': []},
    'device_operational_status': DeviceOperationalStatusEnum.DEVICEOPERATIONALSTATUS_DISABLED,
    'device_drivers': DEVICE_P4_DRIVERS,
    'device_endpoints': [],
}

DEVICE_P4_CONNECT_RULES = [
    config_rule_set('_connect/address', DEVICE_P4_ADDRESS),
    config_rule_set('_connect/port', DEVICE_P4_PORT),
    config_rule_set('_connect/settings', {
        'id': int(DEVICE_P4_ID),
        'name': DEVICE_P4_NAME,
        'hw-ver': DEVICE_P4_HW_VER,
        'sw-ver': DEVICE_P4_SW_VER,
        'pipeconf': DEVICE_P4_PIPECONF
    }),
]

DEVICE_P4_CONFIG_RULES = [
    config_rule_set('key1', 'value1'),
]
+0 −50
Original line number Diff line number Diff line
import grpc, logging
from concurrent import futures
from p4.v1 import p4runtime_pb2_grpc

from .Device_P4 import(
    DEVICE_P4_ADDRESS, DEVICE_P4_PORT,
    DEVICE_P4_WORKERS, DEVICE_P4_GRACE_PERIOD)
from .MockP4RuntimeServicerImpl import MockP4RuntimeServicerImpl

LOGGER = logging.getLogger(__name__)


class MockP4RuntimeService:
    def __init__(
            self, address=DEVICE_P4_ADDRESS, port=DEVICE_P4_PORT,
            max_workers=DEVICE_P4_WORKERS,
            grace_period=DEVICE_P4_GRACE_PERIOD):
        self.address = address
        self.port = port
        self.endpoint = '{:s}:{:s}'.format(str(self.address), str(self.port))
        self.max_workers = max_workers
        self.grace_period = grace_period
        self.pool = None
        self.server = None
        self.servicer = None

    def start(self):
        LOGGER.info(
            'Starting P4Runtime service on {:s} with max_workers: {:s})'.format(
                str(self.endpoint), str(self.max_workers)))

        self.pool = futures.ThreadPoolExecutor(max_workers=self.max_workers)
        self.server = grpc.server(self.pool)

        self.servicer = MockP4RuntimeServicerImpl()
        p4runtime_pb2_grpc.add_P4RuntimeServicer_to_server(
            self.servicer, self.server)

        _ = self.server.add_insecure_port(self.endpoint)
        LOGGER.info('Listening on {:s}...'.format(str(self.endpoint)))

        self.server.start()
        LOGGER.debug('P4Runtime service started')

    def stop(self):
        LOGGER.debug(
            'Stopping P4Runtime service (grace period {:s} seconds)...'.format(
                str(self.grace_period)))
        self.server.stop(self.grace_period)
        LOGGER.debug('P4Runtime service stopped')
+0 −42
Original line number Diff line number Diff line
import queue
from google.rpc import code_pb2
from p4.v1 import p4runtime_pb2, p4runtime_pb2_grpc
from p4.config.v1 import p4info_pb2


class MockP4RuntimeServicerImpl(p4runtime_pb2_grpc.P4RuntimeServicer):
    def __init__(self):
        self.p4info = p4info_pb2.P4Info()
        self.p4runtime_api_version = "1.3.0"
        self.stored_packet_out = queue.Queue()

    def GetForwardingPipelineConfig(self, request, context):
        rep = p4runtime_pb2.GetForwardingPipelineConfigResponse()
        if self.p4info is not None:
            rep.config.p4info.CopyFrom(self.p4info)
        return rep

    def SetForwardingPipelineConfig(self, request, context):
        self.p4info.CopyFrom(request.config.p4info)
        return p4runtime_pb2.SetForwardingPipelineConfigResponse()

    def Write(self, request, context):
        return p4runtime_pb2.WriteResponse()

    def Read(self, request, context):
        yield p4runtime_pb2.ReadResponse()

    def StreamChannel(self, request_iterator, context):
        for req in request_iterator:
            if req.HasField('arbitration'):
                rep = p4runtime_pb2.StreamMessageResponse()
                rep.arbitration.CopyFrom(req.arbitration)
                rep.arbitration.status.code = code_pb2.OK
                yield rep
            elif req.HasField('packet'):
                self.stored_packet_out.put(req)

    def Capabilities(self, request, context):
        rep = p4runtime_pb2.CapabilitiesResponse()
        rep.p4runtime_api_version = self.p4runtime_api_version
        return rep