#!/usr/bin/python3

import central_sdi
import grpc
import os
import sys
from central_sdi import (
    nodes,
    sessions,
    util,
    ipaccess_cfg,
    access,
    policies,
)
from netops.lighthouse import groups
from netops.ip_access import remote
from pyroute2 import IPRoute

OVERRIDE_MAP_URL = os.environ["STORAGE_API_ADDR"] + "/override_map"
SDI_GRETAP_INTERFACE_PREFIX = "ipa-gretap"

if __name__ == '__main__':
    central_sdi.init()
    env = central_sdi.env
    logger = env.logger
    iproute2 = IPRoute()

    username, node_id, lh_rest_api_token, pam_user_groups = env.last_auth
    if pam_user_groups is not None:
        # was deserialized as a list, we want it back as a set
        pam_user_groups = set(pam_user_groups)
    sdi_header = util.log_header(node_id, username)

    conn = env.connection
    if conn:
        logger.warning(f'{sdi_header} VPN client connection failed - concurrent connection attempt by {username}')
        exit(1)

    node = nodes.get_node_by_id(node_id)
    if not node:
        logger.warning(f'{sdi_header} VPN client connection failed - could not fetch node information')
        exit(1)

    config = ipaccess_cfg.default(ipaccess_cfg.load())
    zone_overrides = util.read_override_map(OVERRIDE_MAP_URL)
    nap_config = policies.flatten(policies.load())

    # validate availability of / fetch necessary data
    if not pam_user_groups:
        logger.error(f'{sdi_header} VPN client connection failed - missing user group info')
        exit(1)

    conn = sessions.Connection(
        time_unix=env.time_unix,
        username=username,
        lighthouse_ip=util.get_lhvpn_address(),
        node_id=node_id,
        node_name=node.name,
        node_ip=node.lhvpn_address,
        node_port=env.node_port,
        bridge_iface=env.bridge_iface,
        gretap_iface=f"{SDI_GRETAP_INTERFACE_PREFIX}{node_id.split('-', 1)[1]}",
    )

    gretap_iface_index = iproute2.link_lookup(ifname=conn.gretap_iface)
    gretap_iface_index = gretap_iface_index[0] if len(gretap_iface_index) != 0 else None
    if gretap_iface_index is not None:
        # {'attrs': [('IFLA_GRE_LINK', 0), ('IFLA_GRE_IFLAGS', 0), ('IFLA_GRE_OFLAGS', 0),
        # ('IFLA_GRE_IKEY', 0), ('IFLA_GRE_OKEY', 0), ('IFLA_GRE_LOCAL', '192.168.200.1'),
        # ('IFLA_GRE_REMOTE', '192.168.200.2'), ('IFLA_GRE_TTL', 0), ('IFLA_GRE_TOS', 0),
        # ('IFLA_GRE_PMTUDISC', 1), ('IFLA_GRE_FWMARK', 0), ('IFLA_GRE_ENCAP_TYPE', 0),
        # ('IFLA_GRE_ENCAP_SPORT', 0), ('IFLA_GRE_ENCAP_DPORT', 0), ('IFLA_GRE_ENCAP_FLAGS', 0),
        # ('IFLA_GRE_IGNORE_DF', 0)]}
        gretap_iface_info = iproute2.link('get', index=gretap_iface_index)[0].get_attr('IFLA_LINKINFO').get_attr('IFLA_INFO_DATA')
        if gretap_iface_info.get_attr('IFLA_GRE_IKEY') != 0 or \
                gretap_iface_info.get_attr('IFLA_GRE_OKEY') != 0 or \
                gretap_iface_info.get_attr('IFLA_GRE_LOCAL') != conn.lighthouse_ip or \
                gretap_iface_info.get_attr('IFLA_GRE_REMOTE') != conn.node_ip:
            iproute2.link('del', index=gretap_iface_index)
            gretap_iface_index = None

    if gretap_iface_index is None:
        # Need to check here for nodes which have previously been connected but have been un-enrolled and re-enrolled.
        # This will result in the same LH VPN IP being used but with a different node ID (so the original interface
        # won't be re-used or deleted and cause the connection to fail.
        for interface_index in iproute2.link_lookup():
            if iproute2.link('get', index=interface_index)[0].get_attr('IFLA_IFNAME')\
                    .startswith(SDI_GRETAP_INTERFACE_PREFIX):
                if iproute2.link('get', index=interface_index)[0].get_attr('IFLA_LINKINFO') is not None:
                    interface_data = iproute2.link('get', index=interface_index)[0]\
                        .get_attr('IFLA_LINKINFO').get_attr('IFLA_INFO_DATA')
                    if interface_data.get_attr("IFLA_GRE_IKEY") == 0 and \
                            interface_data.get_attr("IFLA_GRE_OKEY") == 0 and \
                            interface_data.get_attr('IFLA_GRE_LOCAL') == conn.lighthouse_ip and \
                            interface_data.get_attr('IFLA_GRE_REMOTE') == conn.node_ip:
                        iproute2.link('del', index=interface_index)

        iproute2.link(
            'add',
            ifname=conn.gretap_iface,
            kind='gretap',
            gre_local=conn.lighthouse_ip,
            gre_remote=conn.node_ip,
        )

    iproute2.link('set',
        index=iproute2.link_lookup(ifname=conn.gretap_iface),
        master=iproute2.link_lookup(ifname=conn.bridge_iface))
    iproute2.link('set',
        index=iproute2.link_lookup(ifname=conn.gretap_iface),
        state='up')

    with grpc.insecure_channel(f"{conn.node_ip}:{conn.node_port}") as channel:
        node_config = remote.SystemStub(channel).Config(remote.SystemConfigRequest()).config
        access_options = []

        # prefer NAP mechanism to determine client access when enabled
        if config['enable_network_access_policies']:
            lh_groups = {group.groupname: group for group in [groups.Group(group) for group in util.get_lh_groups()]}
            access_options.append(remote.Access(policy_allow_zone_names=remote.Access.AllowZoneNames(
                acl=list(access.calculate_acl(nap_config, [
                    access.filter_groups(pam_user_groups),
                    access.filter_lighthouse(conn.node_id, lh_groups.get),
                ])),
            )))

        # fallback access option (older behavior, and only behavior currently supported by OGCS)
        policy_wan_lan = remote.Access.WanLan(
            enable_wan=config['enable_wan'],
            zone_overrides=zone_overrides,
        )
        access_options.append(remote.Access(policy_wan_lan=policy_wan_lan))

        session = remote.SessionsStub(channel).Create(remote.SessionsCreateRequest(config=remote.Session.Config(
            lighthouse_ip=conn.lighthouse_ip,
            node_ip=conn.node_ip,
            user_id=username,
            # policy_wan_lan fields set directly in config message for versions prior to ip-access-remote-v1.3.0
            enable_wan=policy_wan_lan.enable_wan,
            zone_overrides=policy_wan_lan.zone_overrides,
            access_options=access_options,
        ))).session

    conn.session_id = session.id
    conn.client_ip = session.state.user_ip
    env.connection = conn

    node_address = node_config.bridge_ip.split('/', 1)[0]
    client_address, client_netmask = session.state.user_ip.split("/", 1)
    client_netmask = util.prefix2netmask(int(client_netmask))
    with open(sys.argv[1], "a") as ovpn_config:
        ovpn_config.write(f'''ifconfig-push {client_address} {client_netmask}\n''')
        if config['compress']:
            # `comp-lzo` is deprecated, but `compress` isn't documented as supported by `push`
            ovpn_config.write('''comp-lzo\npush "comp-lzo"\n''')
        if not config['disable_routes']:
            for route in session.state.routes:
                route_address, route_netmask = route.to_prefix.split("/", 1)
                route_netmask = util.prefix2netmask(int(route_netmask))
                # TODO revisit this route metric behavior
                route_metric = int(config['route_metric'])
                if route_metric > 4294967295 or route_metric < 0:
                    # max uint32
                    route_metric = 4294967295
                ovpn_config.write(f'''push "route {route_address} {route_netmask} {node_address} {route_metric:d}"\n''')

    logger.info(f'{sdi_header} VPN client connected with IP {client_address} netmask {client_netmask}')
