""" This module is used to add any utilities which are
    required by the REST API and other scripts.
"""

import ipaddress
import json
import os
import shlex
import socket
import struct
import urllib
import zipfile
import pathlib
from io import BytesIO
from netops.ip_access import remote
from subprocess import Popen, PIPE
from typing import (
    Dict,
    Iterable,
    Set,
    Tuple,
    Optional,
)
from requests.exceptions import HTTPError

import requests

# Cache value
api_version = None


def lh_addr():
    """
    sa_address is the base address for the relevant storage api

    :return: str
    """
    address = os.getenv('LH_ADDR')
    if not address:
        raise ValueError('LH_ADDR not set')
    return address


def get_lighthouse_external_endpoints():
    """
    Get the configured external address for Lighthouse.
    Clients will use this address in their client ovpn file to connect to
    the SDI VPN server.
    """

    # We can hit the Lighthouse API via localhost since central-sdi is running
    # in host mode
    auth_res = requests.get(
        f'https://{lh_addr()}/api/{get_api_version()}/system/external_endpoints',
        verify=False,
        cert=get_certs())

    if auth_res.status_code == 200:
        data = json.loads(auth_res.content)
        return data.get('system_external_endpoints', None)

    return None


def run_command(cmd):
    """Run any script (bash/pyhton) as a separate child process
       Returns the output and error as a tuple and also the return code of the script
    """
    output = Popen(shlex.split(cmd), stdout=PIPE, stderr=PIPE)  # pylint: disable=consider-using-with
    return output.communicate(), output.returncode


def get_lhvpn_address():
    resp = make_request(f"https://{lh_addr()}/api/{get_api_version()}/services/lhvpn")
    resp.raise_for_status()
    return str(ipaddress.IPv4Address(resp.json()['lhvpn']['address']) + 1)


def get_api_version():
    """returns the current version of the LH api
    """
    global api_version

    if api_version is None:
        api_version = os.environ.get("LH_API_VERSION")

    # Fallback option is v3.2
    if api_version is None:
        api_version = "v3.2"

    return api_version


def make_request(url, token=None):
    """
    Make an API request and return the response.

    If a token is provided, it will use that. Else it will use cert auth
    with global access. (Any Non-None value will be interpreted as an attempt to provide a token.)

    You can also pass a client certificate and key as a tuple in that order.
    """
    if token is not None:
        headers = {"Authorization": "Token {}".format(token)}
        return requests.get(url, verify=False, headers=headers)

    return requests.get(url, verify=False, cert=get_certs())


def get_certs():
    """get_certs returns the ssl public & private keypair for the requests lib
       to use. We can use to access Lighthouse endpoints for internal scripts.
    """
    return "/root/.netops/lhvpn/lhvpn_server.crt", "/root/.netops/lhvpn/lhvpn_server.key"


def get_lh_session_token(username, password):
    """
    Create a new Lighthouse API session with the specified credentials, and return the
    session token.
    """
    session_body = {
        'username': username,
        'password': password
    }
    resp = requests.post(f"https://{lh_addr()}/api/{get_api_version()}/sessions",
            data=json.dumps(session_body), verify=False)
    resp.raise_for_status()
    data = resp.json()
    if data['state'] != "authenticated":
        http_error_msg = '%s Authentication Error: %s for url: %s' % (resp.status_code, data['state'], resp.url)
        raise HTTPError(http_error_msg, response=resp)
    return data['session']


def zip_files(files):
    """
    Takes a list of files represented as tuples containing name and contents.

    Returns an in-memory ZIP archive (as bytes) containing all the files.
    """
    zipper = BytesIO()
    with zipfile.ZipFile(zipper, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
        for f in files:
            zf.writestr(f[0], f[1])

    return zipper.getvalue()


def prefix2netmask(prefix):
    return socket.inet_ntoa(struct.pack(">I", (0xffffffff << (32 - prefix)) & 0xffffffff))


def log_header(node_id, username=None):
    if username:
        return "[NetOps-SDI node=\"{}\" username=\"{}\"]".format(node_id, username)
    return "[NetOps-SDI node=\"{}\"]".format(node_id)


def read_override_map(file_url):
    """
    Reads the configuration file which has bveen populated with the interface-zone override mapping.
    """
    try:
        with urllib.request.urlopen(file_url) as f:
            map_dict = json.loads((f.read().decode('utf-8').strip('\n')).replace("\'", '\"'))
            if not isinstance(map_dict, dict):
                return {} # Invalid file format
    except (urllib.error.HTTPError, json.decoder.JSONDecodeError, AttributeError, ValueError):
        # Possible failures due to reading upstream file or malformatted data
        return {}
    zone_mapping = {remote.Net.Zone.Name(v): v for v in remote.Net.Zone.values()} # pylint: disable=E1101
    for _, i in enumerate(map_dict):
        try:
            map_dict[i] = zone_mapping[map_dict[i]]
        except KeyError:
            map_dict[i] = zone_mapping["ZONE_UNKNOWN"]
    return map_dict


def get_lh_users():
    resp = make_request(f"https://{lh_addr()}/api/{get_api_version()}/users")
    resp.raise_for_status()
    return resp.json()


def get_lh_user(user):
    users = get_lh_users()
    lh_user = [u for u in users['users'] if u['username'] == user]
    if lh_user:
        return lh_user[0]
    return None


def parse_username_uid(lines: Iterable[str]) -> Dict[str, str]:
    """Does as it's name suggests, and returns the username to uid mapping, parsed from the
    contents of /etc/passwd, provided as an iterable of lines."""
    result = {}
    for line in lines:
        line = line.rstrip('\n\r').split(sep=':', maxsplit=4)
        if len(line) >= 3 and line[0] != '' and line[2] != '':
            result[line[0]] = line[2]
    return result


def parse_lh_secctxt(lines: Iterable[str]) -> Tuple[Optional[str], Optional[Set[str]]]:
    """Parses the file, in the format that is currently used to propagate a user's groups from the
    PAM-based auth mechanism, to the relevant session row, in the Lighthouse config database,
    returning the user, and the user's groups (note that groups are not guaranteed to exist in the
    Lighthouse configuration)."""
    username = None
    groups = None
    for line in lines:
        line = line.rstrip('\n\r')
        if username is None:
            username = line
        else:
            groups = set(line.split(sep=','))
            if '' in groups:
                groups.remove('')
            break
    return username, groups


def get_lh_user_groups(username: str) -> Set[str]:
    """Nasty parsing of passwd + secctxt file cache to get user's groups, note this will fail
    catastrophically if the user hasn't logged in since boot."""
    with open('/etc/host/passwd', 'r') as f:
        user_uid_map = parse_username_uid(f)
    if username not in user_uid_map:
        raise ValueError(f"username '{username}' not in host /etc/passwd")
    with open(str(pathlib.Path('/var/run/host/secctxt').joinpath(user_uid_map[username])), 'r') as f:
        secctxt_username, secctxt_groups = parse_lh_secctxt(f)
    if secctxt_username != username:
        raise ValueError(f"username '{username}' != secctxt '{secctxt_username}'")
    if secctxt_groups is None:
        raise ValueError('missing secctxt groups')
    return secctxt_groups


def get_lh_groups():
    return make_request(f"https://{lh_addr()}/api/{get_api_version()}/groups").json()['groups']


def get_lh_group_by_id(group_id):
    """Returns the group found by group id or None if not found."""
    groups = get_lh_groups()
    result = [g for g in groups if g['id'] == group_id]
    return result[0] if result else None


def get_lh_group_smartgroup(group_id):
    grp = get_lh_group_by_id(group_id)
    if not grp:
        return None
    if 'smart_group' not in grp:
        return None
    return grp['smart_group']


def get_lh_group_smartgroup_nodes(group_id):
    smartgroup = get_lh_group_smartgroup(group_id)
    if smartgroup:
        return make_request(f"https://{lh_addr()}/api/{get_api_version()}/nodes/smartgroups/{smartgroup}/nodes").json()['nodes']
    return []
