#!/var/opt/python/bin/python
import signal
import sys
import json
import os
import pexpect
import time
import threading
import re
import argparse
import logging
from logging.handlers import SysLogHandler
from . import device, state, ogcs, ngcs

class DiscoveryThread(threading.Thread):
    def __init__(self, discovery_details):
        threading.Thread.__init__(self)
        self.discovery_details = discovery_details
        self.setName('[port{}]'.format(self.port))
        self.logger = logging.getLogger()
        self.child = None
        self.state_machine = None
        self.shutdown = threading.Event()
        self.start_time = time.time()

    @property
    def discovered(self):
        return self.discovery_details['discovered']

    @discovered.setter
    def discovered(self, discovered):
        self.discovery_details['discovered'] = discovered

    @property
    def port(self):
        return self.discovery_details['port_id']

    @property
    def baud(self):
        return self.discovery_details['current_baud']

    @property
    def pinout(self):
        return self.discovery_details['current_pinout']

    @property
    def results(self):
        return self.discovery_details['results']

    @results.setter
    def results(self, results):
        self.discovery_details['results'] = results

    def _spawn_child(self):
        self.child = pexpect.spawn('/bin/pmshell -l port{}'.format(self.port))
        self.child.timeout = 5
        self.child.delaybeforesend = 1.0
        time.sleep(5)

    def cleanup(self):
        if self.child and self.child.isalive():
            self.child.close()

    def run(self):
        self.logger.info('Discovery starting')

        self.logger.info('Checking port readiness')
        if not device.ready(self.port):
            self.logger.info('Device is currently busy, skipping discovery')
            return

        self._spawn_child()

        results = {
            'port_id': self.port,
            'device_type': device.UNKNOWN,
            'label': None,
            'baud': self.baud,
            'pinout': self.pinout,
            'mac_address': None,
            'serial_no': None,
            'factory_default': None,
            'discovery_username': None,
            'discovery_password': None
        }

        # Start by attempting to exit out of any existing session
        device.attempt_exit(self.child)

        self.state_machine = state.StateMachine(results, self.child)

        # Get to target (Enabled) state, with a limit to prevent getting stuck in a loop between states
        count = 0
        while count < 5:
            if self.shutdown.is_set():
                self.logger.warning("Signal received, exiting")
                device.attempt_exit(self.child)
                self.cleanup()
                return

            next_state = self.state_machine.next_state()

            if str(next_state) == state.UNKNOWN_STATE:
                # Note: this doesn't mean we can't do any discovery. It just means
                # we can't do device-specific discovery.
                self.logger.warning("Could not determine device state")
                break

            if str(next_state) == state.ENABLED_STATE:
                break

            count += 1

        # Once in the enabled state, we can fetch device information
        if str(self.state_machine.current_state) == state.ENABLED_STATE:
            results['label'] = device.get_label_prompt(self.child)
            results['device_type'] = device.get_device_type(self.child)
            if results['device_type'] != device.UNKNOWN:
                self.logger.debug("Discovered device type '{}'".format(results['device_type']))

                # Device-specific discovery preparation
                device.prepare_discovery(self.child, results['device_type'])

                # Device-specific information gathering
                results['mac_address'] = device.get_mac_address(self.child, results['device_type'])
                if results['mac_address']:
                    self.logger.debug(
                        "Discovered MAC address '{}'".format(results['mac_address']))

                results['serial_no'] = device.get_serial_number(self.child, results['device_type'])
                if results['serial_no']:
                    self.logger.debug(
                        "Discovered Serial Number '{}'".format(results['serial_no']))

                results['factory_default'] = device.get_factory_default(self.child, results['device_type'])
                if results['factory_default']:
                    self.logger.debug(
                        "Discovered factory default state: {}".format(results['factory_default'])
                    )

        # Attempt to exit the device session
        device.attempt_exit(self.child)

        # Sleep to account for any child send delay
        time.sleep(3)

        # Weirdly, the label can also be discovered from the banner after exiting from any current session, and this is the best place to do it
        if not results['label']:
            results['label'] = device.get_label_banner(self.child)

        self.cleanup()

        # Set 'discovered' to True if we managed to discover any port information
        if results['device_type'] != device.UNKNOWN or results['label'] is not None or results['mac_address'] is not None or results['serial_no'] is not None:
            self.discovered = True
            self.logger.info("Finished discovery in {} seconds".format(round(time.time() - self.start_time)))

        self.results = results

def get_node():
    """
    Return the appropriate node class based on which *GCS this script is running
    """
    node = None
    with open("/etc/version") as f:
        version = f.readline()
        if "OpenGear" in version:
            node = ogcs.OGCS()
        else:
            node = ngcs.NGCS()
    return node

def infod_port_to_list(info_name):
    try:
        port_list = [int(re.findall(r'[0-9]{1,2}$', info_name)[0])]
    except IndexError:
        logging.getLogger().error("'{}' is not a valid infod portname".format(info_name))
        port_list = []
    return port_list

class SignalException(Exception):
    pass

def signal_handler():
    logging.getLogger().error("Received signal, exiting")
    raise SignalException

def do_discovery(verbose=None, log=None, skip=None, bauds=None, pinouts=None, port=None):
    discovery_bauds = device.DEFAULT_BAUDS
    discovery_pinouts = device.DEFAULT_PINOUTS

    threading.current_thread().name = '[' + os.path.basename(__file__) + ']'
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(threadName)s %(message)s')

    sh = SysLogHandler(address='/dev/log')
    sh.setFormatter(formatter)
    logger.addHandler(sh)

    if bauds:
        discovery_bauds = bauds.split(',')
    if pinouts:
        discovery_pinouts = pinouts.split(',')
    if log:
        # Log to stderr as well as syslog
        ch = logging.StreamHandler(sys.stderr)
        ch.setFormatter(formatter)
        logger.addHandler(ch)
    if verbose:
        # Set verbose logging
        logger.setLevel(logging.DEBUG)

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

    node = get_node()
    if not node:
        logger.error("Could not get node type")
        sys.exit(1)

    try:
        threads = []

        # Get port details
        if port:
            ports = infod_port_to_list(port)
        else:
            ports = node.get_all_ports()
        available_ports = node.get_available_ports(ports, skip)
        num_ports = len(available_ports)

        if num_ports == 0:
            print("No ports available for discovery, exiting...")
            sys.exit()

        # Probably better ways to do this, but for now aggregate the results
        # by creating an array with a slot for each thread that it can populate itself
        port_details = []
        undiscovered_port_details = []
        for n in range(num_ports):
            port_details.append({
                "port_id": available_ports[n],
                "discovered": False,
                "current_baud": None,
                "current_pinout": None,
                "results": None
            })

        # Before we start, back up port config to restore if detection fails (use external config in
        # case this process dies horribly and loses everything).
        # TODO: again, there is no locking around the config, so if someone changes it while discovery is
        # running, restoring config could overwrite their changes.
        for d in port_details:
            if not node.backup_port_config(d['port_id']):
                logger.error("Failed to backup port config")
                sys.exit(1)

        # Baud rate/pinout is sort of like meta-discovery, which is done by iterating over
        # each configuration  and running the entire discovery process across all ports. Any ports
        # which failed to be discovered will be configured to the next baud rate/pinout and run through
        # the discovery process again.
        # TODO would it be better to start discovery on the current baud/pinout for each port, instead
        # of using the same values across the board?
        for pinout in discovery_pinouts:
            for baud in discovery_bauds:
                threads = []

                undiscovered_port_details = [d for d in port_details if not d['discovered']]
                if not undiscovered_port_details:
                    break

                logger.info("Starting discovery with {} baud and {} pinout".format(baud, pinout))

                # First, set and apply the current baud/pinout for all ports
                configuration_dirty = False
                for d in undiscovered_port_details:
                    d['current_baud'] = baud
                    d['current_pinout'] = pinout

                    success, port_config_dirty = node.set_port_config(d['port_id'], baud, pinout)
                    if not success:
                        logger.error("Failed to set serial port config")
                        sys.exit(1)

                    configuration_dirty = configuration_dirty or port_config_dirty

                if configuration_dirty:
                    if not node.apply_port_config():
                        logger.error("Failed to apply serial config changes")
                        sys.exit(1)

                for d in undiscovered_port_details:
                    t = DiscoveryThread(d)
                    t.start()
                    threads.append(t)

                for t in threads:
                    while t.is_alive():
                        t.join(1)

        # Apply discovered port config
        for d in port_details:
            if d['discovered']:
                if not node.set_port_config(d['port_id'],
                        d['results']['baud'],
                        d['results']['pinout'],
                        d['results']['label'],
                        d['results']['discovery_username'],
                        d['results']['discovery_password']):
                    logger.error("Failed to apply discovered port config")
                    sys.exit(1)
            else:
                # Restore backup config for ports that failed to be discovered
                if not node.restore_port_config(d['port_id']):
                    logger.error("Failed to restore port config after discovery")
                    sys.exit(1)

        # Apply any configuration changes made
        if not node.apply_port_config():
            logger.error("Failed to apply discovery configuration changes")
            sys.exit(1)

        discovery_results = [d['results'] for d in port_details]
        logger.info("Discovered Port Info: {}".format(json.dumps(discovery_results)))

    except SignalException:
        for t in threads:
            t.shutdown.set()

        for t in threads:
            t.join()

    discovered_node = discovery_results[0]
    discovered_node['port_id'] = 'port' + str(discovery_results[0]['port_id'])
    print(discovered_node)
    return discovered_node

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Discover port information')
    parser.add_argument('-v', action='store_true', help='Turn on verbose logging')
    parser.add_argument('-l', action='store_true', help='Log to stderr as well as syslog')
    parser.add_argument('-s', action='store_true', help='Skip ports with existing user sessions')
    parser.add_argument('-b', dest='bauds', help='Specify baud rates to discover as a comma-separated list e.g. 9600,115200')
    parser.add_argument('-o', dest='pinouts', help='Specify pinouts to discover as a comma-separated list e.g. X1,X2')
    parser.add_argument('-p', dest='port', help='Specify port to discover')

    args = parser.parse_args()
    do_discovery(args.v, args.l, args.s, args.bauds, args.pinouts, args.port)