# access: implements network access policy internals

from typing import (
    Dict,
    Set,
    Optional,
    List,
    Callable,
)
from netops.lighthouse import (
    smartgroups,
    groups,
)


def calculate_acl(policies: Dict[str, Set[str]], filters: Optional[List[Callable[[str], bool]]] = None) -> Set[str]:
    """Resolves a set of policies, pre-flattened into sets of zones per group
    (name), where any groups not matching all filters will be excluded, if
    filters are provided, note the return value will be a set of allowed
    zones."""
    result = set()
    if filters is None:
        filters = []
    for group in policies:
        if all(fn(group) for fn in filters):
            result.update(policies[group])
    return result


def filter_groups(group_names: Set[str]) -> Callable[[str], bool]:
    """Returns a filter for use with calculate_acl, accepting a set of groups."""
    return lambda group_name: group_name in group_names


def filter_lighthouse(node_id: str, group_by_name: Callable[[str], Optional[groups.Group]]) -> Callable[[str], bool]:
    """Returns a filter that applies the required filtering behavior, based
    Lighthouse group data (e.g. searching for nodes for smartgroups).
    """

    def result(group_name):
        group = group_by_name(group_name)
        if group is not None and group.enabled:
            if group.mode == 'global':
                return True
            if group.mode == 'smart_group':
                return any(node.id == node_id for node in smartgroups.get_nodes_in_group(group))
        return False

    return result
