# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pip install psycopg==3.1.6
# Ref: https://www.cockroachlabs.com/docs/stable/changefeed-for.html
# (current implementation) Ref: https://www.cockroachlabs.com/docs/v22.1/changefeed-for
# Ref: https://www.psycopg.org/psycopg3/docs/api/crdb.html

import contextlib, json, logging, psycopg, psycopg.conninfo, psycopg.crdb, sys, time
from typing import Any, Dict, Iterator, List, Optional, Tuple
from common.Settings import get_setting

LOGGER = logging.getLogger(__name__)

SQL_ACTIVATE_CHANGE_FEED = 'SET CLUSTER SETTING kv.rangefeed.enabled = true'
SQL_START_CHANGE_FEED = 'EXPERIMENTAL CHANGEFEED FOR {:s}.{:s} WITH format=json, no_initial_scan, updated'

class ChangeFeedClient:
    def __init__(self) -> None:
        self._connection : Optional[psycopg.crdb.CrdbConnection] = None
        self._conn_info_dict : Dict = dict()
        self._is_crdb : bool = False

    def initialize(self) -> bool:
        crdb_uri = get_setting('CRDB_URI')
        if crdb_uri is None:
            LOGGER.error('Connection string not found in EnvVar CRDB_URI')
            return False

        try:
            crdb_uri = crdb_uri.replace('cockroachdb://', 'postgres://')
            self._conn_info_dict = psycopg.conninfo.conninfo_to_dict(crdb_uri)
        except psycopg.ProgrammingError:
            LOGGER.exception('Invalid connection string: {:s}'.format(str(crdb_uri)))
            return False

        self._connection = psycopg.crdb.connect(**self._conn_info_dict)
        self._is_crdb = psycopg.crdb.CrdbConnection.is_crdb(self._connection)
        LOGGER.debug('is_crdb = {:s}'.format(str(self._is_crdb)))

        # disable multi-statement transactions
        self._connection.autocommit = True

        # activate change feeds
        self._connection.execute(SQL_ACTIVATE_CHANGE_FEED)

        return self._is_crdb

    def get_changes(self, table_name : str) -> Iterator[Tuple[float, str, List[Any], bool, Dict]]:
        db_name = self._conn_info_dict.get('dbname')
        if db_name is None: raise Exception('ChangeFeed has not been initialized!')
        cur = self._connection.cursor()
        str_sql_query = SQL_START_CHANGE_FEED.format(db_name, table_name)
        with contextlib.closing(cur.stream(str_sql_query)) as feed:
            for change in feed:
                LOGGER.info(change)
                table_name, primary_key, data = change[0], json.loads(change[1]), json.loads(change[2])
                timestamp = data.get('updated') / 1.e9
                if timestamp is None: timestamp = time.time()
                after = data.get('after')
                is_delete = ('after' in data) and (after is None)
                yield timestamp, table_name, primary_key, is_delete, after

def main():
    logging.basicConfig(level=logging.INFO)

    cf = ChangeFeed()
    ready = cf.initialize()
    if not ready: raise Exception('Unable to initialize ChangeFeed')
    for change in cf.get_changes('context'):
        LOGGER.info(change)

    return 0

if __name__ == '__main__':
    sys.exit(main())
