import os
import socket
import time


class CommandException(Exception):
    pass


class OpenVPNClientStatus:
    def __init__(self, common_name=None, real_address=None, virtual_ipv4_address=None, virtual_ipv6_address=None,
                 bytes_rec=None, bytes_sent=None, connected_since=None, connected_since_unix=None, username=None,
                 client_id=None, peer_id=None):
        self.common_name = common_name
        self.real_address = real_address
        self.virtual_ipv4_address = virtual_ipv4_address
        self.virtual_ipv6_address = virtual_ipv6_address
        self.bytes_rec = bytes_rec
        self.bytes_sent = bytes_sent
        self.connected_since = connected_since
        self.connected_since_unix = connected_since_unix
        self.username = username
        self.client_id = client_id
        self.peer_id = peer_id


class OpenVPNSocket:
    def __init__(self, socket_path):
        if not socket_path:
            raise ValueError('OpenVPNSocket: empty socket path')
        self._socket_path = socket_path
        self._socket = None

    def __enter__(self):
        self.connect()
        return self

    def __exit__(self, type, value, traceback):
        self.disconnect()

    def _recv(self):
        return self._socket.recv(4096)

    def connect(self):
        self.disconnect()
        self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self._socket.settimeout(10)
        self._socket.connect(self._socket_path)
        # wait for info like ">INFO:OpenVPN Management Interface Version 1 -- type 'help' for more info\r\n"
        start = int(time.time())
        while int(time.time()) - start < 10:
            b = self._recv()
            if not b:
                raise ValueError('OpenVPNManager: unexpected eof')
            if chr(b[len(b)-1]) == '\n':
                return
        raise RuntimeError('OpenVPNManager: info timeout')

    def disconnect(self):
        if not self._socket:
            return
        s = self._socket
        self._socket = None
        s.close()

    def send_command(self, cmd):
        """
        Sends a command to the OpenVPN socket, and return the response as a
        string, note that it requires connect to be called beforehand, and
        will close the write side of the socket.

        According to the OpenVPN management interface spec, this will be either:
        - "SUCCESS: [text]" or "ERROR: [text]" in the case of a basic command
        - "END" in the case of a command with multi-line response

        Failure cases will result in the output being raised as an exception.
        """
        self._socket.sendall(cmd.encode('utf-8'))
        self._socket.shutdown(socket.SHUT_WR)
        resp = bytearray()
        while True:
            b = self._recv()
            if not b:
                break
            resp += b
        resp = resp.decode('utf-8')
        if resp.endswith('\r\nEND\r\n'):
            return resp[:len(resp) - 7]
        if resp.startswith('SUCCESS: '):
            return resp[9:]
        raise CommandException(resp)


class OpenVPNManager:
    def __init__(self, socket_path=None):
        if socket_path is None:
            socket_path = os.environ.get("MGMT_SOCKET")
        if not socket_path:
            raise ValueError('OpenVPNManager: empty socket path')

        def send_command(cmd):
            with OpenVPNSocket(socket_path) as sock:
                return sock.send_command(cmd)

        self._send_command = send_command

    def send_command(self, cmd):
        """Sends a command to the OpenVPN socket, and return the response.
        """
        return self._send_command(cmd)

    def get_client_status(self):
        """Return the current OpenVPN client list
        """
        # 2 is the new version of the status output, with comma-separated fields to make it easier to parse
        output = self.send_command('status 2\n')

        client_list = []

        for line in output.split('\n'):
            fields = line.split(',')
            if fields[0] == 'CLIENT_LIST':
                # The CLIENT_LIST output in OpenVPN status version 2 has the following format:
                # CLIENT_LIST,Common Name,Real Address,Virtual Address,Virtual IPv6 Address,Bytes Received,Bytes Sent,Connected Since,Connected Since (time_t),Username,Client ID,Peer ID
                # TODO this is based on testing, not an official spec. Would probably be safer to parse the HEADER line to pull out the fields and their indices.
                # This also assumes none of the user-defined fields have a comma in them (validated at API level)
                client_list.append(
                    OpenVPNClientStatus(common_name=fields[1], real_address=fields[2], virtual_ipv4_address=fields[3],
                                        virtual_ipv6_address=fields[4], bytes_rec=fields[5], bytes_sent=fields[6],
                                        connected_since=fields[7], connected_since_unix=fields[8],
                                        username=fields[9], client_id=fields[10], peer_id=fields[11]))

        return client_list

    def kill_client_by_real_address(self, client_list, address):
        """Kills an OpenVPN client connection based on real address e.g. 192.168.0.1:12345

        Returns True on success, or False if the connection wasn't found, throwing an
        exception on command failure
        """
        for client in client_list:
            if client.real_address == address:
                self.send_command(f'kill {address}\n')
                return True
        return False
