# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from typing import Dict, List, Optional, Set, Tuple
from common.proto.context_pb2 import Link
from common.proto.pathcomp_pb2 import Algorithm_KDisjointPath, Algorithm_KShortestPath, PathCompReply
from ._Algorithm import _Algorithm
from .KShortestPathAlgorithm import KShortestPathAlgorithm

class KDisjointPathAlgorithm(_Algorithm):
    def __init__(self, algorithm : Algorithm_KDisjointPath, class_name=__name__) -> None:
        super().__init__('KDP', False, class_name=class_name)
        self.num_disjoint = algorithm.num_disjoint

    def get_link_from_endpoint(self, endpoint : Dict) -> Tuple[Dict, Link]:
        device_uuid = endpoint['device_id']
        endpoint_uuid = endpoint['endpoint_uuid']
        item = self.endpoint_to_link_dict.get((device_uuid, endpoint_uuid))
        if item is None:
            MSG = 'Link for Endpoint({:s}, {:s}) not found'
            self.logger.warning(MSG.format(device_uuid, endpoint_uuid))
            return None
        return item

    def path_to_links(self, path_endpoints : List[Dict]) -> Tuple[List[Dict], Set[str]]:
        path_links = list()
        path_link_ids = set()
        for endpoint in path_endpoints:
            link_tuple = self.get_link_from_endpoint(endpoint)
            if link_tuple is None: continue
            json_link,_ = link_tuple
            json_link_id = json_link['link_Id']
            if len(path_links) == 0 or path_links[-1]['link_Id'] != json_link_id:
                path_links.append(json_link)
                path_link_ids.add(json_link_id)
        return path_links, path_link_ids

    def remove_traversed_links(self, link_list : List[Dict], path_endpoints : List[Dict]):
        _, path_link_ids = self.path_to_links(path_endpoints)
        new_link_list = list(filter(lambda l: l['link_Id'] not in path_link_ids, link_list))
        self.logger.info('cur_link_list = {:s}'.format(str(link_list)))
        self.logger.info('new_link_list = {:s}'.format(str(new_link_list)))
        return new_link_list

    def execute(self, dump_request_filename: Optional[str] = None, dump_reply_filename: Optional[str] = None) -> None:
        algorithm = KShortestPathAlgorithm(Algorithm_KShortestPath(k_inspection=0, k_return=1))
        algorithm.sync_paths = True
        algorithm.device_list = self.device_list
        algorithm.device_dict = self.device_dict
        algorithm.endpoint_dict = self.endpoint_dict
        algorithm.link_list = self.link_list
        algorithm.link_dict = self.link_dict
        algorithm.endpoint_to_link_dict = self.endpoint_to_link_dict
        algorithm.service_list = self.service_list
        algorithm.service_dict = self.service_dict

        Path = List[Dict]
        Path_NoPath = Optional[Path] # None = no path, list = path
        self.json_reply : Dict[Tuple[str, str], List[Path_NoPath]] = dict()

        for num_path in range(self.num_disjoint):
            #dump_request_filename = 'ksp-{:d}-request.json'.format(num_path)
            #dump_reply_filename   = 'ksp-{:d}-reply.txt'.format(num_path)
            #algorithm.execute(dump_request_filename, dump_reply_filename)
            algorithm.execute()
            response_list = algorithm.json_reply.get('response-list', [])
            for response in response_list:
                service_id = response['serviceId']
                service_key = (service_id['contextId'], service_id['service_uuid'])
                json_reply_service = self.json_reply.setdefault(service_key, list())

                no_path_issue = response.get('noPath', {}).get('issue')
                if no_path_issue is not None:
                    json_reply_service.append(None)
                    continue

                path_endpoints = response['path'][0]['devices']
                json_reply_service.append(path_endpoints)
                algorithm.link_list = self.remove_traversed_links(algorithm.link_list, path_endpoints)

        self.logger.info('self.json_reply = {:s}'.format(str(self.json_reply)))

    def get_reply(self) -> PathCompReply:
        reply = PathCompReply()
        for service_key,paths in self.json_reply.items():
            grpc_service = self.add_service_to_reply(reply, service_key[0], service_key[1])
            for path_endpoints in paths:
                if path_endpoints is None: continue
                grpc_connection = self.add_connection_to_reply(reply, grpc_service)
                self.add_path_to_connection(grpc_connection, path_endpoints)
        return reply
