import json
import logging
import uuid
from typing import Dict, List, Optional

from common.Constants import DEFAULT_CONTEXT_NAME
from common.proto.context_pb2 import (
    ConfigRule,
    Constraint,
    DeviceId,
    Device,
    Empty,
    EndPointId,
    ServiceConfig,
    Slice,
    SliceStatusEnum,
)
from common.tools.context_queries.Slice import get_slice_by_defualt_name
from common.tools.grpc.ConfigRules import update_config_rule_custom
from common.tools.object_factory.Device import json_device_id
from common.DeviceTypes import DeviceTypeEnum
from context.client import ContextClient

from .YangValidator import YangValidator

LOGGER = logging.getLogger(__name__)


RUNNING_RESOURCE_KEY = "running_ietf_slice"
CANDIDATE_RESOURCE_KEY = "candidate_ietf_slice"
ADDRESS_PREFIX = 24
RAISE_IF_DIFFERS = False


def get_endpoint_controller_type(
    endpoint: EndPointId, context_client: ContextClient
) -> str:
    endpoint_device: Device = context_client.GetDevice(endpoint.device_id)
    if endpoint_device.controller_id == DeviceId():
        return ""
    controller = context_client.GetDevice(endpoint_device.controller_id)
    if controller is None:
        controller_uuid = endpoint_device.controller_id.device_uuid.uuid
        raise Exception("Device({:s}) not found".format(str(controller_uuid)))
    return controller.device_type


def get_custom_config_rule(
    service_config: ServiceConfig, resource_key: str
) -> Optional[ConfigRule]:
    for cr in service_config.config_rules:
        if (
            cr.WhichOneof("config_rule") == "custom"
            and cr.custom.resource_key == resource_key
        ):
            return cr


def sort_endpoints(
    endpoinst_list: List[EndPointId],
    sdps: List,
    connection_group: Dict,
    context_client: ContextClient,
) -> List[EndPointId]:
    first_ep = endpoinst_list[0]
    first_controller_type = get_endpoint_controller_type(first_ep, context_client)
    last_ep = endpoinst_list[-1]
    last_controller_type = get_endpoint_controller_type(last_ep, context_client)
    if first_controller_type == DeviceTypeEnum.NCE.value:
        return endpoinst_list
    elif last_controller_type == DeviceTypeEnum.NCE.value:
        return endpoinst_list[::-1]
    else:
        src_sdp_id = connection_group["connectivity-construct"][0]["p2p-sender-sdp"]
        sdp_id_name_mapping = {sdp["id"]: sdp["node-id"] for sdp in sdps}
        if (
            endpoinst_list[0].device_id.device_uuid.uuid
            == sdp_id_name_mapping[src_sdp_id]
        ):
            return endpoinst_list
        return endpoinst_list[::-1]


def replace_ont_endpoint_with_emu_dc(
    endpoint_list: List, context_client: ContextClient
) -> List:
    link_list = context_client.ListLinks(Empty())
    links = list(link_list.links)
    devices_list = context_client.ListDevices(Empty())
    devices = devices_list.devices
    uuid_name_map = {d.device_id.device_uuid.uuid: d.name for d in devices}
    uuid_device_map = {d.device_id.device_uuid.uuid: d for d in devices}
    name_device_map = {d.name: d for d in devices}
    endpoint_id_1 = endpoint_list[0]
    device_uuid_1 = endpoint_id_1.device_id.device_uuid.uuid
    device_1 = name_device_map[device_uuid_1]
    endpoint_id_2 = endpoint_list[1]
    device_uuid_2 = endpoint_id_2.device_id.device_uuid.uuid
    device_2 = name_device_map[device_uuid_2]
    if device_1.controller_id != DeviceId():
        for link in links:
            link_endpoints = list(link.link_endpoint_ids)
            link_ep_1 = link_endpoints[0]
            link_ep_2 = link_endpoints[1]
            if (
                device_uuid_1 == uuid_name_map[link_ep_1.device_id.device_uuid.uuid]
                and uuid_device_map[link_ep_2.device_id.device_uuid.uuid].device_type
                == "emu-datacenter"
            ):
                endpoint_list[0] = link_ep_2
                break
    elif device_2.controller_id != DeviceId():
        for link in links:
            link_endpoints = list(link.link_endpoint_ids)
            link_ep_1 = link_endpoints[0]
            link_ep_2 = link_endpoints[1]
            if (
                device_uuid_2 == uuid_name_map[link_ep_1.device_id.device_uuid.uuid]
                and uuid_device_map[link_ep_2.device_id.device_uuid.uuid].device_type
                == "emu-datacenter"
            ):
                endpoint_list[1] = link_ep_2
                break
    else:
        raise Exception(
            "one of the sdps should be managed by a controller and the other one should not be controlled"
        )
    return endpoint_list


def validate_ietf_slice_data(request_data: Dict) -> None:
    yang_validator = YangValidator("ietf-network-slice-service")
    _ = yang_validator.parse_to_dict(request_data)
    yang_validator.destroy()


class IETFSliceHandler:
    @staticmethod
    def get_all_ietf_slices(context_client: ContextClient) -> Dict:
        existing_context_ids = context_client.ListContextIds(Empty())
        context_ids = list(existing_context_ids.context_ids)
        if len(context_ids) != 1:
            raise Exception("Number of contexts should be 1")
        slices_list = context_client.ListSlices(context_ids[0])
        slices = slices_list.slices
        ietf_slices = {"network-slice-services": {"slice-service": []}}
        for slice in slices:
            candidate_cr = get_custom_config_rule(
                slice.slice_config, CANDIDATE_RESOURCE_KEY
            )
            candidate_ietf_data = json.loads(candidate_cr.custom.resource_value)
            ietf_slices["network-slice-services"]["slice-service"].append(
                candidate_ietf_data["network-slice-services"]["slice-service"][0]
            )
        return ietf_slices

    @staticmethod
    def create_slice_service(
        request_data: dict, context_client: ContextClient
    ) -> Slice:
        if "network-slice-services" not in request_data:
            request_data = {"network-slice-services": request_data}
        validate_ietf_slice_data(request_data)
        slice_services = request_data["network-slice-services"]["slice-service"]
        slice_service = slice_services[0]
        slice_id = slice_service["id"]
        sdps = slice_service["sdps"]["sdp"]
        connection_groups = slice_service["connection-groups"]["connection-group"]
        if len(sdps) != 2:
            raise Exception("Number of SDPs should be 2")
        slice_request: Slice = Slice()
        slice_request.slice_id.context_id.context_uuid.uuid = DEFAULT_CONTEXT_NAME
        slice_request.slice_id.slice_uuid.uuid = slice_id
        slice_request.slice_status.slice_status = SliceStatusEnum.SLICESTATUS_PLANNED
        list_endpoints = []
        endpoint_config_rules = []
        connection_group_ids = set()
        for sdp in sdps:
            attachment_circuits = sdp["attachment-circuits"]["attachment-circuit"]
            if len(attachment_circuits) != 1:
                raise Exception("All SDPs should have 1 attachment-circuit")
            endpoint = EndPointId()
            endpoint.topology_id.context_id.context_uuid.uuid = DEFAULT_CONTEXT_NAME
            device_uuid = sdp["node-id"]
            endpoint.device_id.device_uuid.uuid = device_uuid
            endpoint_uuid = attachment_circuits[0]["ac-tp-id"]
            endpoint.endpoint_uuid.uuid = endpoint_uuid
            list_endpoints.append(endpoint)
            connection_group_ids.add(
                sdp["service-match-criteria"]["match-criterion"][0][
                    "target-connection-group-id"
                ]
            )
            endpoint_config_rule_fields = {
                "address_ip": (endpoint_uuid, RAISE_IF_DIFFERS),
                "address_prefix": (ADDRESS_PREFIX, RAISE_IF_DIFFERS),
            }
            endpoint_config_rules.append(
                (
                    f"/device[{device_uuid}]/endpoint[{endpoint_uuid}]/settings",
                    endpoint_config_rule_fields,
                )
            )
        if len(connection_group_ids) != 1:
            raise Exception("SDPs target-connection-group-id do not match")
        list_constraints = []
        for cg in connection_groups:
            if cg["id"] != list(connection_group_ids)[0]:
                continue
            metric_bounds = cg["connectivity-construct"][0]["service-slo-sle-policy"][
                "slo-policy"
            ]["metric-bound"]
            for metric in metric_bounds:
                if metric["metric-type"] == "ietf-nss:one-way-delay-maximum":
                    constraint = Constraint()
                    constraint.sla_latency.e2e_latency_ms = float(metric["bound"])
                    list_constraints.append(constraint)
                elif metric["metric-type"] == "ietf-nss:one-way-bandwidth":
                    constraint = Constraint()
                    constraint.sla_capacity.capacity_gbps = (
                        float(metric["bound"]) / 1.0e3
                    )
                    list_constraints.append(constraint)
            break
        else:
            raise Exception("connection group not found")
        list_endpoints = sort_endpoints(list_endpoints, sdps, cg, context_client)
        list_endpoints = replace_ont_endpoint_with_emu_dc(
            list_endpoints, context_client
        )
        slice_request.slice_endpoint_ids.extend(list_endpoints)
        slice_request.slice_constraints.extend(list_constraints)
        # TODO adding owner, needs to be recoded after updating the bindings
        owner = slice_id
        slice_request.slice_owner.owner_string = owner
        slice_request.slice_owner.owner_uuid.uuid = str(
            uuid.uuid5(uuid.NAMESPACE_DNS, owner)
        )
        ietf_slice_fields = {
            name: (value, RAISE_IF_DIFFERS) for name, value in request_data.items()
        }
        update_config_rule_custom(
            slice_request.slice_config.config_rules,
            RUNNING_RESOURCE_KEY,
            ietf_slice_fields,
        )
        update_config_rule_custom(
            slice_request.slice_config.config_rules,
            CANDIDATE_RESOURCE_KEY,
            ietf_slice_fields,
        )

        for ep_cr_key, ep_cr_fields in endpoint_config_rules:
            update_config_rule_custom(
                slice_request.slice_config.config_rules, ep_cr_key, ep_cr_fields
            )

        return slice_request

    @staticmethod
    def create_sdp(
        request_data: dict, slice_uuid: str, context_client: ContextClient
    ) -> Slice:
        sdps = request_data["sdp"]
        if len(sdps) != 1:
            raise Exception("Number of SDPs should be 1")
        new_sdp = sdps[0]
        # slice_request = get_slice_by_uuid(context_client, slice_uuid)
        slice_request = get_slice_by_defualt_name(
            context_client, slice_uuid, rw_copy=False
        )
        for cr in slice_request.slice_config.config_rules:
            if cr.WhichOneof("config_rule") != "custom":
                continue
            if cr.custom.resource_key == CANDIDATE_RESOURCE_KEY:
                ietf_data = json.loads(cr.custom.resource_value)
                break
        else:
            raise Exception("ietf data not found")
        slice_services = ietf_data["network-slice-services"]["slice-service"]
        slice_service = slice_services[0]
        slice_sdps = slice_service["sdps"]["sdp"]
        slice_sdps.append(new_sdp)
        fields = {name: (value, RAISE_IF_DIFFERS) for name, value in ietf_data.items()}
        update_config_rule_custom(
            slice_request.slice_config.config_rules, CANDIDATE_RESOURCE_KEY, fields
        )
        return slice_request

    @staticmethod
    def delete_sdp(
        slice_uuid: str, sdp_id: str, context_client: ContextClient
    ) -> Slice:
        # slice_request = get_slice_by_uuid(context_client, slice_uuid)
        slice_request = get_slice_by_defualt_name(
            context_client, slice_uuid, rw_copy=False
        )
        for cr in slice_request.slice_config.config_rules:
            if cr.WhichOneof("config_rule") != "custom":
                continue
            if cr.custom.resource_key == CANDIDATE_RESOURCE_KEY:
                ietf_data = json.loads(cr.custom.resource_value)
                break
        else:
            raise Exception("ietf data not found")
        slice_services = ietf_data["network-slice-services"]["slice-service"]
        slice_service = slice_services[0]
        slice_sdps = slice_service["sdps"]["sdp"]
        sdp_idx = list((slice_sdp["id"] == sdp_id for slice_sdp in slice_sdps)).index(
            True
        )
        slice_sdps.pop(sdp_idx)
        fields = {name: (value, RAISE_IF_DIFFERS) for name, value in ietf_data.items()}
        update_config_rule_custom(
            slice_request.slice_config.config_rules, CANDIDATE_RESOURCE_KEY, fields
        )
        return slice_request

    @staticmethod
    def create_connection_group(
        request_data: dict, slice_id: str, context_client: ContextClient
    ) -> Slice:
        connection_groups = request_data["connection-group"]
        if len(connection_groups) != 1:
            raise Exception("Number of connection groups should be 1")
        new_connection_group = connection_groups[0]
        # slice = get_slice_by_uuid(context_client, slice_id)
        slice = get_slice_by_defualt_name(context_client, slice_id, rw_copy=False)
        for cr in slice.slice_config.config_rules:
            if cr.WhichOneof("config_rule") != "custom":
                continue
            if cr.custom.resource_key == CANDIDATE_RESOURCE_KEY:
                ietf_data = json.loads(cr.custom.resource_value)
                break
        else:
            raise Exception("ietf data not found")
        slice_services = ietf_data["network-slice-services"]["slice-service"]
        slice_service = slice_services[0]
        slice_connection_groups = slice_service["connection-groups"]["connection-group"]
        slice_connection_groups.append(new_connection_group)
        fields = {name: (value, RAISE_IF_DIFFERS) for name, value in ietf_data.items()}
        update_config_rule_custom(
            slice.slice_config.config_rules, CANDIDATE_RESOURCE_KEY, fields
        )
        validate_ietf_slice_data(ietf_data)
        return slice

    @staticmethod
    def update_connection_group(
        slice_name: str,
        updated_connection_group: dict,
        context_client: ContextClient,
    ):
        slice_request = get_slice_by_defualt_name(
            context_client, slice_name, rw_copy=False
        )
        slice_config = slice_request.slice_config
        cr = get_custom_config_rule(slice_config, CANDIDATE_RESOURCE_KEY)
        candidate_ietf_data = json.loads(cr.custom.resource_value)
        slice_services = candidate_ietf_data["network-slice-services"]["slice-service"]
        slice_service = slice_services[0]
        slice_connection_groups = slice_service["connection-groups"]["connection-group"]
        connection_group_id = updated_connection_group["id"]
        cg_idx = list(
            (
                slice_cg["id"] == connection_group_id
                for slice_cg in slice_connection_groups
            )
        ).index(True)
        slice_connection_groups[cg_idx] = updated_connection_group
        fields = {
            name: (value, RAISE_IF_DIFFERS)
            for name, value in candidate_ietf_data.items()
        }
        update_config_rule_custom(
            slice_request.slice_config.config_rules, CANDIDATE_RESOURCE_KEY, fields
        )
        slice_request.slice_status.slice_status = SliceStatusEnum.SLICESTATUS_PLANNED
        return slice_request

    @staticmethod
    def delete_connection_group(
        slice_uuid: str, connection_group_id: str, context_client: ContextClient
    ) -> Slice:
        # slice_request = get_slice_by_uuid(context_client, slice_uuid)
        slice_request = get_slice_by_defualt_name(
            context_client, slice_uuid, rw_copy=False
        )
        slice_config = slice_request.slice_config
        cr = get_custom_config_rule(slice_config, CANDIDATE_RESOURCE_KEY)
        candidate_ietf_data = json.loads(cr.custom.resource_value)
        slice_services = candidate_ietf_data["network-slice-services"]["slice-service"]
        slice_service = slice_services[0]
        slice_connection_groups = slice_service["connection-groups"]["connection-group"]
        sdp_idx = list(
            (
                slice_cr["id"] == connection_group_id
                for slice_cr in slice_connection_groups
            )
        ).index(True)
        removed_connection_group = slice_connection_groups.pop(sdp_idx)
        fields = {
            name: (value, RAISE_IF_DIFFERS)
            for name, value in candidate_ietf_data.items()
        }
        update_config_rule_custom(
            slice_request.slice_config.config_rules, CANDIDATE_RESOURCE_KEY, fields
        )
        slice_request.slice_status.slice_status = SliceStatusEnum.SLICESTATUS_PLANNED
        return slice_request

    @staticmethod
    def create_match_criteria(
        request_data: dict, slice_name: str, sdp_id: str, context_client: ContextClient
    ) -> Slice:
        match_criteria = request_data["match-criterion"]
        if len(match_criteria) != 1:
            raise Exception("Number of SDPs should be 1")
        new_match_criterion = match_criteria[0]
        target_connection_group_id = new_match_criterion["target-connection-group-id"]
        # slice_request = get_slice_by_uuid(context_client, slice_id)
        slice_request = get_slice_by_defualt_name(
            context_client, slice_name, rw_copy=False
        )
        for cr in slice_request.slice_config.config_rules:
            if cr.WhichOneof("config_rule") != "custom":
                continue
            if cr.custom.resource_key == CANDIDATE_RESOURCE_KEY:
                ietf_data = json.loads(cr.custom.resource_value)
                break
        else:
            raise Exception("ietf data not found")
        slice_services = ietf_data["network-slice-services"]["slice-service"]
        slice_service = slice_services[0]
        slice_id = slice_service["id"]
        sdps = slice_service["sdps"]["sdp"]
        connection_groups = slice_service["connection-groups"]["connection-group"]
        slice_request.slice_status.slice_status = SliceStatusEnum.SLICESTATUS_PLANNED
        list_endpoints = []
        for sdp in sdps:
            if (
                sdp["service-match-criteria"]["match-criterion"][0][
                    "target-connection-group-id"
                ]
                == target_connection_group_id
            ):
                attachment_circuits = sdp["attachment-circuits"]["attachment-circuit"]
                if len(attachment_circuits) != 1:
                    raise Exception("All SDPs should have 1 attachment-circuit")
                endpoint = EndPointId()
                endpoint.topology_id.context_id.context_uuid.uuid = DEFAULT_CONTEXT_NAME
                endpoint.device_id.device_uuid.uuid = sdp["node-id"]
                endpoint.endpoint_uuid.uuid = attachment_circuits[0]["ac-tp-id"]
                list_endpoints.append(endpoint)
                break
        else:
            raise Exception("Second SDP not found")
        for sdp in sdps:
            if sdp["id"] == sdp_id:
                sdp["service-match-criteria"]["match-criterion"].append(
                    new_match_criterion
                )
                attachment_circuits = sdp["attachment-circuits"]["attachment-circuit"]
                if len(attachment_circuits) != 1:
                    raise Exception("All SDPs should have 1 attachment-circuit")
                endpoint = EndPointId()
                endpoint.topology_id.context_id.context_uuid.uuid = DEFAULT_CONTEXT_NAME
                endpoint.device_id.device_uuid.uuid = sdp["node-id"]
                endpoint.endpoint_uuid.uuid = attachment_circuits[0]["ac-tp-id"]
                list_endpoints.append(endpoint)
                break
        else:
            raise Exception("SDP not found")
        list_constraints = []
        for cg in connection_groups:
            if cg["id"] != target_connection_group_id:
                continue
            metric_bounds = cg["connectivity-construct"][0]["service-slo-sle-policy"][
                "slo-policy"
            ]["metric-bound"]
            for metric in metric_bounds:
                if metric["metric-type"] == "ietf-nss:one-way-delay-maximum":
                    constraint = Constraint()
                    constraint.sla_latency.e2e_latency_ms = float(metric["bound"])
                    list_constraints.append(constraint)
                elif metric["metric-type"] == "ietf-nss:one-way-bandwidth":
                    constraint = Constraint()
                    constraint.sla_capacity.capacity_gbps = (
                        float(metric["bound"]) / 1.0e3
                    )
                    list_constraints.append(constraint)
            break
        else:
            raise Exception("connection group not found")
        del slice_request.slice_constraints[:]
        slice_request.slice_constraints.extend(list_constraints)
        fields = {name: (value, RAISE_IF_DIFFERS) for name, value in ietf_data.items()}
        update_config_rule_custom(
            slice_request.slice_config.config_rules, CANDIDATE_RESOURCE_KEY, fields
        )
        return slice_request

    @staticmethod
    def delete_match_criteria(
        slice_uuid: str,
        sdp_id: str,
        match_criterion_id: int,
        context_client: ContextClient,
    ) -> Slice:
        # slice_request = get_slice_by_uuid(context_client, slice_uuid)
        slice_request = get_slice_by_defualt_name(
            context_client, slice_uuid, rw_copy=False
        )
        for cr in slice_request.slice_config.config_rules:
            if cr.WhichOneof("config_rule") != "custom":
                continue
            if cr.custom.resource_key == CANDIDATE_RESOURCE_KEY:
                ietf_data = json.loads(cr.custom.resource_value)
                break
        else:
            raise Exception("ietf data not found")
        slice_services = ietf_data["network-slice-services"]["slice-service"]
        slice_service = slice_services[0]
        sdps = slice_service["sdps"]["sdp"]
        for sdp in sdps:
            if sdp["id"] == sdp_id:
                match_criteria = sdp["service-match-criteria"]["match-criterion"]
                match_criterion_idx = [
                    match_criterion["index"] == match_criterion_id
                    for match_criterion in match_criteria
                ].index(True)
                del match_criteria[match_criterion_idx]
                break
        else:
            raise Exception("Second SDP not found")
        fields = {name: (value, RAISE_IF_DIFFERS) for name, value in ietf_data.items()}
        update_config_rule_custom(
            slice_request.slice_config.config_rules, CANDIDATE_RESOURCE_KEY, fields
        )
        return slice_request

    @staticmethod
    def copy_candidate_ietf_slice_data_to_running(
        slice_uuid: str, context_client: ContextClient
    ) -> Slice:
        # slice_request = get_slice_by_uuid(context_client, slice_uuid)
        slice_request = get_slice_by_defualt_name(
            context_client, slice_uuid, rw_copy=False
        )
        for cr in slice_request.slice_config.config_rules:
            if (
                cr.WhichOneof("config_rule") == "custom"
                and cr.custom.resource_key == CANDIDATE_RESOURCE_KEY
            ):
                candidate_resource_value_dict = json.loads(cr.custom.resource_value)
                fields = {
                    name: (value, RAISE_IF_DIFFERS)
                    for name, value in candidate_resource_value_dict.items()
                }
                break
        else:
            raise Exception("candidate ietf slice data not found")
        update_config_rule_custom(
            slice_request.slice_config.config_rules, RUNNING_RESOURCE_KEY, fields
        )
        return slice_request
