diff --git a/src/device/service/drivers/p4/P4Driver.py b/src/device/service/drivers/p4/P4Driver.py new file mode 100644 index 0000000000000000000000000000000000000000..9337ade2b8129f28e452691aa04b67fbbcbb9cb4 --- /dev/null +++ b/src/device/service/drivers/p4/P4Driver.py @@ -0,0 +1,123 @@ +import grpc, anytree, logging, threading +from typing import Any, Iterator, List, Optional, Tuple, Union +from .P4Util import P4RuntimeClient,\ + P4_ATTR_DEV_ID, P4_ATTR_DEV_NAME, P4_ATTR_DEV_VENDOR,\ + P4_ATTR_DEV_HW_VER, P4_ATTR_DEV_SW_VER, P4_ATTR_DEV_PIPECONF + +try: + from Checkers import chk_float, chk_length, chk_string, chk_type + from _Driver import _Driver + from AnyTreeTools import TreeNode, dump_subtree, get_subnode,\ + set_subnode_value +except ImportError: + from common.type_checkers.Checkers import chk_float, chk_length, chk_string, chk_type + from device.service.driver_api._Driver import _Driver + from device.service.driver_api.AnyTreeTools import TreeNode, dump_subtree, get_subnode, set_subnode_value + +LOGGER = logging.getLogger(__name__) + + +class P4Driver(_Driver): + def __init__(self, address: str, port: int, **settings) -> None: + # pylint: disable=super-init-not-called + self.__client = None + self.__address = address + self.__port = int(port) + self.__settings = settings + if P4_ATTR_DEV_ID in self.__settings: + self.__dev_id = self.__settings.get(P4_ATTR_DEV_ID) + if P4_ATTR_DEV_NAME in self.__settings: + self.__dev_name = self.__settings.get(P4_ATTR_DEV_NAME) + if P4_ATTR_DEV_VENDOR in self.__settings: + self.__dev_vendor = self.__settings.get(P4_ATTR_DEV_VENDOR) + if P4_ATTR_DEV_HW_VER in self.__settings: + self.__dev_hw_version = self.__settings.get(P4_ATTR_DEV_HW_VER) + if P4_ATTR_DEV_SW_VER in self.__settings: + self.__dev_sw_version = self.__settings.get(P4_ATTR_DEV_SW_VER) + if P4_ATTR_DEV_PIPECONF in self.__settings: + self.__dev_pipeconf = self.__settings.get(P4_ATTR_DEV_PIPECONF) + + LOGGER.info('Initializing P4 device at {}:{} with settings:'.format( + self.__address, self.__port)) + + self.__lock = threading.Lock() + self.__started = threading.Event() + self.__terminate = threading.Event() + + for key, value in settings.items(): + LOGGER.info('\t%8s = %s' %(key, value)) + + def Connect(self) -> bool: + LOGGER.info('Connecting to P4 device {}:{}...'.format( + self.__address, self.__port)) + with self.__lock: + # Skip if already connected + if self.__started.is_set(): + return True + + # Instantiate a gRPC channel with the P4 device + grpc_address = '{}:{}'.format(self.__address, self.__port) + election_id = (1, 0) + self.__client = P4RuntimeClient(self.__dev_id, grpc_address, election_id) + LOGGER.info('\tConnected!') + self.__started.set() + + return True + + def Disconnect(self) -> bool: + LOGGER.info('Disconnecting from P4 device {}:{}...'.format( + self.__address, self.__port)) + + # If not started, assume it is already disconnected + if not self.__started.is_set(): + return True + + # gRPC client must already be instantiated + assert self.__client + + # Trigger termination of loops and processes + self.__terminate.set() + + # Trigger connection tear down with the P4Runtime server + self.__client.tear_down() + self.__client = None + + LOGGER.info('\tDisonnected!') + + return True + + def GetInitialConfig(self) -> List[Tuple[str, Any]]: + LOGGER.info('P4 GetInitialConfig()') + return [] + + def GetConfig(self, resource_keys : List[str] = [])\ + -> List[Tuple[str, Union[Any, None, Exception]]]: + LOGGER.info('P4 GetConfig()') + return [] + + def SetConfig(self, resources : List[Tuple[str, Any]])\ + -> List[Union[bool, Exception]]: + LOGGER.info('P4 SetConfig()') + return [] + + def DeleteConfig(self, resources : List[Tuple[str, Any]]) -> List[Union[bool, Exception]]: + LOGGER.info('P4 DeleteConfig()') + return [] + + def GetResource(self, endpoint_uuid : str) -> Optional[str]: + LOGGER.info('P4 GetResource()') + return "" + + def GetState(self, blocking=False) -> Iterator[Tuple[str, Any]]: + LOGGER.info('P4 GetState()') + return [] + + def SubscribeState(self, subscriptions : List[Tuple[str, float, float]])\ + -> List[Union[bool, Exception]]: + LOGGER.info('P4 SubscribeState()') + return [] + + def UnsubscribeState(self, subscriptions : List[Tuple[str, float, float]])\ + -> List[Union[bool, Exception]]: + LOGGER.info('P4 UnsubscribeState()') + return [] diff --git a/src/device/service/drivers/p4/P4Util.py b/src/device/service/drivers/p4/P4Util.py new file mode 100644 index 0000000000000000000000000000000000000000..278444f5766a0bc29c27a4a28c1cbfee2a016d63 --- /dev/null +++ b/src/device/service/drivers/p4/P4Util.py @@ -0,0 +1,181 @@ +from functools import wraps +import google.protobuf.text_format +from google.rpc import code_pb2 +import grpc +import logging +import queue +import sys +import threading + +from p4.v1 import p4runtime_pb2 +from p4.v1 import p4runtime_pb2_grpc + +P4_ATTR_DEV_ID = 'id' +P4_ATTR_DEV_NAME = 'name' +P4_ATTR_DEV_VENDOR = 'vendor' +P4_ATTR_DEV_HW_VER = 'hw_ver' +P4_ATTR_DEV_SW_VER = 'sw_ver' +P4_ATTR_DEV_PIPECONF = 'pipeconf' + +LOGGER = logging.getLogger(__name__) + + +class P4RuntimeErrorFormatException(Exception): + def __init__(self, message): + super().__init__(message) + + +class P4RuntimeException(Exception): + def __init__(self, grpc_error): + super().__init__() + self.grpc_error = grpc_error + + def __str__(self): + message = "P4Runtime RPC error ({}): {}".format( + self.grpc_error.code().name, self.grpc_error.details()) + return message + + +def parse_p4runtime_error(f): + @wraps(f) + def handle(*args, **kwargs): + try: + return f(*args, **kwargs) + except grpc.RpcError as e: + raise P4RuntimeException(e) from None + return handle + + +class P4RuntimeClient: + def __init__(self, device_id, grpc_address, election_id, role_name=None): + self.device_id = device_id + self.election_id = election_id + self.role_name = role_name + self.stream_in_q = None + self.stream_out_q = None + self.stream = None + self.stream_recv_thread = None + LOGGER.debug( + "Connecting to device {} at {}".format(device_id, grpc_address)) + try: + self.channel = grpc.insecure_channel(grpc_address) + except Exception: + LOGGER.critical("Failed to connect to P4Runtime server") + sys.exit(1) + self.stub = p4runtime_pb2_grpc.P4RuntimeStub(self.channel) + self.set_up_stream() + + def set_up_stream(self): + self.stream_out_q = queue.Queue() + # queues for different messages + self.stream_in_q = { + "arbitration": queue.Queue(), + "packet": queue.Queue(), + "digest": queue.Queue(), + "unknown": queue.Queue(), + } + + def stream_req_iterator(): + while True: + p = self.stream_out_q.get() + if p is None: + break + yield p + + def stream_recv_wrapper(stream): + @parse_p4runtime_error + def stream_recv(): + for p in stream: + if p.HasField("arbitration"): + self.stream_in_q["arbitration"].put(p) + elif p.HasField("packet"): + self.stream_in_q["packet"].put(p) + elif p.HasField("digest"): + self.stream_in_q["digest"].put(p) + else: + self.stream_in_q["unknown"].put(p) + try: + stream_recv() + except P4RuntimeException as e: + LOGGER.critical("StreamChannel error, closing stream") + LOGGER.critical(e) + for k in self.stream_in_q: + self.stream_in_q[k].put(None) + self.stream = self.stub.StreamChannel(stream_req_iterator()) + self.stream_recv_thread = threading.Thread( + target=stream_recv_wrapper, args=(self.stream,)) + self.stream_recv_thread.start() + self.handshake() + + def handshake(self): + req = p4runtime_pb2.StreamMessageRequest() + arbitration = req.arbitration + arbitration.device_id = self.device_id + election_id = arbitration.election_id + election_id.high = self.election_id[0] + election_id.low = self.election_id[1] + if self.role_name is not None: + arbitration.role.name = self.role_name + self.stream_out_q.put(req) + + rep = self.get_stream_packet("arbitration", timeout=2) + if rep is None: + LOGGER.critical("Failed to establish session with server") + sys.exit(1) + is_primary = (rep.arbitration.status.code == code_pb2.OK) + LOGGER.debug("Session established, client is '{}'".format( + 'primary' if is_primary else 'backup')) + if not is_primary: + LOGGER.warning("You are not the primary client, you only have read access to the server") + + def get_stream_packet(self, type_, timeout=1): + if type_ not in self.stream_in_q: + LOGGER.critical("Unknown stream type '{}'".format(type_)) + return None + try: + msg = self.stream_in_q[type_].get(timeout=timeout) + return msg + except queue.Empty: # timeout expired + return None + + @parse_p4runtime_error + def get_p4info(self): + LOGGER.debug("Retrieving P4Info file") + req = p4runtime_pb2.GetForwardingPipelineConfigRequest() + req.device_id = self.device_id + req.response_type = p4runtime_pb2.GetForwardingPipelineConfigRequest.P4INFO_AND_COOKIE + rep = self.stub.GetForwardingPipelineConfig(req) + return rep.config.p4info + + @parse_p4runtime_error + def set_fwd_pipe_config(self, p4info_path, bin_path): + LOGGER.debug("Setting forwarding pipeline config") + req = p4runtime_pb2.SetForwardingPipelineConfigRequest() + req.device_id = self.device_id + if self.role_name is not None: + req.role = self.role_name + election_id = req.election_id + election_id.high = self.election_id[0] + election_id.low = self.election_id[1] + req.action = p4runtime_pb2.SetForwardingPipelineConfigRequest.VERIFY_AND_COMMIT + with open(p4info_path, 'r') as f1: + with open(bin_path, 'rb') as f2: + try: + google.protobuf.text_format.Merge(f1.read(), req.config.p4info) + except google.protobuf.text_format.ParseError: + LOGGER.error("Error when parsing P4Info") + raise + req.config.p4_device_config = f2.read() + return self.stub.SetForwardingPipelineConfig(req) + + def tear_down(self): + if self.stream_out_q: + LOGGER.debug("Cleaning up stream") + self.stream_out_q.put(None) + if self.stream_in_q: + for k in self.stream_in_q: + self.stream_in_q[k].put(None) + if self.stream_recv_thread: + self.stream_recv_thread.join() + self.channel.close() + del self.channel # avoid a race condition if channel deleted when process terminates