#!/var/opt/python/bin/python
import signal
import sys
import json
import os
import pexpect
import time
import threading
import re
import argparse
import logging
from logging.handlers import SysLogHandler

import discovery.device as device
import discovery.state as state
import discovery.ogcs as ogcs
import discovery.ngcs as ngcs

class DiscoveryThread(threading.Thread):
    def __init__(self, discovery_details):
        threading.Thread.__init__(self)
        self.discovery_details = discovery_details
        self.setName('[port{}]'.format(self.port))
        self.logger = logging.getLogger()
        self.child = None
        self.state_machine = None
        self.shutdown = threading.Event()
        self.start_time = time.time()

    @property
    def discovered(self):
        return self.discovery_details['discovered']

    @discovered.setter
    def discovered(self, discovered):
        self.discovery_details['discovered'] = discovered

    @property
    def port(self):
        return self.discovery_details['port']

    @property
    def baud(self):
        return self.discovery_details['current_baud']

    @property
    def pinout(self):
        return self.discovery_details['current_pinout']

    @property
    def results(self):
        return self.discovery_details['results']

    @results.setter
    def results(self, results):
        self.discovery_details['results'] = results

    def _spawn_child(self):
        self.child = pexpect.spawn('/bin/pmshell -l port{}'.format(self.port))
        self.child.timeout = 5
        self.child.delaybeforesend = 1.0
        time.sleep(5)

    def cleanup(self):
        if self.child and self.child.isalive():
            self.child.close()

    def run(self):
        self.logger.info('Discovery starting')

        self.logger.info('Checking port readiness')
        if not device.ready(self.port):
            self.logger.info('Device is currently busy, skipping discovery')
            return

        self._spawn_child()

        results = {
            'port': self.port,
            'device_type': device.UNKNOWN,
            'label': None,
            'baud': self.baud,
            'pinout': self.pinout,
            'mac_address': None,
            'serial_no': None,
            'factory_default': None,
            'discovery_username': None,
            'discovery_password': None
        }

        # Start by attempting to exit out of any existing session
        device.attempt_exit(self.child)

        self.state_machine = state.StateMachine(results, self.child)

        # Get to target (Enabled) state, with a limit to prevent getting stuck in a loop between states
        count = 0
        while count < 5:
            if self.shutdown.is_set():
                self.logger.warning("Signal received, exiting")
                device.attempt_exit(self.child)
                self.cleanup()
                return

            next_state = self.state_machine.next_state()

            if str(next_state) == state.UNKNOWN_STATE:
                # Note: this doesn't mean we can't do any discovery. It just means
                # we can't do device-specific discovery.
                self.logger.warning("Could not determine device state")
                break

            if str(next_state) == state.ENABLED_STATE:
                break

            count += 1

        # Once in the enabled state, we can fetch device information
        if str(self.state_machine.current_state) == state.ENABLED_STATE:
            results['label'] = device.get_label_prompt(self.child)
            results['device_type'] = device.get_device_type(self.child)
            if results['device_type'] != device.UNKNOWN:
                self.logger.debug("Discovered device type '{}'".format(results['device_type']))

                # Device-specific discovery preparation
                device.prepare_discovery(self.child, results['device_type'])

                # Device-specific information gathering
                results['mac_address'] = device.get_mac_address(self.child, results['device_type'])
                if results['mac_address']:
                    self.logger.debug(
                        "Discovered MAC address '{}'".format(results['mac_address']))

                results['serial_no'] = device.get_serial_number(self.child, results['device_type'])
                if results['serial_no']:
                    self.logger.debug(
                        "Discovered Serial Number '{}'".format(results['serial_no']))

                results['factory_default'] = device.get_factory_default(self.child, results['device_type'])
                if results['factory_default']:
                    self.logger.debug(
                        "Discovered factory default state: {}".format(results['factory_default'])
                    )

        # Attempt to exit the device session
        device.attempt_exit(self.child)

        # Sleep to account for any child send delay
        time.sleep(3)

        # Weirdly, the label can also be discovered from the banner after exiting from any current session, and this is the best place to do it
        if not results['label']:
            results['label'] = device.get_label_banner(self.child)

        self.cleanup()

        # Set 'discovered' to True if we managed to discover any port information
        if results['device_type'] != device.UNKNOWN or results['label'] is not None or results['mac_address'] is not None or results['serial_no'] is not None:
            self.discovered = True
            self.logger.info("Finished discovery in {} seconds".format(round(time.time() - self.start_time)))

        self.results = results

def get_node():
    """
    Return the appropriate node class based on which *GCS this script is running
    """
    node = None
    with open("/etc/version") as f:
        version = f.readline()
        if "OpenGear" in version:
            node = ogcs.OGCS()
        else:
            node = ngcs.NGCS()
    return node

def infod_port_to_list(info_name):
    try:
        port_list = [int(re.findall(r'[0-9]{1,2}$', info_name)[0])]
    except IndexError:
        logging.getLogger().error("'{}' is not a valid infod portname".format(info_name))
        port_list = []
    return port_list

class SignalException(Exception):
    pass

def signal_handler(signum, frame):
    logging.getLogger().error("Received signal, exiting")
    raise SignalException

def main():
    threading.current_thread().name = '[' + os.path.basename(__file__) + ']'
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(threadName)s %(message)s')

    sh = SysLogHandler(address='/dev/log')
    sh.setFormatter(formatter)
    logger.addHandler(sh)

    parser = argparse.ArgumentParser(description='Discover port information')
    parser.add_argument('-v', action='store_true', help='Turn on verbose logging')
    parser.add_argument('-l', action='store_true', help='Log to stderr as well as syslog')
    parser.add_argument('-s', action='store_true', help='Skip ports with existing user sessions')
    parser.add_argument('-b', dest='bauds', help='Specify baud rates to discover as a comma-separated list e.g. 9600,115200')
    parser.add_argument('-o', dest='pinouts', help='Specify pinouts to discover as a comma-separated list e.g. X1,X2')
    parser.add_argument('-p', dest='port', help='Specify port to discover')

    discovery_bauds = device.DEFAULT_BAUDS
    discovery_pinouts = device.DEFAULT_PINOUTS

    args = parser.parse_args()
    if args.bauds:
        discovery_bauds = args.bauds.split(',')
    if args.pinouts:
        discovery_pinouts = args.pinouts.split(',')
    if args.l:
        # Log to stderr as well as syslog
        ch = logging.StreamHandler(sys.stderr)
        ch.setFormatter(formatter)
        logger.addHandler(ch)
    if args.v:
        # Set verbose logging
        logger.setLevel(logging.DEBUG)

    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

    node = get_node()
    if not node:
        logger.error("Could not get node type")
        sys.exit(1)

    try:
        threads = []

        # Get port details
        if args.port:
            ports = infod_port_to_list(args.port)
        else:
            ports = node.get_all_ports()
        available_ports = node.get_available_ports(ports, args.s)
        num_ports = len(available_ports)

        if num_ports == 0:
            print("No ports available for discovery, exiting...")
            sys.exit()

        # Probably better ways to do this, but for now aggregate the results
        # by creating an array with a slot for each thread that it can populate itself
        port_details = []
        undiscovered_port_details = []
        for n in range(num_ports):
            port_details.append({
                "port": available_ports[n],
                "discovered": False,
                "current_baud": None,
                "current_pinout": None,
                "results": None
            })

        # Before we start, back up port config to restore if detection fails (use external config in
        # case this process dies horribly and loses everything).
        # TODO: again, there is no locking around the config, so if someone changes it while discovery is
        # running, restoring config could overwrite their changes.
        for d in port_details:
            if not node.backup_port_config(d['port']):
                logger.error("Failed to backup port config")
                sys.exit(1)

        # Baud rate/pinout is sort of like meta-discovery, which is done by iterating over
        # each configuration  and running the entire discovery process across all ports. Any ports
        # which failed to be discovered will be configured to the next baud rate/pinout and run through
        # the discovery process again.
        # TODO would it be better to start discovery on the current baud/pinout for each port, instead
        # of using the same values across the board?
        for pinout in discovery_pinouts:
            for baud in discovery_bauds:
                threads = []

                undiscovered_port_details = [d for d in port_details if not d['discovered']]
                if not undiscovered_port_details:
                    break

                logger.info("Starting discovery with {} baud and {} pinout".format(baud, pinout))

                # First, set and apply the current baud/pinout for all ports
                configuration_dirty = False
                for d in undiscovered_port_details:
                    d['current_baud'] = baud
                    d['current_pinout'] = pinout

                    success, port_config_dirty = node.set_port_config(d['port'], baud, pinout)
                    if not success:
                        logger.error("Failed to set serial port config")
                        sys.exit(1)

                    configuration_dirty = configuration_dirty or port_config_dirty

                if configuration_dirty:
                    if not node.apply_port_config():
                        logger.error("Failed to apply serial config changes")
                        sys.exit(1)

                for d in undiscovered_port_details:
                    t = DiscoveryThread(d)
                    t.start()
                    threads.append(t)

                for t in threads:
                    while t.is_alive():
                        t.join(1)

        # Apply discovered port config
        for d in port_details:
            if d['discovered']:
                if not node.set_port_config(d['port'],
                        d['results']['baud'],
                        d['results']['pinout'],
                        d['results']['label'],
                        d['results']['discovery_username'],
                        d['results']['discovery_password']):
                    logger.error("Failed to apply discovered port config")
                    sys.exit(1)
            else:
                # Restore backup config for ports that failed to be discovered
                if not node.restore_port_config(d['port']):
                    logger.error("Failed to restore port config after discovery")
                    sys.exit(1)

        # Apply any configuration changes made
        if not node.apply_port_config():
            logger.error("Failed to apply discovery configuration changes")
            sys.exit(1)

        discovery_results = [d['results'] for d in port_details]
        logger.info("Discovered Port Info: {}".format(json.dumps(discovery_results)))

    except SignalException:
        for t in threads:
            t.shutdown.set()

        for t in threads:
            t.join()

if __name__ == "__main__":
    main()
