#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/err.h>

#include "loggers.hh"

#include "xml_converters.hh"

#include "TTCN3.hh"

#include "sha1.hh"
#include "sha256.hh"
#include "certs_loader.hh"
#include "security_services.hh"

certs_db* security_services::_certs_db = nullptr;

int security_services::initialize(const std::string& p_certs_db_path) {
  loggers::get_instance().log(">>> security_services::initialize: '%s'", p_certs_db_path.c_str());

  security_services::_certs_db = new certs_db(p_certs_db_path);
  if (security_services::_certs_db == nullptr) { // Memory allocation issue
    loggers::get_instance().warning("security_services::setup: _security_db pointer is NULL");
    return -1;
  }

  loggers::get_instance().log("<<< security_services::initialize");
  return 0;
}

int security_services::load_certificate(const std::string& p_certificate_name, const std::string& p_private_key_name, const X509** p_certificate) {
  loggers::get_instance().log("security_services::load_certificate: '%s'", p_certificate_name.c_str());
  loggers::get_instance().log("security_services::load_certificate: '%s'", p_private_key_name.c_str());

  return _certs_db->get_certificate(p_certificate_name, p_private_key_name, p_certificate);
}

int security_services::do_sign(const OCTETSTRING& p_encoded_message, const CHARSTRING& p_certificate_name, OCTETSTRING& p_signature, OCTETSTRING& p_digest, CHARSTRING& p_x509_certificate_subject, CHARSTRING& p_x509_certificate_pem) {
  loggers::get_instance().log_msg(">>> security_services::do_sign: ", p_encoded_message);
  loggers::get_instance().log_msg(">>> security_services::do_sign: ", p_certificate_name);

  // Canonicalization
  std::string canonicalized;
  xml_converters::get_instance().xml_canonicalization(std::string((const char*)(static_cast<const unsigned char*>(p_encoded_message)), p_encoded_message.lengthof()), canonicalized);
  OCTETSTRING encoded_message(char2oct(CHARSTRING(canonicalized.c_str())));
  loggers::get_instance().log("security_services::do_sign: canonicalized: '%s", canonicalized.c_str());

  // Retrieve certificate
  std::string certificate_id;
  if (certs_loader::get_instance().get_certificate_id(std::string(static_cast<const char*>(p_certificate_name)), certificate_id) != 0) {
    loggers::get_instance().warning("fx__sign: Failed to retrieve certificate identifier");
    return -1;  
  }
  loggers::get_instance().log("fx__sign: certificate_id: '%s'", certificate_id.c_str());
  const X509* certificate;
  if (_certs_db->get_certificate(certificate_id, &certificate) != 0) {
    loggers::get_instance().warning("fx__sign: Failed to retrieve certificate");
    return -1;  
  }
  X509_NAME* sn = X509_get_subject_name((X509*)certificate);
  std::string subject(512, (char)0x00);
  X509_NAME_oneline(sn, (char*)subject.c_str(), subject.length());
  loggers::get_instance().log("fx__sign: certificate_id: X509_NAME_oneline: '%s'", subject.c_str());
  p_x509_certificate_subject = CHARSTRING(subject.c_str());
  X509_free((X509*)certificate);

  std::string str;
  if (_certs_db->get_certificate_pem(certificate_id, str) != 0) {
    loggers::get_instance().warning("fx__sign: Failed to retrieve certificate PEM");
    return -1;  
  }
  p_x509_certificate_pem = CHARSTRING(str.c_str());

  // Compute the digest
  sha256 digest;
  digest.generate(char2oct(CHARSTRING(canonicalized.c_str())), p_digest);
  loggers::get_instance().log_msg("fx__sign: certificate_id: p_digest: ", p_digest);

  // Retrive the private key
  const EVP_PKEY* private_key;
  int ret = _certs_db->get_private_key(certificate_id, &private_key);
  if (ret == 1) {
    loggers::get_instance().warning("fx__sign: Failed to retrieve private key");
    return -1;  
  }
  loggers::get_instance().log("fx__sign: certificate_id: private_key: '%p'", private_key);

  // Create signing context
  EVP_MD_CTX* ctx = ::EVP_MD_CTX_new();
  if (ctx == NULL) {
    loggers::get_instance().warning("fx__sign: EVP_MD_CTX_create failed, error 0x%lx", ::ERR_get_error());
    return -1;
  }
  if (::EVP_DigestSignInit(ctx, NULL, EVP_sha256(), NULL, (EVP_PKEY*)private_key) != 1) { // FIXME Add parameter to chose the digest algorithm
    loggers::get_instance().warning("fx__sign: EVP_DigestSignInit failed, error 0x%lx", ::ERR_get_error());
    ::EVP_MD_CTX_free(ctx);
    return -1;
  }
  if (::EVP_DigestSignUpdate(ctx, (const unsigned char*)canonicalized.c_str(), canonicalized.length()) != 1) {
    loggers::get_instance().warning("fx__sign: EVP_DigestSignUpdate failed, error 0x%lx", ::ERR_get_error());
    ::EVP_MD_CTX_free(ctx);
    return -1;
  }
  // Get signature lengthof
  size_t signature_length = 0;
  if (::EVP_DigestSignFinal(ctx, NULL, &signature_length) != 1) {
    loggers::get_instance().warning("fx__sign: EVP_DigestSignUpdate failed, error 0x%lx", ::ERR_get_error());
    ::EVP_MD_CTX_free(ctx);
    return -1;
  }
  loggers::get_instance().log("fx__sign: signature_length: %d", signature_length);
  std::vector<unsigned char> s(signature_length, 0x00);
  unsigned char* val = s.data();
  if (::EVP_DigestSignFinal(ctx, val, &signature_length) != 1) {
    loggers::get_instance().warning("fx__sign: EVP_DigestSignUpdate failed, error 0x%lx", ::ERR_get_error());
    ::EVP_MD_CTX_free(ctx);
    s.clear();
    return -1;
  }

  p_signature = OCTETSTRING(signature_length, s.data());

  ::EVP_MD_CTX_free(ctx);
  loggers::get_instance().log_msg("fx__sign: signature: ", p_signature);

  return 0;
}