Skip to content
P4Util.py 6.39 KiB
Newer Older
from functools import wraps
import google.protobuf.text_format
from google.rpc import code_pb2
import grpc
import logging
import queue
import sys
import threading

from p4.v1 import p4runtime_pb2
from p4.v1 import p4runtime_pb2_grpc

P4_ATTR_DEV_ID = 'id'
P4_ATTR_DEV_NAME = 'name'
P4_ATTR_DEV_VENDOR = 'vendor'
P4_ATTR_DEV_HW_VER = 'hw_ver'
P4_ATTR_DEV_SW_VER = 'sw_ver'
P4_ATTR_DEV_PIPECONF = 'pipeconf'

LOGGER = logging.getLogger(__name__)


class P4RuntimeErrorFormatException(Exception):
    def __init__(self, message):
        super().__init__(message)


class P4RuntimeException(Exception):
    def __init__(self, grpc_error):
        super().__init__()
        self.grpc_error = grpc_error

    def __str__(self):
        message = "P4Runtime RPC error ({}): {}".format(
            self.grpc_error.code().name, self.grpc_error.details())
        return message


def parse_p4runtime_error(f):
    @wraps(f)
    def handle(*args, **kwargs):
        try:
            return f(*args, **kwargs)
        except grpc.RpcError as e:
            raise P4RuntimeException(e) from None
    return handle


class P4RuntimeClient:
    def __init__(self, device_id, grpc_address, election_id, role_name=None):
        self.device_id = device_id
        self.election_id = election_id
        self.role_name = role_name
        self.stream_in_q = None
        self.stream_out_q = None
        self.stream = None
        self.stream_recv_thread = None
        LOGGER.debug(
            "Connecting to device {} at {}".format(device_id, grpc_address))
        try:
            self.channel = grpc.insecure_channel(grpc_address)
        except Exception:
            LOGGER.critical("Failed to connect to P4Runtime server")
            sys.exit(1)
        self.stub = p4runtime_pb2_grpc.P4RuntimeStub(self.channel)
        self.set_up_stream()

    def set_up_stream(self):
        self.stream_out_q = queue.Queue()
        # queues for different messages
        self.stream_in_q = {
            "arbitration": queue.Queue(),
            "packet": queue.Queue(),
            "digest": queue.Queue(),
            "unknown": queue.Queue(),
        }

        def stream_req_iterator():
            while True:
                p = self.stream_out_q.get()
                if p is None:
                    break
                yield p

        def stream_recv_wrapper(stream):
            @parse_p4runtime_error
            def stream_recv():
                for p in stream:
                    if p.HasField("arbitration"):
                        self.stream_in_q["arbitration"].put(p)
                    elif p.HasField("packet"):
                        self.stream_in_q["packet"].put(p)
                    elif p.HasField("digest"):
                        self.stream_in_q["digest"].put(p)
                    else:
                        self.stream_in_q["unknown"].put(p)
            try:
                stream_recv()
            except P4RuntimeException as e:
                LOGGER.critical("StreamChannel error, closing stream")
                LOGGER.critical(e)
                for k in self.stream_in_q:
                    self.stream_in_q[k].put(None)
        self.stream = self.stub.StreamChannel(stream_req_iterator())
        self.stream_recv_thread = threading.Thread(
            target=stream_recv_wrapper, args=(self.stream,))
        self.stream_recv_thread.start()
        self.handshake()

    def handshake(self):
        req = p4runtime_pb2.StreamMessageRequest()
        arbitration = req.arbitration
        arbitration.device_id = self.device_id
        election_id = arbitration.election_id
        election_id.high = self.election_id[0]
        election_id.low = self.election_id[1]
        if self.role_name is not None:
            arbitration.role.name = self.role_name
        self.stream_out_q.put(req)

        rep = self.get_stream_packet("arbitration", timeout=2)
        if rep is None:
            LOGGER.critical("Failed to establish session with server")
            sys.exit(1)
        is_primary = (rep.arbitration.status.code == code_pb2.OK)
        LOGGER.debug("Session established, client is '{}'".format(
            'primary' if is_primary else 'backup'))
        if not is_primary:
            LOGGER.warning("You are not the primary client, you only have read access to the server")

    def get_stream_packet(self, type_, timeout=1):
        if type_ not in self.stream_in_q:
            LOGGER.critical("Unknown stream type '{}'".format(type_))
            return None
        try:
            msg = self.stream_in_q[type_].get(timeout=timeout)
            return msg
        except queue.Empty:  # timeout expired
            return None

    @parse_p4runtime_error
    def get_p4info(self):
        LOGGER.debug("Retrieving P4Info file")
        req = p4runtime_pb2.GetForwardingPipelineConfigRequest()
        req.device_id = self.device_id
        req.response_type = p4runtime_pb2.GetForwardingPipelineConfigRequest.P4INFO_AND_COOKIE
        rep = self.stub.GetForwardingPipelineConfig(req)
        return rep.config.p4info

    @parse_p4runtime_error
    def set_fwd_pipe_config(self, p4info_path, bin_path):
        LOGGER.debug("Setting forwarding pipeline config")
        req = p4runtime_pb2.SetForwardingPipelineConfigRequest()
        req.device_id = self.device_id
        if self.role_name is not None:
            req.role = self.role_name
        election_id = req.election_id
        election_id.high = self.election_id[0]
        election_id.low = self.election_id[1]
        req.action = p4runtime_pb2.SetForwardingPipelineConfigRequest.VERIFY_AND_COMMIT
        with open(p4info_path, 'r') as f1:
            with open(bin_path, 'rb') as f2:
                try:
                    google.protobuf.text_format.Merge(f1.read(), req.config.p4info)
                except google.protobuf.text_format.ParseError:
                    LOGGER.error("Error when parsing P4Info")
                    raise
                req.config.p4_device_config = f2.read()
        return self.stub.SetForwardingPipelineConfig(req)

    def tear_down(self):
        if self.stream_out_q:
            LOGGER.debug("Cleaning up stream")
            self.stream_out_q.put(None)
        if self.stream_in_q:
            for k in self.stream_in_q:
                self.stream_in_q[k].put(None)
        if self.stream_recv_thread:
            self.stream_recv_thread.join()
        self.channel.close()
        del self.channel  # avoid a race condition if channel deleted when process terminates