from openmtc_app.onem2m import XAE
from openmtc_onem2m.model import Container
from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import load_model
from json import load as json_load


class NLP_AE(XAE):
    remove_registration = True
    remote_cse = '/mn-cse-1/onem2m'

    EMB_SIZE = 50

    def _on_register(self):
        # Load models
        self.model = load_model('models/lstm_model')
        with open('tokenizers/tokenizer.json') as f:
            data = json_load(f)
            self.tokenizer_obj = tokenizer_from_json(data)

        # init base structure
        label = 'processed_texts'
        container = Container(resourceName=label)
        self._clean_text_container = self.create_container(
            None,
            container,
            labels=[label, 'cleaned', 'cleaned_tweet'],
            max_nr_of_instances=0
        )
        print('Container created', self._clean_text_container.path)

        print('Subscribing to container')
        sub = self.add_container_subscription(self._clean_text_container, self.handle_processed_text)
        # print(sub)

        # trigger periodically, default=1000
        self.run_forever(30)

        # log message
        self.logger.debug('registered')

    def handle_processed_text(self, container, value):
        # print('Received data...')
        # print('container: %s' % container)
        # print('value: %s' % value)
        # print('')

        text = value['clean_text']

        # print('Received data. Classifying...')
        self.logger.info('Received data: %s. Classifying...', value['text'])
        sequences = self.tokenizer_obj.texts_to_sequences([text])
        tweet_pad = pad_sequences(sequences, maxlen=self.EMB_SIZE, truncating='post', padding='post')
        y_pre = self.model.predict(tweet_pad)

        # print('Classification:', int(y_pre[0][0] * 100), '% occurrence')
        self.logger.info('Classification: %d%% occurrence!', int(y_pre[0][0] * 100))
