import backoff
import base64
import json
import requests
import grpc
from typing import (
    List,
    Set,
    Dict,
    Optional,
)
from netops.lighthouse import nodes as lh_nodes
from netops import (
    automation_gateway as ag,
)


def get_nodes(lh_addr, nodes, token, verify) -> Dict[str, lh_nodes.Node]:
    """
        Fetch a list of nodes from Lighthouse using the provided token.

        To minimise requests, hit the nodes API with each node ID as an 'OR' query param.
        Unsure of how many query params LH can safely handle, so capping it at 50 for now.
    """
    ret = {}

    with requests.Session() as api:
        api.verify = verify
        api.headers = {"Authorization": "Token {}".format(token)}

        batch_size = 50
        num_nodes = len(nodes)
        count = 0
        while True:
            query = {
                "type": 2,
                "items": [],
            }
            for node in nodes[count * batch_size:(count + 1) * batch_size]:
                query["items"].append({
                    "fieldname": "config:_id",
                    "oper": 1,
                    "datatype": 9,
                    "value": node,
                    "type": 3
                })
            query_string = base64.b64encode(json.dumps(
                query).encode("utf-8")).decode("utf-8")

            def non_retry_errors(e):
                return e.response.status_code in [403, 404]

            @backoff.on_exception(backoff.constant,
                                  requests.exceptions.RequestException,
                                  interval=1,
                                  max_tries=5,
                                  giveup=non_retry_errors)
            def get_node(query):
                resp = api.get(
                    f"https://{lh_addr}/api/v3.4/nodes?ports=false&etags=false&jb64={query}")
                resp.raise_for_status()
                return resp

            resp = get_node(query_string)

            # NOTE: just storing the response, no guarantee the requested nodes are actually in there
            for node in resp.json().get('nodes', []):
                ret[node['id']] = lh_nodes.Node(node)

            count = count + 1
            if (count * 50) > num_nodes:
                break

    return ret


def devices_node_ids(devices: List[ag.Device]) -> Set[str]:
    """
    For a list of AG devices, return the unique set of sources (nodes)
    """
    nodes_set = set()
    for device in devices:
        for host in device.hosts:
            for source in host.sources:
                nodes_set.add(source.node.id)
    return nodes_set


def lookup_node(nodes: Dict[str, lh_nodes.Node], node: ag.Node.Reference) -> Optional[lh_nodes.Node]:
    """
    Return lh_node if the provided node is within the provided dict of LH nodes

    :param nodes: Dict of node objects (pulled from LH API)
    :param node: Target Node object (from ag-devices response)
    """
    if not nodes or not node:
        return None

    lh_node = nodes.get(node.id, None)
    if not lh_node:
        return None

    # Validate nodes are the same (node IDs can be reused)
    if not lh_node.mac_address or lh_node.mac_address != node.mac_address:
        return None

    # Sanity check node has AG activated on it
    if not any(mod.name == 'ag' for mod in lh_node.modules):
        return None

    return lh_node


def node_sufficient_rights(nodes: Dict[str, lh_nodes.Node], node: ag.Node.Reference) -> bool:
    """
    Return True if the node has sufficient rights for AG

    :param nodes: Dict of node objects (pulled from LH API)
    :param node: Target Node object (from ag-devices response)
    """
    lh_node = lookup_node(nodes, node)
    if lh_node is None:
        return False

    if not lh_node.rights.get('netops', False):
        return False

    return True


def filter_device(device: ag.Device, nodes):
    """
    Filters a device's hosts and sources based on the rights of the provided nodes
    """

    def keyed(ref):
        return ref.id, ref.mac_address, ref.ip_address

    hosts = []
    for host in device.hosts:
        allowed = {keyed(source.node) for source in host.sources if node_sufficient_rights(nodes, source.node)}
        if not allowed:
            continue

        availability = [availability for availability in host.availability if keyed(availability) in allowed]
        del host.availability[:]
        host.availability.extend(availability)

        services = []
        for service in host.services:
            allowed = {keyed(source.node) for source in service.sources if node_sufficient_rights(nodes, source.node)}
            if not allowed:
                continue

            availability = [availability for availability in service.availability if keyed(availability) in allowed]
            del service.availability[:]
            service.availability.extend(availability)

            services.append(service)
        del host.services[:]
        host.services.extend(services)

        hosts.append(host)
    del device.hosts[:]
    device.hosts.extend(hosts)

    return device


def ag_remote_channel(hostname: str) -> grpc.Channel:
    return grpc.insecure_channel(f"{hostname}:9045")
