"""
A mock P4Runtime service implementation.
"""

import queue
from google.rpc import code_pb2
from p4.v1 import p4runtime_pb2, p4runtime_pb2_grpc
from p4.config.v1 import p4info_pb2

try:
    from p4_util import STREAM_ATTR_ARBITRATION, STREAM_ATTR_PACKET
except ImportError:
    from device.service.drivers.p4.p4_util import STREAM_ATTR_ARBITRATION,\
        STREAM_ATTR_PACKET

class MockP4RuntimeServicerImpl(p4runtime_pb2_grpc.P4RuntimeServicer):
    """
    A P4Runtime service implementation for testing purposes.
    """

    def __init__(self):
        self.p4info = p4info_pb2.P4Info()
        self.p4runtime_api_version = "1.3.0"
        self.stored_packet_out = queue.Queue()

    def GetForwardingPipelineConfig(self, request, context):
        rep = p4runtime_pb2.GetForwardingPipelineConfigResponse()
        if self.p4info is not None:
            rep.config.p4info.CopyFrom(self.p4info)
        return rep

    def SetForwardingPipelineConfig(self, request, context):
        self.p4info.CopyFrom(request.config.p4info)
        return p4runtime_pb2.SetForwardingPipelineConfigResponse()

    def Write(self, request, context):
        return p4runtime_pb2.WriteResponse()

    def Read(self, request, context):
        yield p4runtime_pb2.ReadResponse()

    def StreamChannel(self, request_iterator, context):
        for req in request_iterator:
            if req.HasField(STREAM_ATTR_ARBITRATION):
                rep = p4runtime_pb2.StreamMessageResponse()
                rep.arbitration.CopyFrom(req.arbitration)
                rep.arbitration.status.code = code_pb2.OK
                yield rep
            elif req.HasField(STREAM_ATTR_PACKET):
                self.stored_packet_out.put(req)

    def Capabilities(self, request, context):
        rep = p4runtime_pb2.CapabilitiesResponse()
        rep.p4runtime_api_version = self.p4runtime_api_version
        return rep
