import grpc
import multiprocessing
import falcon.status_codes as http_status
from requests import HTTPError
from central_sdi.api import authz, schema
from central_sdi import ipaccess_cfg, ipaccess_node_cfg, nodes, openvpn
from netops.ip_access import remote
from netops import auth


class Rights():
    """This endpoint returns the CentralSDI-specific UI rights the user has.
    """
    def on_get(self, request, response):
        """Rights GET endpoint.
        ---
        description: Get user's SDI rights
        responses:
            200:
                description: User rights
        """
        token = request.context.get("token")

        response.status = http_status.HTTP_200
        response.media = {
            "rights": {
                "certs": {
                    "view": authz.can(token, "/nom/sdi/certs", "get"),
                    "create": authz.can(token, "/nom/sdi/certs", "post"),
                    "edit": authz.can(token, "/nom/sdi/certs", "put"),
                    "delete": authz.can(token, "/nom/sdi/certs", "delete")
                },
                "status": {
                    "view": authz.can(token, "/nom/sdi/status", "get"),
                    "delete": authz.can(token, "/nom/sdi/status", "delete")
                },
                "nodes": {
                    "view": authz.can(token, "/nom/sdi/nodes", "get"),
                    "edit": authz.can(token, "/nom/sdi/nodes", "put")
                },
                "options": {
                    "view": authz.can(token, "/nom/sdi/options", "get"),
                    "edit": authz.can(token, "/nom/sdi/options", "put")
                }
            }
        }


class Routes():
    """This endpoint is used by the UI to generate sidebar links
    """
    def __init__(self, nom_auth_stub):
        self.nom_auth = nom_auth_stub

    def on_get(self, request, response):
        """Routes GET endpoint.
        ---
        description: Get list of SDI UI routes
        responses:
            200:
                description: UI routes
        """
        token = request.context.get("token")

        response.status = http_status.HTTP_200
        routes = {"IP Access":[]}

        if authz.can(token, "/nom/sdi/nodes", "get"):
            routes["IP Access"].append({
                "name": "Node Access",
                "route": "access"
            })
        if authz.can(token, "/nom/sdi/certs", "get"):
            routes["IP Access"].append({
                "name": "Client Certificates",
                "route": "certs"
            })
        if authz.can(token, "/nom/sdi/status", "get"):
            routes["IP Access"].append({
                "name": "Client Status",
                "route": "status"
            })
        # Access to the certs endpoint should be sufficient for this page
        if authz.can(token, "/nom/sdi/certs", "get"):
            routes["IP Access"].append({
                "name": "Advanced Options",
                "route": "options"
            })
        # Checking LH permissions as well here.
        principle = auth.Principle(session=auth.Session(token=token))
        resource = auth.Resource(lighthouse=auth.Lighthouse())
        # Create auth actions to send.
        action_netops_rw = auth.Action(rights=auth.Rights(values=[auth.Rights.Value(entity='netops_modules',
                                                                                    write=True)]))
        action_groups_ro = auth.Action(rights=auth.Rights(values=[auth.Rights.Value(entity='groups_and_roles')]))
        action_users_rw = auth.Action(rights=auth.Rights(values=[auth.Rights.Value(entity='users',
                                                                                   write=True)]))
        # Send auth actions and retrieve response.
        netops_rw = self.nom_auth.Can(auth.Can.Request(principle=principle,
                                                       resource=resource,
                                                       action=action_netops_rw)).allowed
        groups_ro = self.nom_auth.Can(auth.Can.Request(principle=principle,
                                                       resource=resource,
                                                       action=action_groups_ro)).allowed
        users_rw = self.nom_auth.Can(auth.Can.Request(principle=principle,
                                                      resource=resource,
                                                      action=action_users_rw)).allowed

        # Check that user has all the above permissions.
        if netops_rw and (groups_ro or users_rw):
            routes["IP Access"].append({
                "name": "Network Access Policies",
                "route": "policies"
            })

        response.media = routes


class IPAccess:
    """These endpoints represent the global IP Access config, and allow the user
    to completely enable/disable the service.
    """
    schema = schema.IPAccessConfig()

    def on_get(self, request, response):  # pylint: disable=unused-argument
        """IP Access Config GET endpoint.
        ---
        description: Get IP Access configuration
        responses:
            200:
                description: IP Access configuration
                content:
                    application/json:
                        schema: IPAccessConfig
        """
        response.status = http_status.HTTP_200
        response.media = ipaccess_cfg.default(ipaccess_cfg.load())

    def on_put(self, request, response):
        """IP Access Config PUT endpoint.
        ---
        description: Update IP Access configuration
        requestBody:
            content:
                application/json:
                    schema: IPAccessConfig
        responses:
            200:
                description: Updated IP Access configuration
                content:
                    application/json:
                        schema: IPAccessConfig
        """
        ipaccess_cfg.store(request.context.json)
        response.status = http_status.HTTP_200
        response.media = ipaccess_cfg.default(request.context.json)

    def on_delete(self, request, response):  # pylint: disable=unused-argument
        """This endpoint disables the entire IP Access service.

        Enabling/disabling will set a global config flag that will be checked by the VPN
        connection scripts. Existing VPN connections will be disconnected manually (note this
        is best effort, there may be edge cases here such as a user connecting while the
        service is being disabled.)

        The IP Access service will still be running on enabled nodes, but a user would need
        direct CLI access on the node or Lighthouse to use it.

        NOTE: Disabling will not kill existing sessions access across all MI instances,
        this may have to be implemented by sending a request from primary to secondaries.
        ---
        description: Disable IP Access
        responses:
            200:
                description: Updated IP Access configuration
                content:
                    application/json:
                        schema: IPAccessConfig
            500:
                description: Failed to disconnect client
        """
        cfg = ipaccess_cfg.load()
        cfg['disable'] = True
        ipaccess_cfg.store(cfg)

        openvpn_mgr = openvpn.OpenVPNManager()
        client_status = openvpn_mgr.get_client_status()
        for client in client_status:
            try:
                openvpn_mgr.kill_client_by_real_address(client_status, client.real_address)
            except openvpn.CommandException:
                response.status = http_status.HTTP_500
                response.media = {f"Failed to disconnect client with address: {client.real_address}"}
                return

        response.status = http_status.HTTP_200
        response.media = ipaccess_cfg.default(cfg)


class IPAccessNodes:
    """IP Access configuration for nodes
    """

    def on_get(self, request, response):  # pylint: disable=unused-argument
        """IP Access Node configuration GET endpoint.
        ---
        description: Get IP Access node configuration
        responses:
            200:
                description: IP Access node configuration
                content:
                    application/json:
                        schema: IPAccessNodeList
        """
        response.media = {
            "nodes": ipaccess_node_cfg.nodes_cfg_load().to_dict()}
        response.status = http_status.HTTP_200


def run_node_action(node, action):
    """
    Perform an IP Access action on a node (enable/disable).

    Return the node (for identifying the result) and the result (if an error occurs, result = None)
    """
    result = []
    try:
        with grpc.insecure_channel(f"{node.lhvpn_address}:8980") as channel:
            stub = remote.SystemStub(channel)
            if action == 'enable':
                response = stub.Enable(remote.SystemEnableRequest(), timeout=5)
            elif action == 'disable':
                response = stub.Disable(remote.SystemDisableRequest(), timeout=5)

            for r in response:
                result.append(r)
    except grpc.RpcError:
        return node, None

    return node, result


class IPAccessNodesActions():
    """Handlers for performing actions on IP Access nodes e.g. enable/disable
    """
    schema = schema.IPAccessNodeAction()

    def on_put(self, request, response):
        """IP Access Node actions GET endpoint.
        ---
        description: Enable/disable a list of nodes for IP Access
        requestBody:
            content:
                application/json:
                    schema: IPAccessNodeAction
        responses:
            200:
                description: Successfully updated nodes
            400:
                description: Unable to retrieve node(s)
            403:
                description: Permission denied
            500:
                description: Failed to save IP Access node configuration
        """
        token = request.context.get("token")

        action = request.context.json['action']
        node_ids = request.context.json.get('node_ids', [])
        node_details = []

        # Fetch each node, and check that we have netops access
        for id in node_ids:
            try:
                ret, node = nodes.node_authorized(id, token)
                if ret and node.rights.get('netops', False):
                    node_details.append(node)
                else:
                    response.status = http_status.HTTP_403
                    response.media = {"error": f"Permission denied: {id}"}
                    return
            except HTTPError as err:
                response.status = http_status.HTTP_400
                response.media = {"error": f"Unable to retrieve nodes: {err}"}
                return

        nodes_cfg = ipaccess_node_cfg.nodes_cfg_load()

        # Create a list of nodes to perform the action against
        action_list = []
        for node in node_details:
            if not node.get_module_by_name('sdi'):
                continue

            node_cfg = nodes_cfg.get_node_cfg_by_id(node.id)
            if not node_cfg:
                node_cfg = ipaccess_node_cfg.NodeConfig(node.id)
                nodes_cfg.add_node_cfg(node_cfg)

            if action == 'enable' and node_cfg.disabled:
                action_list.append((node, action))
            elif action == 'disable' and not node_cfg.disabled:
                action_list.append((node, action))

        # Run action on nodes
        failed = []
        with multiprocessing.Pool() as pool:
            for action_result in pool.starmap(run_node_action, action_list, 16):
                node = action_result[0]
                result = action_result[1]
                if result is not None:
                    node_cfg = nodes_cfg.get_node_cfg_by_id(node.id)
                    node_cfg.disabled = action == 'disable'
                else:
                    failed.append(node.id)

        # Store node config
        try:
            ipaccess_node_cfg.nodes_cfg_store(nodes_cfg)
        except HTTPError as e:
            response.status = http_status.HTTP_500
            response.media = {"error": f"Failed to save node IP Access configuration: {e}"}
            return

        # TODO currently just returns a list of failed nodes. Could also try to return error details,
        # and sessions which were enabled/disabled.
        if failed:
            response.status = http_status.HTTP_400
            response.media = {
                "failed": failed
            }
        else:
            response.status = http_status.HTTP_200
