Commit 96345310 authored by Georgios Katsikas's avatar Georgios Katsikas Committed by Georgios P. Katsikas
Browse files

feat: basic P4 device driver

parent 94382769
Loading
Loading
Loading
Loading
+123 −0
Original line number Diff line number Diff line
import grpc, anytree, logging, threading
from typing import Any, Iterator, List, Optional, Tuple, Union
from .P4Util import P4RuntimeClient,\
    P4_ATTR_DEV_ID, P4_ATTR_DEV_NAME, P4_ATTR_DEV_VENDOR,\
    P4_ATTR_DEV_HW_VER, P4_ATTR_DEV_SW_VER, P4_ATTR_DEV_PIPECONF

try:
    from Checkers import chk_float, chk_length, chk_string, chk_type
    from _Driver import _Driver
    from AnyTreeTools import TreeNode, dump_subtree, get_subnode,\
        set_subnode_value
except ImportError:
    from common.type_checkers.Checkers import chk_float, chk_length, chk_string, chk_type
    from device.service.driver_api._Driver import _Driver
    from device.service.driver_api.AnyTreeTools import TreeNode, dump_subtree, get_subnode, set_subnode_value

LOGGER = logging.getLogger(__name__)


class P4Driver(_Driver):
    def __init__(self, address: str, port: int, **settings) -> None:
            # pylint: disable=super-init-not-called
        self.__client = None
        self.__address = address
        self.__port = int(port)
        self.__settings = settings
        if P4_ATTR_DEV_ID in self.__settings:
            self.__dev_id = self.__settings.get(P4_ATTR_DEV_ID)
        if P4_ATTR_DEV_NAME in self.__settings:
            self.__dev_name = self.__settings.get(P4_ATTR_DEV_NAME)
        if P4_ATTR_DEV_VENDOR in self.__settings:
            self.__dev_vendor = self.__settings.get(P4_ATTR_DEV_VENDOR)
        if P4_ATTR_DEV_HW_VER in self.__settings:
            self.__dev_hw_version = self.__settings.get(P4_ATTR_DEV_HW_VER)
        if P4_ATTR_DEV_SW_VER in self.__settings:
            self.__dev_sw_version = self.__settings.get(P4_ATTR_DEV_SW_VER)
        if P4_ATTR_DEV_PIPECONF in self.__settings:
            self.__dev_pipeconf = self.__settings.get(P4_ATTR_DEV_PIPECONF)

        LOGGER.info('Initializing P4 device at {}:{} with settings:'.format(
            self.__address, self.__port))

        self.__lock = threading.Lock()
        self.__started = threading.Event()
        self.__terminate = threading.Event()

        for key, value in settings.items():
            LOGGER.info('\t%8s = %s' %(key, value))

    def Connect(self) -> bool:
        LOGGER.info('Connecting to P4 device {}:{}...'.format(
            self.__address, self.__port))
        with self.__lock:
            # Skip if already connected
            if self.__started.is_set():
                return True

            # Instantiate a gRPC channel with the P4 device
            grpc_address = '{}:{}'.format(self.__address, self.__port)
            election_id = (1, 0)
            self.__client = P4RuntimeClient(self.__dev_id, grpc_address, election_id)
            LOGGER.info('\tConnected!')
            self.__started.set()

            return True

    def Disconnect(self) -> bool:
        LOGGER.info('Disconnecting from P4 device {}:{}...'.format(
            self.__address, self.__port))

        # If not started, assume it is already disconnected
        if not self.__started.is_set():
            return True

        # gRPC client must already be instantiated
        assert self.__client

        # Trigger termination of loops and processes
        self.__terminate.set()

        # Trigger connection tear down with the P4Runtime server
        self.__client.tear_down()
        self.__client = None

        LOGGER.info('\tDisonnected!')

        return True

    def GetInitialConfig(self) -> List[Tuple[str, Any]]:
        LOGGER.info('P4 GetInitialConfig()')
        return []

    def GetConfig(self, resource_keys : List[str] = [])\
            -> List[Tuple[str, Union[Any, None, Exception]]]:
        LOGGER.info('P4 GetConfig()')
        return []

    def SetConfig(self, resources : List[Tuple[str, Any]])\
            -> List[Union[bool, Exception]]:
        LOGGER.info('P4 SetConfig()')
        return []

    def DeleteConfig(self, resources : List[Tuple[str, Any]]) -> List[Union[bool, Exception]]:
        LOGGER.info('P4 DeleteConfig()')
        return []

    def GetResource(self, endpoint_uuid : str) -> Optional[str]:
        LOGGER.info('P4 GetResource()')
        return ""

    def GetState(self, blocking=False) -> Iterator[Tuple[str, Any]]:
        LOGGER.info('P4 GetState()')
        return []

    def SubscribeState(self, subscriptions : List[Tuple[str, float, float]])\
            -> List[Union[bool, Exception]]:
        LOGGER.info('P4 SubscribeState()')
        return []

    def UnsubscribeState(self, subscriptions : List[Tuple[str, float, float]])\
            -> List[Union[bool, Exception]]:
        LOGGER.info('P4 UnsubscribeState()')
        return []
+181 −0
Original line number Diff line number Diff line
from functools import wraps
import google.protobuf.text_format
from google.rpc import code_pb2
import grpc
import logging
import queue
import sys
import threading

from p4.v1 import p4runtime_pb2
from p4.v1 import p4runtime_pb2_grpc

P4_ATTR_DEV_ID = 'id'
P4_ATTR_DEV_NAME = 'name'
P4_ATTR_DEV_VENDOR = 'vendor'
P4_ATTR_DEV_HW_VER = 'hw_ver'
P4_ATTR_DEV_SW_VER = 'sw_ver'
P4_ATTR_DEV_PIPECONF = 'pipeconf'

LOGGER = logging.getLogger(__name__)


class P4RuntimeErrorFormatException(Exception):
    def __init__(self, message):
        super().__init__(message)


class P4RuntimeException(Exception):
    def __init__(self, grpc_error):
        super().__init__()
        self.grpc_error = grpc_error

    def __str__(self):
        message = "P4Runtime RPC error ({}): {}".format(
            self.grpc_error.code().name, self.grpc_error.details())
        return message


def parse_p4runtime_error(f):
    @wraps(f)
    def handle(*args, **kwargs):
        try:
            return f(*args, **kwargs)
        except grpc.RpcError as e:
            raise P4RuntimeException(e) from None
    return handle


class P4RuntimeClient:
    def __init__(self, device_id, grpc_address, election_id, role_name=None):
        self.device_id = device_id
        self.election_id = election_id
        self.role_name = role_name
        self.stream_in_q = None
        self.stream_out_q = None
        self.stream = None
        self.stream_recv_thread = None
        LOGGER.debug(
            "Connecting to device {} at {}".format(device_id, grpc_address))
        try:
            self.channel = grpc.insecure_channel(grpc_address)
        except Exception:
            LOGGER.critical("Failed to connect to P4Runtime server")
            sys.exit(1)
        self.stub = p4runtime_pb2_grpc.P4RuntimeStub(self.channel)
        self.set_up_stream()

    def set_up_stream(self):
        self.stream_out_q = queue.Queue()
        # queues for different messages
        self.stream_in_q = {
            "arbitration": queue.Queue(),
            "packet": queue.Queue(),
            "digest": queue.Queue(),
            "unknown": queue.Queue(),
        }

        def stream_req_iterator():
            while True:
                p = self.stream_out_q.get()
                if p is None:
                    break
                yield p

        def stream_recv_wrapper(stream):
            @parse_p4runtime_error
            def stream_recv():
                for p in stream:
                    if p.HasField("arbitration"):
                        self.stream_in_q["arbitration"].put(p)
                    elif p.HasField("packet"):
                        self.stream_in_q["packet"].put(p)
                    elif p.HasField("digest"):
                        self.stream_in_q["digest"].put(p)
                    else:
                        self.stream_in_q["unknown"].put(p)
            try:
                stream_recv()
            except P4RuntimeException as e:
                LOGGER.critical("StreamChannel error, closing stream")
                LOGGER.critical(e)
                for k in self.stream_in_q:
                    self.stream_in_q[k].put(None)
        self.stream = self.stub.StreamChannel(stream_req_iterator())
        self.stream_recv_thread = threading.Thread(
            target=stream_recv_wrapper, args=(self.stream,))
        self.stream_recv_thread.start()
        self.handshake()

    def handshake(self):
        req = p4runtime_pb2.StreamMessageRequest()
        arbitration = req.arbitration
        arbitration.device_id = self.device_id
        election_id = arbitration.election_id
        election_id.high = self.election_id[0]
        election_id.low = self.election_id[1]
        if self.role_name is not None:
            arbitration.role.name = self.role_name
        self.stream_out_q.put(req)

        rep = self.get_stream_packet("arbitration", timeout=2)
        if rep is None:
            LOGGER.critical("Failed to establish session with server")
            sys.exit(1)
        is_primary = (rep.arbitration.status.code == code_pb2.OK)
        LOGGER.debug("Session established, client is '{}'".format(
            'primary' if is_primary else 'backup'))
        if not is_primary:
            LOGGER.warning("You are not the primary client, you only have read access to the server")

    def get_stream_packet(self, type_, timeout=1):
        if type_ not in self.stream_in_q:
            LOGGER.critical("Unknown stream type '{}'".format(type_))
            return None
        try:
            msg = self.stream_in_q[type_].get(timeout=timeout)
            return msg
        except queue.Empty:  # timeout expired
            return None

    @parse_p4runtime_error
    def get_p4info(self):
        LOGGER.debug("Retrieving P4Info file")
        req = p4runtime_pb2.GetForwardingPipelineConfigRequest()
        req.device_id = self.device_id
        req.response_type = p4runtime_pb2.GetForwardingPipelineConfigRequest.P4INFO_AND_COOKIE
        rep = self.stub.GetForwardingPipelineConfig(req)
        return rep.config.p4info

    @parse_p4runtime_error
    def set_fwd_pipe_config(self, p4info_path, bin_path):
        LOGGER.debug("Setting forwarding pipeline config")
        req = p4runtime_pb2.SetForwardingPipelineConfigRequest()
        req.device_id = self.device_id
        if self.role_name is not None:
            req.role = self.role_name
        election_id = req.election_id
        election_id.high = self.election_id[0]
        election_id.low = self.election_id[1]
        req.action = p4runtime_pb2.SetForwardingPipelineConfigRequest.VERIFY_AND_COMMIT
        with open(p4info_path, 'r') as f1:
            with open(bin_path, 'rb') as f2:
                try:
                    google.protobuf.text_format.Merge(f1.read(), req.config.p4info)
                except google.protobuf.text_format.ParseError:
                    LOGGER.error("Error when parsing P4Info")
                    raise
                req.config.p4_device_config = f2.read()
        return self.stub.SetForwardingPipelineConfig(req)

    def tear_down(self):
        if self.stream_out_q:
            LOGGER.debug("Cleaning up stream")
            self.stream_out_q.put(None)
        if self.stream_in_q:
            for k in self.stream_in_q:
                self.stream_in_q[k].put(None)
        if self.stream_recv_thread:
            self.stream_recv_thread.join()
        self.channel.close()
        del self.channel  # avoid a race condition if channel deleted when process terminates