"""
devices implements the user-facing (frontend etc) devices api
"""
import re
import falcon
import grpc
from typing import (
    Any,
    Callable,
    Dict,
    Optional,
    Tuple,
    List,
)
from netops.lighthouse import (
    nodes as lh_nodes,
)
from google.protobuf import (
    json_format as jsonpb,
)
from netops import (
    automation_gateway as ag,
)
from netops.grpcutil import (
    header,
)
from . import (
    auth,
    util,
)


class Devices:
    def __init__(
            self,
            *,
            identity: Optional[auth.Identity] = None,
            ag_devices: Optional[ag.DevicesStub] = None,
            get_nodes: Optional[Callable[[str, List[str]], Dict[str, lh_nodes.Node]]] = None,  # (token, node_ids)
    ):
        self._identity = identity
        self._ag_devices = ag_devices
        self._get_nodes = get_nodes

    def on_get(self, req: falcon.Request, res: falcon.Response):
        """
        ---
        description: GET multiple devices, oldest first.
        parameters:
            - name: limit
              description: Limits the number of devices.
              in: query
              required: false
              schema:
                type: integer
                minimum: 1
                maximum: 300
                default: 100
            - name: device
              description: Name of device to filter on.
              in: query
              required: false
              schema:
                type: string
            - name: ip
              description: IP address of device to filter on.
              in: query
              required: false
              schema:
                type: string
            - name: service
              description: Name of services to filter on.
              in: query
              required: false
              schema:
                type: string
            - name: node
              description: Name or id of node to show devices from.
              in: query
              required: false
              schema:
                type: string
            - name: mac
              description: MAC address fragment to filter devices on.
              in: query
              required: false
              schema:
                type: string
            - name: offset
              description: Filter devices including only those following the previous result (that provided the offset, see DevicesResponse.offset).
              in: query
              required: false
              schema:
                type: string
        responses:
            200:
                description: OK
                content:
                    application/json:
                        schema: DevicesResponse
            400:
                description: Bad Request
        """
        # WARNING change the doc if you change the params
        return self._write(res, *self.search(
            token=req.context['token'],
            limit=req.get_param_as_int('limit', min_value=1, max_value=300, default=100),
            offset=req.get_param('offset'),
            device_name=req.get_param('device'),
            ip_address=req.get_param('ip'),
            service_name=req.get_param('service'),
            node_id_or_name=req.get_param('node'),
            mac_addr=req.get_param('mac'),
        ))

    def search(
            self,
            token: str,
            limit: int,
            offset: Optional[str] = None,
            device_name: Optional[str] = None,
            ip_address: Optional[str] = None,
            service_name: Optional[str] = None,
            node_id_or_name: Optional[str] = None,
            mac_addr: Optional[str] = None,
    ) -> Tuple[List[ag.Device], Optional[str]]:
        """
        search is paginated with consistent sorting and filtering
        it returns a devices list and the next offset (None if none left)
        """

        def check_node_id_or_name(nodes_with_rights: dict, checked_nodes: set, node_id_or_name: str, device: ag.models_pb2.Device) -> bool:
            for host in device.hosts:
                for source in host.sources:
                    if source.node.id not in checked_nodes:
                        # Haven't yet checked the API's to see if the token grants us access
                        return True
                    if is_node_name_or_id_an_id(node_id_or_name) and node_id_or_name in checked_nodes and nodes_with_rights.get(node_id_or_name, None) is None:
                        # User is searching by node-id for devices discovered by a node they're not authorized to access - omit this device from result
                        return False
                    lh_node = util.lookup_node(nodes_with_rights, source.node)
                    if lh_node is not None and node_id_or_name in lh_node.name:
                        return True
                    if node_id_or_name in source.node.id:
                        return True
            return False

        def check_service_name(service_name, device):
            for host in device.hosts:
                for service in host.services:
                    if service.nmap.name == service_name:
                        return True
            return False

        def check_ip_addr(ip_addr, device):
            available = False
            # Check available hosts first
            for host in device.hosts:
                if len(host.availability) > 0:
                    available = True
                    if ip_addr in host.ip_address:
                        return True
                else:
                    pass
            # Check unavailable hosts if not availablity was detected
            if not available:
                for host in device.hosts:
                    if ip_addr in host.ip_address:
                        return True
            return False

        def check_mac_addr(mac_addr, device):
            for host in device.hosts:
                if mac_addr in host.mac_address:
                    return True
            return False

        def is_forbidden(nodes_with_rights, checked_nodes, device):
            """
            Returns false if the device has at least one host
            with a source node that hasn't been checked or has been
            checked and has sufficient rights for AG.
            """
            for host in device.hosts:
                for source in host.sources:
                    if source.node.id not in checked_nodes or util.node_sufficient_rights(nodes_with_rights, source.node):
                        return False
            return True

        devices = []  # List of all AG discovered devices that gets filtered through the following steps before being returned in the response
        nodes_with_rights = {}  # Fetched nodes that the auth token provides access to
        checked_nodes = set()  # Nodes that have been checked for authorisation via the token

        def search():
            nonlocal self, token, limit, offset, device_name, service_name, \
                node_id_or_name, mac_addr, devices, nodes_with_rights, checked_nodes

            index = len(devices)
            limited = False
            offset_ts, offset_id = None, None
            if offset is not None:
                offset_ts, offset_id = self._decode_offset(offset)

            def apply_limit():
                nonlocal limit, devices, limited
                if len(devices) <= limit:
                    return
                devices = devices[:limit]
                limited = True

            try:
                for msg in self._ag_devices.Range(ag.DevicesRange.Request()):
                    if offset is not None and (offset_ts > msg.device.updated.first or (offset_ts == msg.device.updated.first and offset_id >= msg.device.id)):
                        continue

                    if device_name is not None and (device_name not in msg.device.name):
                        continue

                    if ip_address is not None and not check_ip_addr(ip_address, msg.device):
                        continue

                    if mac_addr is not None and not check_mac_addr(mac_addr, msg.device):
                        continue

                    if is_forbidden(nodes_with_rights, checked_nodes, msg.device):
                        continue

                    if node_id_or_name is not None and not check_node_id_or_name(nodes_with_rights, checked_nodes, node_id_or_name, msg.device):
                        continue

                    if service_name is not None and not check_service_name(service_name, msg.device):
                        continue

                    # note that the range order is not stable (golang's sync.Map)
                    # so we insert sort it into the slice based create ts then id
                    devices_insert_sort(devices, msg.device)

                    # drop devices that are past the limit in an attempt to save memory
                    # the +50 is to avoid copying the full slice every single time
                    if len(devices) > limit + 50:
                        apply_limit()
            except grpc.RpcError as e:
                code = e.code()  # pylint: disable=no-member
                headers = [
                    header.grpc_status(code),
                ]
                raise falcon.HTTPInternalServerError(
                    description=f"netops.automation_gateway.Devices.Range {code}: {e.details()}",  # pylint: disable=no-member
                    headers=headers,
                )

            # final limit check
            apply_limit()

            # determine the next offset (if there is any more data available)
            # WARNING needs to use the last item PRIOR to auth filtering
            offset = None
            if limited and len(devices) > 0:
                offset = devices[len(devices) - 1]
                offset = self._encode_offset(offset.updated.first, offset.id)

            # Filter out devices we don't have rights over.
            # Do this at the end to minimise the list of nodes we need to fetch
            unchecked_nodes = util.devices_node_ids(devices).difference(checked_nodes)
            checked_nodes.update(unchecked_nodes)
            unchecked_nodes = list(unchecked_nodes)
            if unchecked_nodes:
                for k, v in self._get_nodes(token, unchecked_nodes).items():
                    nodes_with_rights[k] = v
            while index < len(devices):
                if node_id_or_name is None or check_node_id_or_name(nodes_with_rights, checked_nodes, node_id_or_name, devices[index]):
                    devices[index] = util.filter_device(devices[index], nodes_with_rights)
                    if devices[index].hosts:
                        index += 1
                        continue
                del devices[index]

        while len(devices) < limit:
            search()
            if offset is None:
                break

        filter_out_unauthorized_content(devices, nodes_with_rights.keys())

        return devices, offset

    @staticmethod
    def _encode_offset(offset_ts: int, offset_id: str) -> str:
        return f"{offset_ts},{offset_id}"

    @staticmethod
    def _decode_offset(offset: str) -> Tuple[int, str]:
        try:
            split = offset.split(',')
            if len(split) != 2:
                raise ValueError
            return int(split[0]), split[1]
        except ValueError:
            raise falcon.HTTPBadRequest(description=f"invalid param offset: {offset}") from None

    @staticmethod
    def _write(res: falcon.Response, devices: List[ag.Device], offset: Optional[str]):
        res.set_header(*header.grpc_status(grpc.StatusCode.OK))
        res.status = falcon.HTTP_OK
        res.media = Devices.media(devices, offset)

    @staticmethod
    def media(devices: List[ag.Device], offset: Optional[str]) -> Dict[str, Any]:
        media = {'devices': []}
        for device in devices:
            media['devices'].append(jsonpb.MessageToDict(
                device,
                including_default_value_fields=True,
                preserving_proto_field_name=True,
            ))
        if offset is not None:
            media['offset'] = offset
        return media


class Device:
    def __init__(
            self,
            *,
            identity: Optional[auth.Identity] = None,
            ag_devices: Optional[ag.DevicesStub] = None,
            get_nodes: Optional[Callable[[str, List[str]], Dict[str, lh_nodes.Node]]] = None,  # (token, node_ids)
    ):
        self._identity = identity
        self._ag_devices = ag_devices
        self._get_nodes = get_nodes

    def on_head(self, req: falcon.Request, res: falcon.Response, device_id: str):  # pylint: disable=unused-argument
        """
        ---
        description: HEAD device by id.
        parameters:
            - name: device_id
              in: path
              required: true
              schema:
                type: string
        responses:
            200:
                description: OK
            404:
                description: Not Found
        """
        self._write(res, self._load_filtered(req.context['token'], device_id), body=False)

    def on_get(self, req: falcon.Request, res: falcon.Response, device_id: str):  # pylint: disable=unused-argument
        """
        ---
        description: GET device by id.
        parameters:
            - name: device_id
              in: path
              required: true
              schema:
                type: string
        responses:
            200:
                description: OK
                content:
                    application/json:
                        schema: DeviceResponse
            404:
                description: Not Found
        """
        self._write(res, self._load_filtered(req.context['token'], device_id))

    def on_delete(self, req: falcon.Request, res: falcon.Response, device_id: str):  # pylint: disable=unused-argument
        """
        ---
        description: DELETE a device by id.
        parameters:
            - name: device_id
              in: path
              required: true
              schema:
                type: string
        responses:
            200:
                description: OK
            503:
                description: Service Unavailable
        """
        # Unfortunately, have to load the device first to check whether we're allowed to delete it
        # TODO maybe we can just limit deleting to LH Admins or something
        self._load_filtered(req.context['token'], device_id)
        try:
            self._ag_devices.Delete(ag.DevicesDelete.Request(id=device_id))
        except grpc.RpcError as e:
            code = e.code()  # pylint: disable=no-member
            headers = [
                header.grpc_status(code),
            ]
            if code == grpc.StatusCode.UNAVAILABLE:
                raise falcon.HTTPServiceUnavailable(headers=headers)
            raise falcon.HTTPInternalServerError(
                description=f"netops.automation_gateway.Devices.Delete {code}: {e.details()}",  # pylint: disable=no-member
                headers=headers,
            )

        res.set_header(*header.grpc_status(grpc.StatusCode.OK))
        res.status = falcon.HTTP_OK

    def _load_filtered(self, token: str, device_id: str) -> ag.Device:
        device = self._load(ag.DevicesLoad.Request(id=device_id)).device
        device = util.filter_device(device, self._get_nodes(token, list(util.devices_node_ids([device]))))
        if not device.hosts:
            raise falcon.HTTPForbidden()
        return device

    def _load(self, req: ag.DevicesLoad.Request) -> ag.DevicesLoad.Response:
        try:
            return self._ag_devices.Load(req)
        except grpc.RpcError as e:
            code = e.code()  # pylint: disable=no-member
            headers = [
                header.grpc_status(code),
            ]
            if code == grpc.StatusCode.NOT_FOUND:
                raise falcon.HTTPNotFound(headers=headers)
            raise falcon.HTTPInternalServerError(
                description=f"netops.automation_gateway.Devices.Load {code}: {e.details()}",  # pylint: disable=no-member
                headers=headers,
            )

    @staticmethod
    def _write(res: falcon.Response, device: ag.Device, body: bool = True):
        res.set_header(*header.last_modified(device.updated.last))
        res.set_header(*header.grpc_status(grpc.StatusCode.OK))
        res.status = falcon.HTTP_OK
        if body:
            res.media = {'device': jsonpb.MessageToDict(
                device,
                including_default_value_fields=True,
                preserving_proto_field_name=True,
            )}


def devices_insert_sort(devices: List[ag.Device], device: ag.Device):
    """
    devices_insert_sort inserts device into devices in order of created timestamp then id (string comparison)
    """

    def compare(a: ag.Device, b: ag.Device) -> int:
        if a.updated.first < b.updated.first:
            return -1
        if a.updated.first > b.updated.first:
            return 1
        if a.id < b.id:
            return -1
        if a.id > b.id:
            return 1
        return 0

    def search(left: int, right: int):
        if left > right:
            return left
        middle = (left + right) // 2
        c = compare(devices[middle], device)
        if c < 0:
            return search(middle + 1, right)
        if c > 0:
            return search(left, middle - 1)
        return middle

    devices.insert(search(0, len(devices) - 1), device)


def is_node_name_or_id_an_id(node_name_or_id: str) -> bool:
    """ Determines if `node_name_or_id` represents a node-id and
        if so returns True, else False and should therefore be treated
        as a name.
    """
    node_id_template = r"^nodes-\d+$"
    return re.match(node_id_template, node_name_or_id) is not None


def filter_out_unauthorized_content(devices: List[ag.models_pb2.Device], nodes_with_rights: List[str]):
    """ Mutates `devices` removing all metadata for nodes the user is
        unauthorized to access.

        NOTE: This does not change the length of `devices`
    """
    for device in devices:
        for host in device.hosts:
            # Filter out devices[:].hosts[:].sources[:] for unauthorized nodes
            index = 0
            while index < len(host.sources):
                if host.sources[index].node.id in nodes_with_rights:
                    # keep it
                    index += 1
                else:
                    del host.sources[index]

            for service in host.services:
                # Filter out devices[:].hosts[:].services[:].sources[:] for unauthorized nodes
                index = 0
                while index < len(service.sources):
                    if service.sources[index].node.id in nodes_with_rights:
                        # keep it
                        index += 1
                    else:
                        del service.sources[index]
