#!/usr/bin/python3
import os
import json
from enum import IntEnum, auto
import requests
import grpc
from central_sdi import util
from netops import authz

# certificate verification warning.
requests.urllib3.disable_warnings()


class Verb(IntEnum):
    READ = auto()
    WRITE = auto()
    CREATE = auto()
    DELETE = auto()
    EXECUTE = auto()
    UNKNOWN = auto()

# Global rights definitions


class Right(IntEnum):
    RIGHT_VIEW_RIGHTS = 0
    RIGHT_VIEW_ROUTES = 1
    RIGHT_VIEW_IPACCESS_STATUS = 2
    RIGHT_VIEW_SDI_CERT = 3
    RIGHT_CREATE_SDI_CERT = 4
    RIGHT_MODIFY_SDI_CERT = 5
    RIGHT_DELETE_SDI_CERT = 6
    RIGHT_VIEW_SDI_IPACCESS_NODES = 7
    RIGHT_MODIFY_SDI_IPACCESS_NODES = 8
    RIGHT_VIEW_SDI_IPACCESS = 9
    RIGHT_MODIFY_SDI_IPACCESS = 10
    RIGHT_DELETE_SDI_IPACCESS = 11
    RIGHT_VIEW_NETWORK_ACCESS_POLICY = 12
    RIGHT_MODIFY_NETWORK_ACCESS_POLICY = 13


# Role->Rights definitions
lh_admin_global_rights = [
    Right.RIGHT_VIEW_RIGHTS,
    Right.RIGHT_VIEW_ROUTES,
    Right.RIGHT_VIEW_IPACCESS_STATUS,
    Right.RIGHT_VIEW_SDI_CERT,
    Right.RIGHT_CREATE_SDI_CERT,
    Right.RIGHT_MODIFY_SDI_CERT,
    Right.RIGHT_DELETE_SDI_CERT,
    Right.RIGHT_VIEW_SDI_IPACCESS_NODES,
    Right.RIGHT_MODIFY_SDI_IPACCESS_NODES,
    Right.RIGHT_VIEW_SDI_IPACCESS,
    Right.RIGHT_MODIFY_SDI_IPACCESS,
    Right.RIGHT_DELETE_SDI_IPACCESS,
    Right.RIGHT_VIEW_NETWORK_ACCESS_POLICY,
    Right.RIGHT_MODIFY_NETWORK_ACCESS_POLICY,
]

lh_admin_sg_rights = []

node_admin_global_rights = [
    Right.RIGHT_VIEW_RIGHTS,
    Right.RIGHT_VIEW_ROUTES,
    Right.RIGHT_VIEW_SDI_IPACCESS_NODES,
    Right.RIGHT_MODIFY_SDI_IPACCESS_NODES,
    Right.RIGHT_VIEW_NETWORK_ACCESS_POLICY,
]

node_admin_sg_rights = []
node_user_global_rights = [
    Right.RIGHT_VIEW_RIGHTS,
    Right.RIGHT_VIEW_ROUTES
]
node_user_sg_rights = []


def add_rights_for_role(role, rights, scope):
    """Add all a given role's rights to the rights list"""
    target_rights = None

    if scope == "global":
        if role == "LighthouseAdmin":
            target_rights = lh_admin_global_rights
        elif role == "NodeAdmin":
            target_rights = node_admin_global_rights
        elif role == "NodeUser":
            target_rights = node_user_global_rights
    else:
        if role == "LighthouseAdmin":
            target_rights = lh_admin_sg_rights
        elif role == "NodeAdmin":
            target_rights = node_admin_sg_rights
        elif role == "NodeUser":
            target_rights = node_user_sg_rights

    for r in target_rights:
        rights[r] = True


def add_global_rights(rights, groups):
    for group in groups['groups']:
        if not group['enabled']:
            continue
        if group['mode'] == 'global' and 'global_roles' in group:
            add_rights_for_role("LighthouseAdmin", rights, "global")
        elif group['mode'] == 'smart_group':
            if group['smart_group_roles'] == "NodeAdmin":
                add_rights_for_role("NodeAdmin", rights, "global")
            elif group['smart_group_roles'] == 'NodeUser':
                add_rights_for_role("NodeUser", rights, "global")


def string_to_verb(verb_str):
    """ convert the given verb to the authz enum
    """
    verb = Verb.UNKNOWN
    if verb_str == "READ":
        verb = Verb.READ
    elif verb_str == "WRITE":
        verb = Verb.WRITE
    elif verb_str == "CREATE":
        verb = Verb.CREATE
    elif verb_str == "DELETE":
        verb = Verb.DELETE
    elif verb_str == "EXECUTE":
        verb = Verb.EXECUTE
    return verb

############# Caching LH API functions ##################


def get_cached_groups(token):
    """Retrieve list of groups for a session. If cached results exist,
    return those. If not, retrieve them from Lighthouse API and cache them
    for next time.
    Note: The cache has the same lifetime as the session token, which mirrors
    the LH behavior where a user needs to log out for authz changes to apply."""
    token_cache_dir = "/tmp/.authz_cache/{}".format(token)

    if not os.path.isdir(token_cache_dir):
        os.makedirs(token_cache_dir)

    group_file = "{}/groups".format(token_cache_dir)
    if os.path.isfile(group_file):
        with open(group_file) as data:
            return json.loads(data.readline())

    headers = {"Authorization": "Token {}".format(token)}
    resp = requests.get(f"https://{util.lh_addr()}/api/{util.get_api_version()}/groups",
            headers=headers, verify=False)
    if resp.status_code != 200:
        # Failed to get groups, cannot authorize
        return None

    # Cache groups
    with open(group_file, "w+") as data:
        data.write(resp.text)

    return resp.json()

################ Entity Handlers ########################


class BaseEntityHandler():
    """This is the base class for entity handlers. It sets up rights on initialization,
    and stores the entity. Each entity will be handled by an extension of this class"""

    def __init__(self, entity, lh_groups):
        self._rights = [False] * len(Right)
        self._entity = entity
        self._lh_groups = lh_groups
        add_global_rights(self._rights, lh_groups)

    @property
    def rights(self):
        return self._rights

    @property
    def entity(self):
        return self._entity

    @property
    def lh_groups(self):
        return self._lh_groups

    def can(self, verb):
        pass


def new_handler_for_entity(entity, lh_groups):
    """ return handler for all SDI IP Access endpoints
    """
    handler = None
    if entity == '/nom/sdi/ipaccess':
        handler = IPAccessEntityHandler(entity, lh_groups)
    elif entity == '/nom/sdi/status':
        handler = IPAccessStatusEntityHandler(entity, lh_groups)
    elif entity.startswith('/nom/sdi/certs'):
        handler = CertsEntityHandler(entity, lh_groups)
    elif entity.startswith('/nom/sdi/nodes'):
        handler = IPAccessNodesEntityHandler(entity, lh_groups)
    elif entity == '/nom/sdi/rights':
        handler = RightsEntityHandler(entity, lh_groups)
    elif entity == '/nom/sdi/routes':
        handler = RoutesEntityHandler(entity, lh_groups)
    elif entity == '/nom/sdi/policies':
        handler = PolicyEntityHandler(entity, lh_groups)
    elif entity == '/nom/sdi/detailed':
        handler = PolicyEntityHandler(entity, lh_groups)
    return handler


class IPAccessEntityHandler(BaseEntityHandler):
    """Entity handler for IPAccess global config
    """
    def can(self, verb):
        if verb == Verb.READ and self.rights[Right.RIGHT_VIEW_SDI_IPACCESS]:
            return True
        if verb == Verb.WRITE and self.rights[Right.RIGHT_MODIFY_SDI_IPACCESS]:
            return True
        if verb == Verb.DELETE and self.rights[Right.RIGHT_DELETE_SDI_IPACCESS]:
            return True
        return False


class IPAccessStatusEntityHandler(BaseEntityHandler):
    """ entity handler for IPAccess status
    """
    def can(self, verb):
        if verb == Verb.READ and self.rights[Right.RIGHT_VIEW_IPACCESS_STATUS]:
            return True
        return False


class IPAccessNodesEntityHandler(BaseEntityHandler):
    """ Entity handler for IPAccess nodes
    """
    def can(self, verb):
        if verb == Verb.READ and self.rights[Right.RIGHT_VIEW_SDI_IPACCESS_NODES]:
            return True
        if verb == Verb.WRITE and self.rights[Right.RIGHT_MODIFY_SDI_IPACCESS_NODES]:
            return True
        return False


class RightsEntityHandler(BaseEntityHandler):
    """ entity handler for SDI rights
    """
    def can(self, verb):
        if verb == Verb.READ and self.rights[Right.RIGHT_VIEW_RIGHTS]:
            return True
        return False


class RoutesEntityHandler(BaseEntityHandler):
    """entity handler for SDI routes
    """
    def can(self, verb):
        if verb == Verb.READ and self.rights[Right.RIGHT_VIEW_ROUTES]:
            return True
        return False


class PolicyEntityHandler(BaseEntityHandler):
    """
    Entity handler for network access policies
    """
    def can(self, verb):
        if verb == Verb.READ and self.rights[Right.RIGHT_VIEW_NETWORK_ACCESS_POLICY]:
            return True
        if verb == Verb.WRITE and self.rights[Right.RIGHT_MODIFY_NETWORK_ACCESS_POLICY]:
            return True
        return False


class CertsEntityHandler(BaseEntityHandler):
    """
    Entity handler for Cert endpoints
    """
    def can(self, verb):
        if verb == Verb.READ and self.rights[Right.RIGHT_VIEW_SDI_CERT]:
            return True
        if verb == Verb.CREATE and self.rights[Right.RIGHT_CREATE_SDI_CERT]:
            return True
        if verb == Verb.WRITE and self.rights[Right.RIGHT_MODIFY_SDI_CERT]:
            return True
        if verb == Verb.DELETE and self.rights[Right.RIGHT_DELETE_SDI_CERT]:
            return True
        return False


def main_handler(token, entity, verb):
    # Compatible with either string or enum
    if isinstance(verb, str):
        verb_enum = string_to_verb(verb)
    else:
        verb_enum = verb

    if verb_enum == Verb.UNKNOWN:
        return False

    try:
        with authz.grpc_channel() as channel:
            lh_groups = authz.dummy_groups(channel, token)
    except grpc.RpcError as err:
        if err.code() != grpc.StatusCode.UNIMPLEMENTED:  # pylint: disable=no-member
            raise err
        lh_groups = get_cached_groups(token)
    if not lh_groups:
        return False

    handler = new_handler_for_entity(entity, lh_groups)
    if handler is None:
        return False

    return handler.can(verb_enum)
