#!/usr/bin/python3
"""
This is a simple TFTP file server which extends
fbtftp (https://github.com/facebook/fbtftp).

It serves files out of /tftpboot, and hooks into Redis to run Jinja2
templating for device-specific variables on the files it serves.
"""

from fbtftp.base_handler import BaseHandler
from fbtftp.base_handler import ResponseData
from fbtftp.base_server import BaseServer

from jinja2 import Environment, FileSystemLoader
from jinja2.exceptions import TemplateNotFound
from remote_dop.provd import load_config
from remote_dop.target_device import TargetDevice, from_ip, from_mac
import remote_dop.file_serve
from remote_dop.ordered_provisioning import dependencies_satisfied
from remote_dop.util import normalise_host_addr

from io import BytesIO
import os
import redis

# Global handle for the redis connection.
redis_port = 6379
if 'REDIS_PORT' in os.environ:
    redis_port = os.environ['REDIS_PORT']

REDIS_HANDLE = redis.StrictRedis(host="localhost", port=redis_port, db=0)

# default directory to serve files from
ROOT_DIR = "/tftpboot"

def print_session_stats(stats):
    """Print session related stats"""
    print(stats)

def print_server_stats(stats):
    """Print server related stats"""
    counters = stats.get_and_reset_all_counters()
    print('Server stats - every {} seconds'.format(stats.interval))
    print(counters)

class TemplateResponseData(ResponseData):
    """This class extends the base ResponseData to provide support for
    Jinja2 templating on files.
    """
    def __init__(self, path, peer):
        self._size = 0
        self._reader = None

        # Use root as the loader path, since all requests will be prepended with _root (/tftpboot/)
        jinja2_env = Environment(loader=FileSystemLoader('/'))

        # Validate requested file. If we have a template for this file (ending in .j2), use that.
        filename = os.path.basename(path)
        actual_filename = remote_dop.file_serve.get_filename_to_serve(filename)
        if actual_filename is None:
            return
        actual_filepath = '/tftpboot/{}'.format(actual_filename)

        config = load_config('/etc/provd.conf')

        normalised_address = normalise_host_addr(peer[0])

        target_device = from_ip(TargetDevice(REDIS_HANDLE), normalised_address)

        if not dependencies_satisfied(REDIS_HANDLE, target_device, actual_filename, config):
            return

        # If file has .j2 extension, try to open it as a template. If this fails (e.g. binary file), return an error
        if actual_filepath.endswith(".j2"):
            template = None
            try:
                template = jinja2_env.get_template(actual_filepath)
            except UnicodeDecodeError:
                template = None
            except TemplateNotFound:
                template = None

            if template:
                rendered_template = template.render(
                    nom_device_ipv4_address=target_device.ip,
                    nom_device_mac_address=target_device.mac,
                    nom_device_hostname=target_device.hostname)
                self._size = str(len(rendered_template))

                # Create a BytesIO so it can be treated as a file
                self._reader = BytesIO(bytes(rendered_template, "utf-8"))
        else:
            self._size = os.stat(actual_filepath).st_size
            self._reader = open(actual_filepath, 'rb')

        print("Get file {} for device {} via TFTP".format(actual_filename, target_device.ip))
    def read(self, num_bytes):
        return self._reader.read(num_bytes)

    def size(self):
        return self._size

    def close(self):
        try:
            self._reader.close()
        except AttributeError:
            pass

class StaticHandler(BaseHandler):
    """Simple Handler for TFTP requests"""
    def __init__(self, server_addr, peer, path, options, root, stats_callback):
        self._root = root
        super().__init__(server_addr, peer, path, options, stats_callback)

    def get_response_data(self):
        return TemplateResponseData(os.path.join(self._root, self._path), self._peer)

class StaticServer(BaseServer):
    """Simple TFTP server"""
    def __init__(self, address, port, retries, timeout, root,
                 handler_stats_callback, server_stats_callback=None):
        self._root = root
        self._handler_stats_callback = handler_stats_callback
        super().__init__(address, port, retries, timeout, server_stats_callback)

    def get_handler(self, server_addr, peer, path, options):
        return StaticHandler(
            server_addr, peer, path, options, self._root,
            self._handler_stats_callback)

def main():
    """Start a server for files in ROOT_DIR, and serve indefinitely"""
    server = StaticServer('0.0.0.0', 69, 3, 5, ROOT_DIR,
            print_session_stats,
            print_server_stats)
    try:
        server.run()
    except KeyboardInterrupt:
        server.close()

if __name__ == '__main__':
    main()
