import os
from os import path
import base64
import copy
import json
import logging
import subprocess
import tempfile
import yaml
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from enum import Enum
import jinja2 as jin
from typing import Any, Dict, List, Optional, Tuple
from central_sdi import storage_api, util

logger = logging.getLogger(__name__)

# Certificate configuration
client_certs_key = "client_certs"
client_cert_cfg_file = 'certrc'
client_cert_cfg_defaults = {
    'client_cert_days': 365,
    'client_cert_size': 2048,
    'client_cert_default_cn': 'nom-sdi.opengear.com'
}


def get_env(key):
    """
    returns an env variable by key raising a ValueError if it wasn't set or is empty
    """
    value = os.environ.get(key)
    if not value:
        raise ValueError(f"required env variable: {key}")
    return value


def ca_cert_days():
    """
    returns CA_DAYS as an int
    """
    result = int(get_env("CA_DAYS"))
    if result <= 0:
        raise ValueError(f"invalid CA_DAYS: {result}")
    return result


def ca_cert_dir():
    """
    returns CA_DIR after validating that directory exists
    """
    result = get_env("CA_DIR")
    if not path.isdir(result):
        raise ValueError(f"invalid CA_DIR: {result}")
    return result


def cert_cfg_default(cfg):
    """
    Apply defaults to global certificate config
    """
    return copy.deepcopy({**client_cert_cfg_defaults, **cfg})


def cert_cfg_load():
    """
    Fetch global certificate config
    """
    try:
        return yaml.safe_load(storage_api.sa_load(client_cert_cfg_file))
    except storage_api.ObjectNotFound:
        return {}


def cert_cfg_store(cfg):
    """
    Update global certificate config
    """
    storage_api.sa_store(client_cert_cfg_file, yaml.dump({**cfg}))


def cert_load(mgr, id) -> Optional["Cert"]:
    """
    Fetch certificate from storage API
    """
    data = json.load(
        storage_api.sa_load(f"{client_certs_key}/{id}"))
    if not 'metadata' in data or not 'key' in data or not 'cert' in data:
        logger.error("Invalid cert data: {}".format(data))
        return None

    return Cert(mgr,
                id,
                data['metadata']['name'],
                data['key'].encode('ascii'),
                data['cert'].encode('ascii'))


def cert_store(cert: "Cert"):
    """
    Store certificate in storage API
    """
    # Store metadata + key + value in the storage API as a single entry
    # to avoid multiple API calls
    sa_output = json.dumps({
        "metadata": {
            "name": cert.name,
        },
        "key": cert.key,
        "cert": cert.public_string
    })
    storage_api.sa_store(f"{client_certs_key}/{cert.id}", sa_output)


def ca_crt_load(mgr: "CertManager") -> "Cert":
    """
    Load Certificate Authority from storage API
    """
    key_data = storage_api.sa_load("ca/ca.key").read()
    cert_data = storage_api.sa_load("ca/ca.crt").read()

    return Cert(mgr, "ca", "ca", key_data, cert_data)


def ca_crl_load() -> "CertRevocationList":
    """
    Load Certificate Revocation List from storage API
    """
    crl_data = storage_api.sa_load("ca/ca.crl").read().decode("ascii")

    return CertRevocationList(crl_data.encode("ascii"))


def sync_local_ca():
    process = subprocess.run("server_save.sh", stdout=subprocess.PIPE, stderr=subprocess.STDOUT, check=False)
    if process.returncode != 0:
        logger.error("Failed to sync Certificate Authority to storage API: '{}'".format(process.stdout))
        return False
    return True


def order_rfc4514_fields(rfc4514):
    # Crypto library doesn't seem to guarantee order of fields, so order them here
    # to match OpenSSL's format: C=AU,ST=Queensland,O=Opengear,OU=NetOps,CN=sdiCA
    field_order = ['C=', 'ST=', 'O=', 'OU=', 'CN=']
    ordered_list = []
    for field in field_order:
        for token in rfc4514.split(','):
            if token.startswith(field):
                ordered_list.append(token)
                break

    return ','.join(ordered_list)


class InvalidCertConfig(Exception):
    pass


class Cert:
    """
    This class represents an individual certificate, with metadata such as cert name, and
    certificate details read from both the public certificate and private key using the
    cryptography library.
    """

    def __init__(self, mgr: "CertManager", id: str, name: str, key_data: bytes, cert_data: bytes):
        # Don't like needing a reference to the manager, but need it for now
        # to provide the revoked property
        self._mgr = mgr
        self._name = name
        self._id = id

        # TODO consider lazy loading cert details for performance
        self._cert = x509.load_pem_x509_certificate(cert_data, default_backend())

        # Don't think we need to actually load the key, we just need the raw contents
        self._key = key_data

    @property
    def mgr(self) -> "CertManager":
        return self._mgr

    @property
    def id(self) -> str:
        return self._id

    @property
    def name(self) -> str:
        return self._name

    @property
    def serial(self) -> str:
        return self.cert.serial_number

    @property
    def public_bytes(self) -> bytes:
        return self.cert.public_bytes(encoding=serialization.Encoding.PEM)

    @property
    def public_string(self) -> str:
        return self.public_bytes.decode("ascii").strip()

    @property
    def fingerprint(self) -> str:
        """
        Return certificate SHA1 fingerprint as user-readable, colon-separated string
        e.g. "49:e1:b6:f5:1f:25:3c:df:a6:b7:a1:91:01:86:77:01:9e:5b:f6:b9"
        """
        fingerprint_str = self.cert.fingerprint(hashes.SHA1()).hex()
        return ':'.join(a + b for a, b in zip(fingerprint_str[::2], fingerprint_str[1::2]))

    @property
    def issuer(self) -> str:
        return order_rfc4514_fields(self.cert.issuer.rfc4514_string())

    @property
    def subject(self) -> str:
        return order_rfc4514_fields(self.cert.subject.rfc4514_string())

    @property
    def not_valid_before(self) -> str:
        return self.cert.not_valid_before.strftime("%b %d %H:%M:%S %Y")  # TODO timezone

    @property
    def not_valid_after(self) -> str:
        return self.cert.not_valid_after.strftime("%b %d %H:%M:%S %Y")

    @property
    def cert(self) -> str:
        return self._cert

    @property
    def revoked(self) -> bool:
        return self.mgr.crl.is_certificate_revoked(self)

    @property
    def key(self) -> str:
        return self._key.decode("ascii").strip()

    def to_dict(self) -> Dict[str, any]:
        return {
            "name": self.name,
            "serial": self.serial,
            "issuer": self.issuer,
            "subject": self.subject,
            "fingerprint": self.fingerprint,
            "not_valid_before": self.not_valid_before,
            "not_valid_after": self.not_valid_after,
            "revoked": self.revoked,
        }


class CertRevocationList:
    def __init__(self, crl_data: bytes):
        self._crl_data = crl_data
        self._crl = None
        self._crl = x509.load_pem_x509_crl(
            self._crl_data,
            default_backend()
        )

    @property
    def crl_data(self) -> bytes:
        return self._crl_data

    @property
    def crl(self):
        return self._crl

    @property
    def issuer(self):
        return order_rfc4514_fields(self._crl.issuer.rfc4514_string())

    @property
    def fingerprint(self) -> str:
        """
        Return CRL SHA1 fingerprint as user-readable, colon-separated string
        e.g. "49:e1:b6:f5:1f:25:3c:df:a6:b7:a1:91:01:86:77:01:9e:5b:f6:b9"
        """
        fingerprint_str = self.crl.fingerprint(hashes.SHA1()).hex()
        return ':'.join(a + b for a, b in zip(fingerprint_str[::2], fingerprint_str[1::2]))

    def is_certificate_revoked(self, cert: Cert) -> bool:
        return self.crl.get_revoked_certificate_by_serial_number(cert.serial) is not None


class CertManager:
    """
    This class represents a list of x509 certificates with associated metadata.

    Note: functions which modify underlying cert data should update the objects
    to keep the CertManager instance consistent, but will not handle the case where
    another instance has modified the data.
    """

    def __init__(self, cert_cfg: Dict[str, Any]):
        self._certs = []
        self._crl = None
        self._cert_cfg = cert_cfg

        # Load CRL from storage API
        self._crl = ca_crl_load()

        # Load client certs from storage API
        client_certs_index_json = None
        try:
            client_certs_index_json = storage_api.sa_load(
                f"{client_certs_key}/index.json")
            if not client_certs_index_json:
                return
        except storage_api.ObjectNotFound:
            return

        client_certs_index = json.load(client_certs_index_json)
        for item in client_certs_index.get('items', []):
            item_name = item.get('name', None)
            if not item_name:
                continue
            cert = cert_load(self, item_name)
            if not cert:
                continue
            self._certs.append(cert)

    @property
    def cert_cfg(self) -> Dict[str, Any]:
        return self._cert_cfg

    @property
    def crl(self) -> CertRevocationList:
        return self._crl

    @property
    def certs(self) -> List[Cert]:
        return self._certs

    def get_cert_by_name(self, name: str) -> Optional[Cert]:
        for cert in self.certs:
            if str(cert.name) == name:
                return cert

        return None

    def get_cert_by_serial(self, serial: str) -> Optional[Cert]:
        for cert in self.certs:
            if str(cert.serial) == serial:
                return cert

        return None

    def get_cert_by_fingerprint(self, fingerprint: str) -> Optional[Cert]:
        for cert in self.certs:
            if cert.fingerprint == fingerprint:
                return cert
        return None

    def create_cert(self, name: str, common_name: str) -> Optional[Cert]:
        """
        Creates a new key + CSR + certificate
        """
        if common_name:
            subject = f"/CN={common_name}"
        else:
            subject = f"/CN={self.cert_cfg['client_cert_default_cn']}"

        # TODO set startdate to a day ago for timezone issues?
        # TODO would like to create the cert with the cryptography library, but it returns an
        # openssl error on validation. For now, just use openssl to generate it.

        # Generate unique client cert ID based on name (used as the key for the storage API)
        client_id = base64.urlsafe_b64encode(
            name.encode('utf-8')).decode("utf-8")

        key_data, cert_data = openssl_create_cert(subject, self.cert_cfg["client_cert_days"])
        if not key_data or not cert_data:
            return None

        if not sync_local_ca():
            return None

        new_cert = Cert(self, client_id, name, key_data, cert_data)
        self.certs.append(new_cert)

        cert_store(new_cert)

        return new_cert

    def delete_cert(self, cert: Cert) -> bool:
        # Add cert to CRL
        if not openssl_revoke_cert(cert):
            return False

        if not sync_local_ca():
            return False

        # Reload CRL
        self._crl = ca_crl_load()

        return True


class ExportFormat(Enum):
    OPENVPN = "openvpn"


class Exporter:
    """
    This class handles exporting a certificate as a configuration bundle for VPN
    clients.
    """

    def __init__(self, cert: Cert, ca: Cert, username=None, templates_dir="/templates"):
        self._cert = cert
        self._ca = ca
        self._username = username
        self._env = jin.Environment(
            loader=jin.FileSystemLoader(templates_dir),
            trim_blocks=True,
            lstrip_blocks=True)

    @property
    def cert(self) -> Cert:
        return self._cert

    @property
    def ca(self) -> Cert:
        return self._ca

    @property
    def username(self) -> str:
        return self._username

    @property
    def env(self):
        return self._env

    def export(self, format: ExportFormat) -> List[Tuple[str, str]]:
        """
        Exports a certificate in the target VPN client format. Currently only supports OpenVPN.

        Returns a list of files, each represented as a tuple containing the filename and contents.
        """
        files = []

        if format == ExportFormat.OPENVPN:
            # OpenVPN exports as a config file, with an optional auth file
            # containing a pre-populated username
            auth_file = "auth.txt"

            files.append(("ipaccess.ovpn", self._env.get_template("openvpn.j2").render(
                server_endpoints=util.get_lighthouse_external_endpoints(),
                server_port="8194",
                ca=self.ca.public_string,
                cert=self.cert.public_string,
                username=self.username,
                auth_file=auth_file,
                key=self.cert.key
            )))
            if self.username:
                files.append((auth_file, self.username))

            # Append README to the bundle
            files.append((
                "README.txt",
                "To access a node’s remote IP networks, add the node name to your Lighthouse username using this format: username:node"
            ))

        return files


def openssl_create_cert(subject: str, days: int) -> (bytes, bytes):
    """
    This functions calls openssl directly and:
    - creates a new private key and CSR
    - creates a new cert from the CSR signed by the CA
    - returns the key and cert bytes to the caller, and removes the created files from the filesystem
    """
    with tempfile.TemporaryDirectory() as tmpdir:
        process = subprocess.run(
            [
                "openssl",
                "req",
                "-nodes",
                "-new",
                "-newkey",
                "rsa",
                "-keyout",
                f"{tmpdir}/client.key",
                "-out",
                f"{tmpdir}/client.csr",
                "-subj",
                subject,
            ],
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            check=False,
        )
        if process.returncode != 0:
            logger.error("Failed to create OpenSSL CSR: '%s'", process.stdout)
            return None, None

        process = subprocess.run(
            [
                "openssl",
                "ca",
                "-batch",
                "-extensions",
                "v3_client",
                "-days",
                str(days),
                "-in",
                f"{tmpdir}/client.csr",
                "-out",
                f"{tmpdir}/client.crt",
            ],
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            check=False,
        )
        if process.returncode != 0:
            logger.error("Failed to create OpenSSL certificate: '%s'", process.stdout)
            return None, None

        with open(f"{tmpdir}/client.key", "rb") as f:
            key_data = f.read()

        with open(f"{tmpdir}/client.crt", "rb") as f:
            cert_data = f.read()

    return key_data, cert_data


def openssl_revoke_cert(cert: Cert):
    """
    This functions calls openssl directly and:
    - revokes the specified certificate
    - rebuilds the CRL
    """
    # OpenSSL needs the cert to exist locally, so write it to a tmpfile
    with tempfile.NamedTemporaryFile() as cert_path:
        cert_path.write(cert.public_bytes)
        cert_path.flush()

        # TODO string matching error handling sucks, the api would be preferable
        process = subprocess.run(
            [
                "openssl",
                "ca",
                "-revoke",
                cert_path.name,
            ],
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            check=False)
        if process.returncode != 0 and b"ERROR:Already revoked" not in process.stdout:
            logger.error("Failed to revoke OpenSSL cert: '%s'", process.stdout)
            return False

        process = subprocess.run(
            [
                "openssl",
                "ca",
                "-gencrl",
                "-crldays",
                str(ca_cert_days()),
                "-out",
                f"{ca_cert_dir()}/ca.crl",
            ],
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            check=False,
        )
        if process.returncode != 0:
            logger.error("Failed to update OpenSSL certificate revocation list: '%s'", process.stdout)
            return False

    return True
