import grpc
import falcon
import time
import falcon.status_codes as http_status
from urllib import parse as urllib
from typing import (
    Optional,
    Callable,
    List,
    Dict,
)
from netops.lighthouse import (
    nodes as lh_nodes,
)
from netops import (
    automation_gateway as ag,
)
from netops.grpcutil import (
    header,
)
from . import (
    auth,
    util,
)

_url_subdelims = "!$&'()*+,;="


def wait_for_guacamole_session(stub: ag.RemoteStub, name: str, wait_time: int, wait_interval: int):
    while wait_time > 0:
        conditions = stub.GetGuacamoleSession(ag.GetGuacamoleSessionRequest(name=name)).conditions
        if any(condition.type == 'Ready' and condition.status == 'True' for condition in conditions):
            return True
        time.sleep(wait_interval)
        wait_time -= wait_interval
    return False


def wait_for_sshttyd_session(stub: ag.RemoteStub, name: str, wait_time: int, wait_interval: int):
    while wait_time > 0:
        conditions = stub.GetSshttydSession(ag.GetSshttydSessionRequest(name=name)).conditions
        if any(condition.type == 'Ready' and condition.status == 'True' for condition in conditions):
            return True
        time.sleep(wait_interval)
        wait_time -= wait_interval
    return False


class Tokens:
    def __init__(
            self,
            *,
            identity: Optional[auth.Identity] = None,
            ag_auth: Optional[ag.AuthStub] = None,
            get_nodes: Optional[Callable[[str, List[str]], Dict[str, lh_nodes.Node]]] = None,  # (token, node_ids)
    ):
        self._identity = identity
        self._ag_auth = ag_auth
        self._get_nodes = get_nodes

    def on_post(self, req: falcon.Request, res: falcon.Response):
        """
        ---
        description: Create Automation Gateway auth token
        requestBody:
            content:
                application/json:
                    schema: CreateAuthTokenRequest
        responses:
            200:
                description: OK
                content:
                    application/json:
                        schema: CreateAuthTokenResponse
        """
        request_type = req.context.json.get("request_type", None)
        if not request_type or request_type not in ["new", "refresh"]:
            raise falcon.HTTPBadRequest(
                description="Valid type not provided in request body")

        response = {}

        if request_type == "new":
            node = req.context.json.get("node", None)
            if not node:
                raise falcon.HTTPBadRequest(
                    description="Node not provided in request body")

            url = req.context.json.get("url", None)
            if not url:
                raise falcon.HTTPBadRequest(
                    description="URL not provided in request body")

            session = req.context.json.get("session", None)
            if not session:
                raise falcon.HTTPBadRequest(
                    description="`session` not provided in request body")

            service_name = ''
            service_netloc = ''
            service_query = []
            service_fragment = ''

            # potentially update the above + special cases for different protocols
            try:
                parsed_url = urllib.urlsplit(url)
                if parsed_url.username is not None:
                    service_netloc += urllib.quote(parsed_url.username, safe=_url_subdelims)
                if parsed_url.username is not None or parsed_url.password is not None:
                    service_netloc += ':'
                if parsed_url.password is not None:
                    service_netloc += urllib.quote(parsed_url.password, safe=_url_subdelims)
                if service_netloc:
                    service_netloc += '@'
                if parsed_url.query:
                    service_query.extend(urllib.parse_qsl(
                        parsed_url.query,
                        keep_blank_values=True,
                        strict_parsing=True,
                        errors='strict',
                    ))
                service_fragment = parsed_url.fragment
                parsed_scheme = parsed_url.scheme.lower()
            except ValueError as e:
                raise falcon.HTTPBadRequest(description=f"invalid url: {url}") from e

            if parsed_scheme in {'rdp', 'vnc'}:
                rdp_security = None
                if parsed_url.hostname is None:
                    raise falcon.HTTPBadRequest(description="Hostname must be provided")
                if parsed_url.port is None:
                    raise falcon.HTTPBadRequest(description="Port must be provided")
                nodes = self._get_nodes(session, [node])
                lh_node = nodes.get(node)
                if not lh_node or not lh_node.mac_address or not lh_node.lhvpn_address:
                    raise falcon.HTTPBadRequest(description=f"Invalid node {lh_node}")
                if not util.node_sufficient_rights(nodes, ag.Node.Reference(id=node, mac_address=lh_node.mac_address, ip_address=lh_node.lhvpn_address)):
                    falcon.HTTPForbidden(description=f"Insufficient permissions for node {lh_node}")
                if parsed_scheme == 'rdp':
                    for key, val in service_query:
                        if key == 'security':
                            rdp_security = val
                            break
                with util.ag_remote_channel(lh_node.lhvpn_address) as channel:
                    stub = ag.RemoteStub(channel)
                    parameters = {
                        'hostname': f"{parsed_url.hostname}",
                        'port': f"{parsed_url.port}",
                        'ignore-cert': 'true',
                    }
                    username = req.context.json.get("username", None)
                    password = req.context.json.get("password", None)
                    if parsed_scheme == 'vnc' and not password:
                        raise falcon.HTTPBadRequest(description="Password must be provided for VNC")
                    if username:
                        parameters['username'] = username
                    if password:
                        parameters['password'] = password
                    if rdp_security:
                        parameters['security'] = rdp_security
                    guacamole_session = stub.CreateGuacamoleSession(ag.CreateGuacamoleSessionRequest(guacamole_session=ag.GuacamoleSession(
                        protocol=parsed_scheme,
                        parameters=parameters
                    )))
                    if not wait_for_guacamole_session(stub, guacamole_session.name, 60, 1):
                        raise falcon.HTTPInternalServerError(description="Failed to start remote Guacamole session")
                service_query = [('username', guacamole_session.username), ('password', guacamole_session.password)] + service_query
                url = guacamole_session.url
                service_name = guacamole_session.name
            elif parsed_scheme == 'ssh':
                if parsed_url.hostname is None:
                    raise falcon.HTTPBadRequest(description="Hostname must be provided")
                if parsed_url.port is None:
                    raise falcon.HTTPBadRequest(description="Port must be provided")
                nodes = self._get_nodes(session, [node])
                lh_node = nodes.get(node)
                if not lh_node or not lh_node.mac_address or not lh_node.lhvpn_address:
                    raise falcon.HTTPBadRequest(description=f"Invalid node {lh_node}")
                with util.ag_remote_channel(lh_node.lhvpn_address) as channel:
                    stub = ag.RemoteStub(channel)
                    sshttyd_session = stub.CreateSshttydSession(ag.CreateSshttydSessionRequest(sshttyd_session=ag.SshttydSession(
                        host=parsed_url.hostname,
                    )))
                    if not wait_for_sshttyd_session(stub, sshttyd_session.name, 30, 1):
                        raise falcon.HTTPInternalServerError(description="Failed to start remote Sshttyd session")
                service_query = [('arg', sshttyd_session.host), ('arg', sshttyd_session.name.split('/')[-1])]
                url = sshttyd_session.url
                service_name = sshttyd_session.name
            elif parsed_scheme not in ('http', 'https'):
                raise falcon.HTTPBadRequest(description=f"Unsupported scheme {parsed_scheme}")

            response['service_url'] = urllib.urlunsplit((
                '',
                service_netloc,
                '/',
                urllib.urlencode(service_query),
                service_fragment,
            ))
            if service_name:
                response['service_name'] = service_name
            try:
                response['token'] = self._ag_auth.CreateToken(ag.CreateTokenRequest(session=session, node=node, url=url)).token
            except grpc.RpcError as e:
                code = e.code()  # pylint: disable=no-member
                headers = [
                    header.grpc_status(code),
                ]
                if code == grpc.StatusCode.NOT_FOUND:
                    raise falcon.HTTPNotFound(headers=headers)
                if code == grpc.StatusCode.PERMISSION_DENIED:
                    raise falcon.HTTPForbidden(headers=headers)
                raise falcon.HTTPInternalServerError(
                    description=f"netops.automation_gateway.CreateTokenRequest {code}: {e.details()}",  # pylint: disable=no-member
                    headers=headers,
                )
        elif request_type == "refresh":
            token = req.context.json.get("token", None)
            if not token:
                raise falcon.HTTPBadRequest(
                    description="Token not provided in request body")

            session = req.context.json.get("session", None)
            if not session:
                raise falcon.HTTPBadRequest(
                    description="`session` not provided in request body")

            try:
                response['token'] = self._ag_auth.RefreshToken(ag.RefreshTokenRequest(session=session, token=token)).token
            except grpc.RpcError as e:
                code = e.code()  # pylint: disable=no-member
                headers = [
                    header.grpc_status(code),
                ]
                raise falcon.HTTPInternalServerError(
                    description=f"netops.automation_gateway.RefreshTokenRequest {code}: {e.details()}",  # pylint: disable=no-member
                    headers=headers,
                )

        res.media = response
        res.status = http_status.HTTP_200


class Validate:
    def __init__(
            self,
            *,
            identity: Optional[auth.Identity] = None,
            ag_auth: Optional[ag.AuthStub] = None,
    ):
        self._identity = identity
        self._ag_auth = ag_auth

    def on_post(self, req: falcon.Request, res: falcon.Response):
        """
        ---
        description: Validate Automation Gateway auth token
        requestBody:
            content:
                application/json:
                    schema: AuthTokenRequest
        responses:
            200:
                description: OK
                content:
                    application/json:
                        schema: ValidateAuthTokenResponse
        """
        token = req.context.json.get("token", None)
        if not token:
            res.status = http_status.HTTP_400
            res.media = {"error": "Token not provided in request body"}
            return

        try:
            resp = self._ag_auth.ValidateToken(
                ag.ValidateTokenRequest(token=token))
        except grpc.RpcError as e:
            code = e.code()  # pylint: disable=no-member
            headers = [
                header.grpc_status(code),
            ]
            raise falcon.HTTPInternalServerError(
                description=f"netops.automation_gateway.ValidateTokenRequest {code}: {e.details()}",  # pylint: disable=no-member
                headers=headers,
            )

        res.media = {"valid": resp.valid}
        res.status = http_status.HTTP_200
