from functools import wraps import google.protobuf.text_format from google.rpc import code_pb2 import grpc import logging import queue import sys import threading from p4.v1 import p4runtime_pb2 from p4.v1 import p4runtime_pb2_grpc P4_ATTR_DEV_ID = 'id' P4_ATTR_DEV_NAME = 'name' P4_ATTR_DEV_VENDOR = 'vendor' P4_ATTR_DEV_HW_VER = 'hw_ver' P4_ATTR_DEV_SW_VER = 'sw_ver' P4_ATTR_DEV_PIPECONF = 'pipeconf' LOGGER = logging.getLogger(__name__) class P4RuntimeErrorFormatException(Exception): def __init__(self, message): super().__init__(message) class P4RuntimeException(Exception): def __init__(self, grpc_error): super().__init__() self.grpc_error = grpc_error def __str__(self): message = "P4Runtime RPC error ({}): {}".format( self.grpc_error.code().name, self.grpc_error.details()) return message def parse_p4runtime_error(f): @wraps(f) def handle(*args, **kwargs): try: return f(*args, **kwargs) except grpc.RpcError as e: raise P4RuntimeException(e) from None return handle class P4RuntimeClient: def __init__(self, device_id, grpc_address, election_id, role_name=None): self.device_id = device_id self.election_id = election_id self.role_name = role_name self.stream_in_q = None self.stream_out_q = None self.stream = None self.stream_recv_thread = None LOGGER.debug( "Connecting to device {} at {}".format(device_id, grpc_address)) try: self.channel = grpc.insecure_channel(grpc_address) except Exception: LOGGER.critical("Failed to connect to P4Runtime server") sys.exit(1) self.stub = p4runtime_pb2_grpc.P4RuntimeStub(self.channel) self.set_up_stream() def set_up_stream(self): self.stream_out_q = queue.Queue() # queues for different messages self.stream_in_q = { "arbitration": queue.Queue(), "packet": queue.Queue(), "digest": queue.Queue(), "unknown": queue.Queue(), } def stream_req_iterator(): while True: p = self.stream_out_q.get() if p is None: break yield p def stream_recv_wrapper(stream): @parse_p4runtime_error def stream_recv(): for p in stream: if p.HasField("arbitration"): self.stream_in_q["arbitration"].put(p) elif p.HasField("packet"): self.stream_in_q["packet"].put(p) elif p.HasField("digest"): self.stream_in_q["digest"].put(p) else: self.stream_in_q["unknown"].put(p) try: stream_recv() except P4RuntimeException as e: LOGGER.critical("StreamChannel error, closing stream") LOGGER.critical(e) for k in self.stream_in_q: self.stream_in_q[k].put(None) self.stream = self.stub.StreamChannel(stream_req_iterator()) self.stream_recv_thread = threading.Thread( target=stream_recv_wrapper, args=(self.stream,)) self.stream_recv_thread.start() self.handshake() def handshake(self): req = p4runtime_pb2.StreamMessageRequest() arbitration = req.arbitration arbitration.device_id = self.device_id election_id = arbitration.election_id election_id.high = self.election_id[0] election_id.low = self.election_id[1] if self.role_name is not None: arbitration.role.name = self.role_name self.stream_out_q.put(req) rep = self.get_stream_packet("arbitration", timeout=2) if rep is None: LOGGER.critical("Failed to establish session with server") sys.exit(1) is_primary = (rep.arbitration.status.code == code_pb2.OK) LOGGER.debug("Session established, client is '{}'".format( 'primary' if is_primary else 'backup')) if not is_primary: LOGGER.warning("You are not the primary client, you only have read access to the server") def get_stream_packet(self, type_, timeout=1): if type_ not in self.stream_in_q: LOGGER.critical("Unknown stream type '{}'".format(type_)) return None try: msg = self.stream_in_q[type_].get(timeout=timeout) return msg except queue.Empty: # timeout expired return None @parse_p4runtime_error def get_p4info(self): LOGGER.debug("Retrieving P4Info file") req = p4runtime_pb2.GetForwardingPipelineConfigRequest() req.device_id = self.device_id req.response_type = p4runtime_pb2.GetForwardingPipelineConfigRequest.P4INFO_AND_COOKIE rep = self.stub.GetForwardingPipelineConfig(req) return rep.config.p4info @parse_p4runtime_error def set_fwd_pipe_config(self, p4info_path, bin_path): LOGGER.debug("Setting forwarding pipeline config") req = p4runtime_pb2.SetForwardingPipelineConfigRequest() req.device_id = self.device_id if self.role_name is not None: req.role = self.role_name election_id = req.election_id election_id.high = self.election_id[0] election_id.low = self.election_id[1] req.action = p4runtime_pb2.SetForwardingPipelineConfigRequest.VERIFY_AND_COMMIT with open(p4info_path, 'r') as f1: with open(bin_path, 'rb') as f2: try: google.protobuf.text_format.Merge(f1.read(), req.config.p4info) except google.protobuf.text_format.ParseError: LOGGER.error("Error when parsing P4Info") raise req.config.p4_device_config = f2.read() return self.stub.SetForwardingPipelineConfig(req) def tear_down(self): if self.stream_out_q: LOGGER.debug("Cleaning up stream") self.stream_out_q.put(None) if self.stream_in_q: for k in self.stream_in_q: self.stream_in_q[k].put(None) if self.stream_recv_thread: self.stream_recv_thread.join() self.channel.close() del self.channel # avoid a race condition if channel deleted when process terminates