#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/err.h>

#include "converter.hh"

#include "loggers.hh"

#include "xml_converters.hh"

#include "TTCN3.hh"

#include "sha1.hh"
#include "sha256.hh"
#include "certs_loader.hh"
#include "security_services.hh"

certs_db* security_services::_certs_db = nullptr;

int security_services::initialize(const std::string& p_certs_db_path) {
  loggers::get_instance().log(">>> security_services::initialize: '%s'", p_certs_db_path.c_str());

  security_services::_certs_db = new certs_db(p_certs_db_path);
  if (security_services::_certs_db == nullptr) { // Memory allocation issue
    loggers::get_instance().warning("security_services::setup: _security_db pointer is NULL");
    return -1;
  }

  loggers::get_instance().log("<<< security_services::initialize");
  return 0;
}

int security_services::load_certificate(const std::string& p_certificate_name, const std::string& p_private_key_name, const std::string& p_private_key_passwd, const X509** p_certificate) {
  loggers::get_instance().log("security_services::load_certificate: '%s'", p_certificate_name.c_str());
  loggers::get_instance().log("security_services::load_certificate: '%s'", p_private_key_name.c_str());
  loggers::get_instance().log("security_services::load_certificate: '%s'", p_private_key_passwd.c_str());

  return _certs_db->get_certificate(p_certificate_name, p_private_key_name, p_private_key_passwd, p_certificate);
}

int security_services::do_sign(const OCTETSTRING& p_encoded_message, const CHARSTRING& p_certificate_name, const CHARSTRING& p_private_key_name, const CHARSTRING& p_private_key_passwd, OCTETSTRING& p_signature, OCTETSTRING& p_digest, CHARSTRING& p_x509_certificate_subject, CHARSTRING& p_x509_certificate_pem, CHARSTRING& p_pull_request_signed_canonicalized) {
  loggers::get_instance().log_msg(">>> security_services::do_sign: ", p_encoded_message);
  loggers::get_instance().log_msg(">>> security_services::do_sign: ", p_certificate_name);
  loggers::get_instance().log_msg(">>> security_services::do_sign: ", p_private_key_name);
  loggers::get_instance().log_msg(">>> security_services::do_sign: ", p_private_key_passwd);

  // Canonicalization
  std::string canonicalized;
  xml_converters::get_instance().xml_canonicalization(std::string((const char*)(static_cast<const unsigned char*>(p_encoded_message)), p_encoded_message.lengthof()), canonicalized);
  p_pull_request_signed_canonicalized = CHARSTRING(canonicalized.c_str());
  OCTETSTRING encoded_message(char2oct(p_pull_request_signed_canonicalized));
  loggers::get_instance().log_msg("security_services::do_sign: p_pull_request_signed_canonicalized: ", p_pull_request_signed_canonicalized);

  // Retrieve certificate
  std::string certificate_id;
  if (certs_loader::get_instance().get_certificate_id(std::string(static_cast<const char*>(p_certificate_name)), certificate_id) != 0) {
    loggers::get_instance().warning("security_services::do_sign: Failed to retrieve certificate identifier");
    return -1;  
  }
  loggers::get_instance().log("security_services::do_sign: certificate_id: '%s'", certificate_id.c_str());
  const X509* certificate;
  if (_certs_db->get_certificate(certificate_id, &certificate) != 0) {
    loggers::get_instance().warning("security_services::do_sign: Failed to retrieve certificate");
    return -1;  
  }
  loggers::get_instance().log("security_services::do_sign: certificate dump: '%s'", _certs_db->cert_to_string(certificate_id).c_str());
  X509_NAME* sn = X509_get_subject_name((X509*)certificate);
  std::string subject(512, (char)0x00);
  ::X509_NAME_oneline(sn, (char*)subject.c_str(), subject.length());
  loggers::get_instance().log("security_services::do_sign: X509_NAME_oneline: '%s'", subject.c_str());
  p_x509_certificate_subject = CHARSTRING(subject.c_str());

  std::string str;
  if (_certs_db->get_certificate_pem(certificate_id, str) != 0) {
    loggers::get_instance().warning("security_services::do_sign: Failed to retrieve certificate PEM");
    return -1;  
  }
  p_x509_certificate_pem = CHARSTRING(str.c_str());

  // Compute the digest
  sha256 digest;
  digest.generate(char2oct(CHARSTRING(canonicalized.c_str())), p_digest);
  loggers::get_instance().log_msg("security_services::do_sign: digest: ", p_digest);

  // Retrive the private key
  const EVP_PKEY* private_key;
  int ret = _certs_db->get_private_key(certificate_id, &private_key);
  if (ret == 1) {
    loggers::get_instance().warning("security_services::do_sign: Failed to retrieve private key");
    return -1;  
  }
  loggers::get_instance().log("security_services::do_sign: private_key: '%p'", private_key);

  // Create signing context
  EVP_MD_CTX* ctx = ::EVP_MD_CTX_new();
  if (ctx == NULL) {
    loggers::get_instance().warning("security_services::do_sign: EVP_MD_CTX_create failed, error 0x%lx", ::ERR_get_error());
    return -1;
  }
  if (::EVP_DigestSignInit(ctx, NULL, EVP_sha256(), NULL, (EVP_PKEY*)private_key) != 1) { // FIXME Add parameter to chose the digest algorithm
    loggers::get_instance().warning("security_services::do_sign: EVP_DigestSignInit failed, error 0x%lx", ::ERR_get_error());
    ::EVP_MD_CTX_free(ctx);
    return -1;
  }
  if (::EVP_DigestSignUpdate(ctx, (const unsigned char*)canonicalized.c_str(), canonicalized.length()) != 1) {
    loggers::get_instance().warning("security_services::do_sign: EVP_DigestSignUpdate failed, error 0x%lx", ::ERR_get_error());
    ::EVP_MD_CTX_free(ctx);
    return -1;
  }
  // Get signature lengthof
  size_t signature_length = 0;
  if (::EVP_DigestSignFinal(ctx, NULL, &signature_length) != 1) {
    loggers::get_instance().warning("security_services::do_sign: EVP_DigestSignUpdate failed, error 0x%lx", ::ERR_get_error());
    ::EVP_MD_CTX_free(ctx);
    return -1;
  }
  if (signature_length != ::EVP_PKEY_size(private_key)) {
    loggers::get_instance().warning("security_services::do_sign: Wrong signature length");
    ::EVP_MD_CTX_free(ctx);
    return -1; 
  }
  loggers::get_instance().log("security_services::do_sign: signature_length: %d", signature_length);
  std::vector<unsigned char> s(signature_length, 0x00);
  if (::EVP_DigestSignFinal(ctx, s.data(), &signature_length) != 1) {
    loggers::get_instance().warning("security_services::do_sign: EVP_DigestSignUpdate failed, error 0x%lx", ::ERR_get_error());
    ::EVP_MD_CTX_free(ctx);
    s.clear();
    return -1;
  }
  ::EVP_MD_CTX_free(ctx);

  // Ensure that the signature round-trips
  // Retrive the public key
  const EVP_PKEY* public_key;
  if (_certs_db->get_private_key(certificate_id, &public_key) == 1) {
    loggers::get_instance().warning("security_services::do_sign: Failed to retrieve private key");
    return -1;  
  }
  loggers::get_instance().log("security_services::do_sign: public_key: '%p'", public_key);
  std::vector<unsigned char> buffer;
  _certs_db->publickey_to_string(public_key, buffer);
  ctx = ::EVP_MD_CTX_new();
  if (!::EVP_DigestVerifyInit(ctx, NULL, EVP_sha256(), NULL, (EVP_PKEY*)public_key) ||
      !::EVP_DigestVerifyUpdate(ctx, (const unsigned char*)canonicalized.c_str(), canonicalized.length()) ||
      !::EVP_DigestVerifyFinal(ctx, s.data(), s.size())) {
    loggers::get_instance().warning("security_services::do_sign: Failed to verify signature, error 0x%lx", ::ERR_get_error());
    ::EVP_MD_CTX_free(ctx);
    s.clear();
    return -1;
  }
  ::EVP_MD_CTX_free(ctx);
  loggers::get_instance().log("security_services::do_sign: signature was verified with public key");

  p_signature = OCTETSTRING(signature_length, s.data());

  loggers::get_instance().log_msg("<<< security_services::do_sign: signature: ", p_signature);

  return 0;
}

bool security_services::do_sign_verify(const CHARSTRING& p_message, const UNIVERSAL_CHARSTRING& p_canonicalization_method, const UNIVERSAL_CHARSTRING& p_signature_method, const UNIVERSAL_CHARSTRING& p_digest_method, const UNIVERSAL_CHARSTRING& p_digest_value, const UNIVERSAL_CHARSTRING& p_signature_value, const UNIVERSAL_CHARSTRING& p_subject_name, const UNIVERSAL_CHARSTRING& p_certificate, const CHARSTRING& p_debug_message) {
  loggers::get_instance().log_msg(">>> security_services::do_sign_verify: p_message:       ", p_message);
  loggers::get_instance().log_msg(">>> security_services::do_sign_verify: p_debug_message: ", p_debug_message);
  loggers::get_instance().log_msg(">>> security_services::do_sign_verify: p_canonicalization_method: ", p_canonicalization_method);
  loggers::get_instance().log_msg(">>> security_services::do_sign_verify: p_subject_name: ", p_subject_name);
  loggers::get_instance().log_msg(">>> security_services::do_sign_verify: p_signature_method: ", p_signature_method);

  // Compute the digest
  sha256 digest;
  OCTETSTRING dg;
  digest.generate(char2oct(p_message), dg);
  loggers::get_instance().log_msg("security_services::do_sign_verify: digest: ", dg);

  // Retrieve certificate
  std::string sn(static_cast<const char*>(unichar2char(p_subject_name)));
  std::string certificate_id;
  const X509* certificate;
  if (_certs_db->get_certificate_by_subject_name(sn, &certificate) != 0) {
    loggers::get_instance().warning("security_services::do_sign_verify: Failed to retrieve certificate");
    // Use provided certificate if any
    loggers::get_instance().log("security_services::do_sign_verify: sn='%s'", sn.c_str());
    size_t i = sn.find("CN=");
    if (i == std::string::npos) {
      loggers::get_instance().warning("security_services::do_sign_verify: Failed to extract certificate name");
      return false;  
    }
    size_t j = sn.find(",", i + 3);
    std::string certificate_name;
    if (j == std::string::npos) {
      certificate_name = sn.substr(i + 3);
    } else {
      certificate_name = sn.substr(i + 3, j - i - 3);
    }
    certificate_name = converter::get_instance().replace(certificate_name, "-", "_");
    loggers::get_instance().log("security_services::do_sign_verify: certificate name: '%s'", certificate_name.c_str());
    const certs_db_record* record;
    if (_certs_db->store_certificate(certificate_name, std::string(static_cast<const char*>(unichar2char(p_certificate))), certificate_id, &record) == -1) {
      loggers::get_instance().warning("security_services::do_sign_verify: Failed to store certificate");
      return false;  
    }
    loggers::get_instance().log("security_services::do_sign_verify: certificate id: '%s'", record->certificate_id().c_str());
  }
  loggers::get_instance().log("security_services::do_sign_verify: certificate dump: '%s'", _certs_db->cert_to_string(certificate_id).c_str());

  // Retrive the public key
  const EVP_PKEY* public_key;
  int ret = _certs_db->get_public_keys(certificate_id, &public_key);
  if (ret == 1) {
    loggers::get_instance().warning("security_services::do_sign_verify: Failed to retrieve public key");
    return false;  
  }
  loggers::get_instance().log("security_services::do_sign_verify: public_key: '%p'", public_key);
  std::vector<unsigned char> buffer;
  _certs_db->publickey_to_string(public_key, buffer);

  // Create signing context
  EVP_MD_CTX* ctx = ::EVP_MD_CTX_new();
  if (ctx == NULL) {
    loggers::get_instance().warning("security_services::do_sign_verify: EVP_MD_CTX_create failed, error 0x%lx", ::ERR_get_error());
    return false;
  }
  std::string signature_method(static_cast<const char*>(unichar2char(p_signature_method))); // E.g. http://www.w3.org/2000/09/xmldsig#rsa-sha1
  const EVP_MD *md;
  if (signature_method.find("sha1") != std::string::npos) {
   loggers::get_instance().log("security_services::do_sign_verify: signature method: sha1");
   md = EVP_sha1();
  } else if (signature_method.find("sha256") != std::string::npos) {
   loggers::get_instance().log("security_services::do_sign_verify: signature method: sha256");
    md = EVP_sha256();
  } else if (signature_method.find("sha384") != std::string::npos) {
   loggers::get_instance().log("security_services::do_sign_verify: signature method: sha384");
    md = EVP_sha384();
  } else {
    loggers::get_instance().error("security_services::do_sign_verify: Unsupported signature method");
    return false;   
  }
  // Check signature
  if (::EVP_DigestVerifyInit(ctx, NULL, md, NULL, (EVP_PKEY*)public_key) != 1) {
    loggers::get_instance().warning("security_services::do_sign_verify: EVP_DigestVerifyInit failed, error 0x%lx", ::ERR_get_error());
    ::EVP_MD_CTX_free(ctx);
    return false;
  }
  if(::EVP_DigestVerifyUpdate(ctx, (const char*)p_message, p_message.lengthof()) != 1) {
    loggers::get_instance().warning("security_services::do_sign_verify: EVP_DigestVerifyUpdate failed, error 0x%lx - %s", ::ERR_get_error(), ::ERR_error_string(::ERR_get_error(), nullptr));
    ::EVP_MD_CTX_free(ctx);
    return false;
  }
  std::string signature_value(static_cast<const char*>(unichar2char(p_signature_value)));
  // Remove CR/LF if any
  //signature_value.erase(std::remove(signature_value.begin(), signature_value.end(), '\r'), signature_value.end());
  //signature_value.erase(std::remove(signature_value.begin(), signature_value.end(), '\n'), signature_value.end());
  loggers::get_instance().log("security_services::do_sign_verify: Before B64 decoded: '%s'", signature_value.c_str());
  // Convert into bytes buffer
  const std::vector<unsigned char> to_decode((const unsigned char*)signature_value.c_str(), static_cast<const unsigned char*>((const unsigned char*)signature_value.c_str() + signature_value.length()));
  std::vector<unsigned char> b64 = converter::get_instance().base64_to_buffer(to_decode);
  loggers::get_instance().log("security_services::do_sign_verify: B64 decoded: '%s'", converter::get_instance().bytes_to_hexa(b64, true).c_str());
  loggers::get_instance().log("security_services::do_sign_verify: B64 decoded len: '%d'", b64.size());
  if(::EVP_DigestVerifyFinal(ctx, b64.data(), b64.size()) != 1) {
    loggers::get_instance().warning("security_services::do_sign_verify: EVP_DigestVerifyFinal failed, error 0x%lx - %s", ::ERR_get_error(), ::ERR_error_string(::ERR_get_error(), nullptr));
    ::EVP_MD_CTX_free(ctx);
    return false;
  }

  ::EVP_MD_CTX_free(ctx);

  return true;
}
