Skip to content
Snippets Groups Projects
conftest.py 7.24 KiB
Newer Older
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json, os, pytest, sqlalchemy
from _pytest.config import Config
from _pytest.terminal import TerminalReporter
from prettytable import PrettyTable
from typing import Any, Dict, List, Tuple
from common.Constants import ServiceNameEnum
from common.Settings import (
    ENVVAR_SUFIX_SERVICE_HOST, ENVVAR_SUFIX_SERVICE_PORT_GRPC, ENVVAR_SUFIX_SERVICE_PORT_HTTP, get_env_var_name,
    get_service_port_grpc, get_service_port_http)
from common.message_broker.Factory import get_messagebroker_backend, BackendEnum as MessageBrokerBackendEnum
from common.message_broker.MessageBroker import MessageBroker
from context.client.ContextClient import ContextClient
from context.service.ContextService import ContextService
from context.service.Database import Database
from context.service.Engine import Engine
from context.service.database.models._Base import rebuild_database
#from context.service._old_code.Populate import populate
#from context.service.rest_server.RestServer import RestServer
#from context.service.rest_server.Resources import RESOURCES


LOCAL_HOST = '127.0.0.1'
GRPC_PORT = 10000 + int(get_service_port_grpc(ServiceNameEnum.CONTEXT))   # avoid privileged ports
HTTP_PORT = 10000 + int(get_service_port_http(ServiceNameEnum.CONTEXT))   # avoid privileged ports

os.environ[get_env_var_name(ServiceNameEnum.CONTEXT, ENVVAR_SUFIX_SERVICE_HOST     )] = str(LOCAL_HOST)
os.environ[get_env_var_name(ServiceNameEnum.CONTEXT, ENVVAR_SUFIX_SERVICE_PORT_GRPC)] = str(GRPC_PORT)
os.environ[get_env_var_name(ServiceNameEnum.CONTEXT, ENVVAR_SUFIX_SERVICE_PORT_HTTP)] = str(HTTP_PORT)

#DEFAULT_REDIS_SERVICE_HOST = LOCAL_HOST
#DEFAULT_REDIS_SERVICE_PORT = 6379
#DEFAULT_REDIS_DATABASE_ID  = 0

#REDIS_CONFIG = {
#    'REDIS_SERVICE_HOST': os.environ.get('REDIS_SERVICE_HOST', DEFAULT_REDIS_SERVICE_HOST),
#    'REDIS_SERVICE_PORT': os.environ.get('REDIS_SERVICE_PORT', DEFAULT_REDIS_SERVICE_PORT),
#    'REDIS_DATABASE_ID' : os.environ.get('REDIS_DATABASE_ID',  DEFAULT_REDIS_DATABASE_ID ),
#}

#SCENARIOS = [
#    ('db:cockroach_mb:inmemory', None, {}, None, {}),
#    ('all_inmemory', DatabaseBackendEnum.INMEMORY, {},           MessageBrokerBackendEnum.INMEMORY, {}          )
#    ('all_redis',    DatabaseBackendEnum.REDIS,    REDIS_CONFIG, MessageBrokerBackendEnum.REDIS,    REDIS_CONFIG),
#]

#@pytest.fixture(scope='session', ids=[str(scenario[0]) for scenario in SCENARIOS], params=SCENARIOS)
@pytest.fixture(scope='session')
def context_db_mb(request) -> Tuple[sqlalchemy.engine.Engine, MessageBroker]:   # pylint: disable=unused-argument
    #name,db_session,mb_backend,mb_settings = request.param
    #msg = 'Running scenario {:s} db_session={:s}, mb_backend={:s}, mb_settings={:s}...'
    #LOGGER.info(msg.format(str(name), str(db_session), str(mb_backend.value), str(mb_settings)))

    _db_engine = Engine.get_engine()
    Engine.drop_database(_db_engine)
    Engine.create_database(_db_engine)
    rebuild_database(_db_engine)

    _msg_broker = MessageBroker(get_messagebroker_backend(backend=MessageBrokerBackendEnum.INMEMORY))
    yield _db_engine, _msg_broker
    _msg_broker.terminate()

RAW_METRICS = None

@pytest.fixture(scope='session')
def context_service_grpc(context_db_mb : Tuple[Database, MessageBroker]): # pylint: disable=redefined-outer-name
    global RAW_METRICS # pylint: disable=global-statement
    _service = ContextService(context_db_mb[0], context_db_mb[1])
    RAW_METRICS = _service.context_servicer._get_metrics()
    _service.start()
    yield _service
    _service.stop()

#@pytest.fixture(scope='session')
#def context_service_rest(context_db_mb : Tuple[Database, MessageBroker]): # pylint: disable=redefined-outer-name
#    database = context_db_mb[0]
#    _rest_server = RestServer()
#    for endpoint_name, resource_class, resource_url in RESOURCES:
#        _rest_server.add_resource(resource_class, resource_url, endpoint=endpoint_name, resource_class_args=(database,))
#    _rest_server.start()
#    time.sleep(1) # bring time for the server to start
#    yield _rest_server
#    _rest_server.shutdown()
#    _rest_server.join()

@pytest.fixture(scope='session')
def context_client_grpc(
    context_service_grpc : ContextService   # pylint: disable=redefined-outer-name,unused-argument
):
    _client = ContextClient()
    yield _client
    _client.close()

@pytest.hookimpl(hookwrapper=True)
def pytest_terminal_summary(
    terminalreporter : TerminalReporter, exitstatus : int, config : Config  # pylint: disable=unused-argument
):
    yield

    method_to_metric_fields : Dict[str, Dict[str, Dict[str, Any]]]= dict()
    for raw_metric_name,raw_metric_data in RAW_METRICS.items():
        if '_COUNTER_' in raw_metric_name:
            method_name,metric_name = raw_metric_name.split('_COUNTER_')
        elif '_HISTOGRAM_' in raw_metric_name:
            method_name,metric_name = raw_metric_name.split('_HISTOGRAM_')
        else:
            raise Exception('Unsupported metric: {:s}'.format(raw_metric_name))
        metric_data = method_to_metric_fields.setdefault(method_name, dict()).setdefault(metric_name, dict())
        for field_name,labels,value,_,_ in raw_metric_data._child_samples():
            if len(labels) > 0: field_name = '{:s}:{:s}'.format(field_name, json.dumps(labels, sort_keys=True))
            metric_data[field_name] = value
    #print('method_to_metric_fields', method_to_metric_fields)

    def sort_stats_key(item : List) -> float:
        str_duration = str(item[0])
        if str_duration == '---': return 0.0
        return float(str_duration.replace(' ms', ''))

    field_names = ['Method', 'Started', 'Completed', 'Failed', 'avg(Duration)']
    pt_stats = PrettyTable(field_names=field_names, sortby='avg(Duration)', sort_key=sort_stats_key, reversesort=True)
    for f in ['Method']: pt_stats.align[f] = 'l'
    for f in ['Started', 'Completed', 'Failed', 'avg(Duration)']: pt_stats.align[f] = 'r'

    for method_name,metrics in method_to_metric_fields.items():
        counter_started_value = int(metrics['STARTED']['_total'])
        if counter_started_value == 0:
            #pt_stats.add_row([method_name, '---', '---', '---', '---'])
            continue
        counter_completed_value = int(metrics['COMPLETED']['_total'])
        counter_failed_value = int(metrics['FAILED']['_total'])
        duration_count_value = float(metrics['DURATION']['_count'])
        duration_sum_value = float(metrics['DURATION']['_sum'])
        duration_avg_value = duration_sum_value/duration_count_value
        pt_stats.add_row([
            method_name, str(counter_started_value), str(counter_completed_value), str(counter_failed_value),
            '{:.3f} ms'.format(1000.0 * duration_avg_value),
        ])
    print('')
    print('Performance Results:')
    print(pt_stats.get_string())