Commit 2d723659 authored by Manuel Álvarez-Campana's avatar Manuel Álvarez-Campana
Browse files

Device config - P4 Driver

Enhanced P4 Driver with Barefoot Runtime API support
parent 7b9f5eb4
Loading
Loading
Loading
Loading
+534 −0
Original line number Diff line number Diff line
# Copyright 2022-2025 ETSI SDG TeraFlowSDN (TFS) (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
P4Runtime client.
"""

import logging
import sys
from functools import wraps
import grpc

try:
    from .bfrt_grpc import bfruntime_pb2 as bfruntime_pb2
    from .bfrt_grpc import bfruntime_pb2_grpc as bfruntime_pb2_grpc
    from .bfrt_grpc import client as gc
except ImportError:
    import bfrt_grpc.bfruntime_pb2 as bfruntime_pb2
    import bfrt_grpc.bfruntime_pb2_grpc as bfruntime_pb2_grpc
    import bfrt_grpc.client as gc

LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)

DEF_DEVICE_ID = 0
DEF_CLIENT_ID = 0
DEF_PROFILE= "pipe"

class BFRTException(Exception):
    """
    BFRT Runtime exception handler.
    """
    def __init__(self, value):
        self.value = value
    def __str__(self):
        message = f"BFRT Exception {self.value} "
        return message

def parse_bfrt_error(func):
    """
    Parse BFRT Eception
    :param func: function
    :return: parsed error
    """
    @wraps(func)
    def handle(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            raise BFRTException(e)
    return handle


class BFRTRuntimeClient:
    """
    BFRT Runtime client.

    Attributes
    ----------
    device_id : int
        P4 device ID
    grpc_address : str
        IP address and port
    """

    def __init__(self, device_id, grpc_address):
        self.device_id = device_id
        self.grpc_address = grpc_address
        self.profile = DEF_PROFILE
        LOGGER.debug("Connecting to device %d at %s", device_id, grpc_address)
        self.p4_name = None
        try:
            self.interface = gc.ClientInterface(self.grpc_address, client_id=DEF_CLIENT_ID, device_id=self.device_id, notifications=None)
            #self.interface = gc.ClientInterface(self.grpc_address, client_id=DEF_CLIENT_ID, device_id=DEF_DEVICE_ID, notifications=None)
        except Exception as e:
            LOGGER.critical(f"Failed to connect to server: {e}")
            sys.exit(1)

        LOGGER.info("Successfully connected")

    #@parse_bfrt_error
    def bfrt_get_info(self):
        """
        Retrieve P4Info content.

        :return: P4Info object.
        """
        LOGGER.debug(f"Retreiving bfrt info from device {self.device_id} at {self.grpc_address}")
        try:
            self.bfrt_info = self.interface.bfrt_info_get()
        except Exception as e:
            msg = f"Cannot retrieve bfrt info: {e}"
            LOGGER.critical(msg)
            raise BFRTException(msg)

        self.p4_name = self.bfrt_info.p4_name_get()

        self.all_tables = self.bfrt_info.table_name_list_get()
        self.tables = []
        for table_name in self.all_tables:
            #if not table_name.startswith('pipe.'):
            #    #LOGGER.debug(f"Table {table_name} does not start with 'pipe.'")
            #    continue
            table = self.bfrt_info.table_get(table_name)
            attr = table.info.attributes_supported_get()
            if not "EntryScope" in attr:
                #LOGGER.debug(f"Table {table_name} does not support entries")
                continue
            self.tables.append(table_name)

        try:
            LOGGER.debug(f"BINDING PIPELINE config for P4 program {self.p4_name}")
            self.interface.bind_pipeline_config(self.p4_name)
        except Exception as e:
            msg = f"BINDING FAILURE for P4 program {self.p4_name}"
            LOGGER.critical(msg)
            raise BFRTException(msg)

        self.__tgt = gc.Target(0)

        return self.bfrt_info

    @parse_bfrt_error
    def bfrt_get_table(self,table_name):
        """
        Retrieve BFRT table

        :return: BFRT table object.
        """
        return self.bfrt_info.table_get(table_name)

    def bfrt_get_tables(self,only_entry_scope=True):
        """
        Retrieve BFRT Tables.

        :return: List of BFRT Tables susceptible of having entries
        """
        return self.tables

    def bfrt_get_all_tables(self,only_entry_scope=True):
        """
        Retrieve BFRT Tables.

        :return: List of BFRT Tables susceptible of having entries
        """
        return self.all_tables

    @parse_bfrt_error
    def bfrt_get_table_entries(self, table, filter_match_map=None, filter_action_name=None, filter_data_map=None):
        """
        Get a list of P4 table entries by table name and optionally by action.
    
        :param table_name: name of a P4 table
        :param action_name: action name
        :return: list of P4 table entries or None
        """
        def table_entry_to_json(table_name,match_map,action_name=None,action_params=[],
                                match_annotations={},action_annotations={}):
            json_data = {}
            json_data["table-name"] = table_name
            if match_annotations:
                json_data["match-annotations"] = match_annotations
            if match_map:
                json_data["match-fields"] = match_map
            if action_name:
                json_data["action-name"] = action_name
            if action_annotations:
                json_data["action-annotations"] = action_annotations
            if action_params:
                json_data["action-params"] = action_params
            return json_data

        def get_value(table,k,v):
            match_type = table.info.key_field_match_type_get(k)
            #LOGGER.info(f"CHECKING DEFAULT VALUE {table.info.name}: match {match_type} key {k} val {v}")
            if (match_type == "Exact"):
                return v["value"]
            if match_type == "Ternary":
               return None if v["mask"] == 0 else v
            if match_type == "LPM":
               return None if ["prefix_len"] == 0 else v
            if (match_type == "Range"):
                # to do
                return v
            if (match_type == "Optional"):
                return v
            return v

        def build_filter_key_list(table, filter_match_map):
            key_list = None
            if filter_match_map:
                LOGGER.info(f"filter_match_map: {filter_match_map}")
                if isinstance(filter_match_map,dict):
                    key_list = [self.bfrt_build_key_list(table, filter_match_map)]
                else:
                    key_list = []
                    for match  in filter_match_map:
                        key_list.append(self.bfrt_build_key_list(table, match))
                return key_list
    
        def build_filter_required_data(table,action_name,data_map):
            required_data = None
            if action_name or data_map:
                LOGGER.info(f"required_data: action_name {action_name} data_map {data_map}")
                required_data = self.bfrt_build_data_list(table, action_name, data_map, get=True)
            return required_data

        try:
            if self.bfrt_count_table_entries(table) == 0:
                return []
            table_entries = []
            table_name = table.info.name_get()

            key_list = build_filter_key_list(table,filter_match_map)
            required_data = build_filter_required_data(table,filter_action_name,filter_data_map)

            #def entry_get(self, target,
            #key_list=None, flags={"from_hw":True}, required_data=None, handle=None, p4_name=None, metadata=None,
            #entry_tgt_list=None):

            entries = table.entry_get(self.__tgt, key_list=key_list, flags={"from_hw": True}, required_data = required_data)
            for data, key in entries:
                key_dict = key.to_dict()
                data_dict = data.to_dict()
                action_name = data_dict.get("action_name")
                match_annotations = {}
                match_map = {}
                for k,v in key_dict.items():
                    val  = get_value(table,k,v)
                    if not val:
                        continue
                    match_map[k] = val
                    ann = self.bfrt_get_key_field_annotations(table,k)
                    if ann != "":
                        match_annotations[k] = ann
                param_dict = {p:v for p,v in data_dict.items() if p not in ["action_name", "is_default_entry"]}
                action_params = {}
                action_annotations = {}
                for p,v in param_dict.items():
                    action_params[p] = v
                    ann = self.bfrt_get_action_data_annotations(table,action_name,p)
                    if ann != "":
                        action_annotations[p] = ann
                entry_json =  table_entry_to_json(table_name,match_map,action_name,action_params,
                                                    match_annotations,action_annotations)
                #LOGGER.info(f"found entry in table name {table_name}")
                table_entries.append(entry_json)
            return table_entries
        except Exception as e:
            LOGGER.error(f"Failed to get table {table_name} entries: {e}")
            raise BFRTException(e)
    
        return []


    @parse_bfrt_error
    def bfrt_count_table_entries(self, table, action_name=None):
        """
        Count Table Entries.

        :return: number of table entries
        """
        count = 0
        entries = table.entry_get(self.__tgt, None, {"from_hw": True})
        for data, key in entries:
            count = count + 1
        return count

    @parse_bfrt_error
    def bfrt_build_key_tuple(self,key,val):
        #LOGGER.info(f"build_key_tuple {key}: {val}")
        if not isinstance(val,dict):
            return gc.KeyTuple(key,val)
        if isinstance(val,dict):
            n = len(val)
            if n ==0:
                raise RuntimeError(f"Invalid key tuple with empty dictionary {val}")
            if not "value" in val:
                raise RuntimeError(f"Invalid key tuple '{val}' does not have 'value' entry")
            value = val["value"]
            if n ==1:
               return gc.KeyTuple(key,value)
            if n==2:
                if "mask" in val:
                    mask = val["mask"]
                    return gc.KeyTuple(key,value=value,mask=mask)
                if "prefix_len" in val:
                    prefix_len = val["prefix_len"]
                    return gc.KeyTuple(key,value=value,prefix_len=prefix_len)
                if "is_valid" in val:
                    is_valid = val["is_valid"]
                    return gc.KeyTuple(key,value=value,is_valid=is_valid)
                raise RuntimeError(f"Invalid ternary key tuple '{val}' does not have mask/prefix_len/is_valid entry")
            if len==3:
                if "low" in val and "high" in val:
                    return gc.KeyTuple(key,val.value, low=val["low"], high=val["high"])
            raise RuntimeError(f"Invalid ternary key tuple '{val}' does not have mask/prefix_len/is_valid entry")

    @parse_bfrt_error
    def bfrt_get_key_field_annotations(self,table,key_field_name):
        return ",".join([obj.value for obj in table.info.key_field_annotations_get(key_field_name)])

    @parse_bfrt_error
    def bfrt_get_key_field(self,table,key_field_name):
        res = {}
        res["key_field_name"] = key_field_name
        res["match_type"] = table.info.key_field_match_type_get(key_field_name)
        res["type"] = table.info.key_field_type_get(key_field_name)
        res["size"] = table.info.key_field_size_get(key_field_name) 
        res["repeated"] = table.info.key_field_repeated_get(key_field_name) 
        res["mandatory"] = table.info.key_field_mandatory_get(key_field_name)
        res["annotations"] = ",".join([obj.value for obj in table.info.key_field_annotations_get(key_field_name)])
        return res

    @parse_bfrt_error
    def bfrt_build_key_list(self, table, key_dict):
        key_tuples = []
        for key, val in key_dict.items():
            key_tuple = self.bfrt_build_key_tuple(key,val)
            key_tuples.append(key_tuple)
        return table.make_key(key_tuples)

    @parse_bfrt_error
    def bfrt_set_key_list_annotations(self, table, key_ann_dict):
        for key, ann in key_ann_dict.items():
            #key_anns = [obj.value for obj in table.info.key_field_annotations_get(key)]
            if ann in [obj.value for obj in table.info.key_field_annotations_get(key)]:
                #LOGGER.info(f"SKIP DUPLICATED KEY ANNOTATION {key}:{ann}")
                continue
            #LOGGER.info(f"MATCH FIELD ANNOTATION {key}:{ann}")
            #LOGGER.info(f"INSERT KEY ANNOTATION {key}:{ann}")
            table.info.key_field_annotation_add(key, ann)

    @parse_bfrt_error
    def bfrt_get_action_data_annotations(self,table,action_name,data_field_name):
        table_name = table.info.name_get()
        return ",".join([obj.value for obj in table.info.data_field_annotations_get(data_field_name,action_name)])

    @parse_bfrt_error
    def bfrt_get_action_data_field(self,table,action_name,data_field_name):
        res = {}
        res["data_field_name"] = data_field_name
        res["type"] = table.info.data_field_type_get(data_field_name,action_name)
        res["size"] = table.info.data_field_size_get(data_field_name,action_name)
        res["repeated"] = table.info.data_field_repeated_get(data_field_name,action_name)
        res["mandatory"] = table.info.data_field_mandatory_get(data_field_name,action_name)
        res["read_only"] = table.info.data_field_read_only_get(data_field_name,action_name)
        res["choices"] = table.info.data_field_allowed_choices_get(data_field_name,action_name)
        res["annotations"] = ",".join([obj.value for obj in table.info.data_field_annotations_get(data_field_name,action_name)])
        return res

    @parse_bfrt_error
    def bfrt_build_data_list(self, table, action_name, param_dict, get=False):
        def build_data_tuple(par,val):
            if par.startswith("$"):
                par_type = table.info.data_field_type_get(par,action_name)
                repeated = table.info.data_field_repeated_get(par,action_name)
                #LOGGER.info(f"expected type of {par}: {par_type} repeated: {repeated}")
                if isinstance(val,str):
                    return gc.DataTuple(par,str_val=val)
                if isinstance(val,bool):
                    return gc.DataTuple(par,bool_val=val)
                if isinstance(val,float):
                    return gc.DataTuple(par,float_val=val)
                if isinstance(val, list):
                    if par_type =="bool":
                        #LOGGER.info("BOOL ARR")
                        return gc.DataTuple(par,bool_arr_val=val)
                    if par_type =="string":
                        #LOGGER.info("STR ARR")
                        return gc.DataTuple(par,str_arr_val=val)
                    #LOGGER.info("INT ARR")
                    return gc.DataTuple(par,int_arr_val=val)
            return gc.DataTuple(par,val)

        if get:
            data_tuples = []
            for par in param_dict:
                data_tuples.append(build_data_tuple(par,None))
            data_list = table.make_data(data_tuples,action_name,get=True)
        else:
            data_tuples = []
            for par, val in param_dict.items():
                data_tuples.append(build_data_tuple(par,val))
            data_list = table.make_data(data_tuples,action_name)
        return data_list


    @parse_bfrt_error
    def bfrt_set_data_list_annotations(self, table, action_name, par_ann_dict):
        for par, ann in par_ann_dict.items():
            if ann in [obj.value for obj in table.info.data_field_annotations_get(par,action_name)]:
                #LOGGER.info(f"SKIP DUPLICATED ACTION PARAM ANNOTATION {par}:{ann}")
                continue
            #LOGGER.info(f"INSERT ACTION PARAM ANNOTATION {action_name} {par}:{ann}")
            table.info.data_field_annotation_add(par,action_name,ann)

    @parse_bfrt_error
    def bfrt_add_table_entry(self, table, key_dict, action_name, param_dict):
        """@brief Insert BFRT table entry.
            @param key_dict
            @param action_name
            @param action_param_dict
        """

        key_list = [self.bfrt_build_key_list(table, key_dict)]
        data_list = [self.bfrt_build_data_list(table, action_name, param_dict)]

        try:
            table.entry_add(self.__tgt, key_list, data_list)
        except gc.BfruntimeRpcException as e:
            msg = f"Cannot add entry to table {table.info.name}: {e}"
            LOGGER.critical(msg)
            raise BFRTException(msg)

    @parse_bfrt_error
    def bfrt_mod_table_entry(self, table, key_dict, action_name, param_dict):
        """@brief Modify BFRT table entry.
            @param key_dict
            @param action_name
            @param action_param_dict
        """

        key_list = [self.bfrt_build_key_list(table, key_dict)]
        data_list = [self.bfrt_build_data_list(table, action_name, param_dict)]

        try:
            table.entry_mod(self.__tgt, key_list, data_list)
        except gc.BfruntimeRpcException as e:
            msg = f"Cannot modify table entry {table.info.name}: {e}"
            LOGGER.critical(msg)
            raise BFRTException(msg)

    @parse_bfrt_error
    def bfrt_add_or_mod_table_entry(self, table, key_dict, action_name, param_dict):
        """@brief Modify table entries.
            @param key_list List of Keys.
            @param data_list List of Data. Each Data object contains action_name info as well
        """

        key_list = [self.bfrt_build_key_list(table, key_dict)]
        data_list = [self.bfrt_build_data_list(table, action_name, param_dict)]

        try:
            table.entry_add_or_mod(self.__tgt, key_list, data_list)
        except gc.BfruntimeRpcException as e:
            msg = f"Cannot add/modify table entry {table.info.name}: {e}"
            LOGGER.critical(msg)
            raise BFRTException(msg)

    @parse_bfrt_error
    def bfrt_del_table_entry(self, table, key_dict):
        """@brief Modify table entries.
            @param key_list List of Keys.
            @param data_list List of Data. Each Data object contains action_name info as well
        """
        key_list = [self.bfrt_build_key_list(table, key_dict)]

        try:
            table.entry_del(self.__tgt, key_list)
        except gc.BfruntimeRpcException as e:
            msg = f"Cannot delete table entry {table.info.name}: {e}"
            LOGGER.critical(msg)
            raise BFRTException(msg)

    @parse_bfrt_error
    def bfrt_del_table_entries(self, table):
        """@brief Modify table entries.
            @param key_list List of Keys.
            @param data_list List of Data. Each Data object contains action_name info as well
        """
        try:
            table.entry_del(self.__tgt)
        except gc.BfruntimeRpcException as e:
            msg = f"Cannot delete table entries {table.info.name}: {e}"
            LOGGER.critical(msg)
            raise BFRTException(msg)


    @parse_bfrt_error
    def set_fwd_pipe_config(self, p4_name, tofino_bin_path, bfrt_json_path, context_json_path):
        """
        Configure the pipeline.

        :param tofino_bin_path
        :param bfrt_json_path
        :param context_json_path
        :return:
        """

        self.p4_name = p4_name
        p4_path = "" # path in remote server where files will be uploaded
        pipe_list = [0,1,2,3]
        try:
            LOGGER.info("Forwarding pipeline config")
            profile = gc.ProfileInfo(self.profile, context_json_path, tofino_bin_path, pipe_list)
            cfg = [ gc.ForwardingConfig(self.p4_name, bfrt_json_path, [profile]) ]
            action = bfruntime_pb2.SetForwardingPipelineConfigRequest.VERIFY_AND_WARM_INIT_BEGIN_AND_END
            success = self.interface.send_set_forwarding_pipeline_config_request(action, p4_path, cfg)
            #success = self.interface.send_set_forwarding_pipeline_config_request(action=action, base_path=p4_path, forwarding_config_list=cfg)
            if success:
                LOGGER.info("Succesfully forwarded pipeline config")
            else:
                msg = f"Failure when forwardng pipeline configuration {p4_name}"
                LOGGER.critical(msg)
                raise BFRTException(msg)
        except Exception as e:
            msg = f"Cannot forward pipeline configuration {p4_name}: {e}"
            LOGGER.critical(msg)
            raise BFRTException(msg)

    @parse_bfrt_error
    def tear_down(self):
        if self.interface is not None:
            LOGGER.warning(f"Tearing down gracefully grpc session with {self.grpc_address}")
            try:
                self.interface.tear_down_stream()
            except:
                msg = f"Error tearing down grpc session with {self.grpc_address}"
                LOGGER.critical(msg)
                raise BFRTException(msg)

+192 −0

File added.

Preview size limit exceeded, changes collapsed.

+204 −0

File added.

Preview size limit exceeded, changes collapsed.

+3045 −0

File added.

Preview size limit exceeded, changes collapsed.

+926 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading