"""
schema provides tools to work with protobuf messages, including the ability to import them
as marshmallow schemas

references:
https://marshmallow.readthedocs.io/en/stable/quickstart.html#creating-schemas-from-dictionaries
https://gist.github.com/trianta2/fd04bdbfc9bdef5631c0d76582a04aca
"""

import marshmallow
from typing import (
    Dict,
)
from collections import (
    namedtuple,
)
from google.protobuf import (
    descriptor,
)

_type_dict = {v: k for k, v in vars(descriptor.FieldDescriptor).items() if k.startswith('TYPE_')}

Repeated = namedtuple('Repeated', ['value'])
Map = namedtuple('Map', ['key', 'value'])
Message = namedtuple('Message', ['name', 'schema'])


def _field_type(field, context):
    '''Helper that returns either a str or nametuple corresponding to the field type'''
    if field.message_type is not None:
        return message_as_namedtuple(field.message_type, context)
    return _type_dict[field.type]


def field_type(field, context):
    '''Returns the protobuf type for a given field descriptor
    A Repeated, Map, or str object may be returned. Strings correspond to protobuf types.
    '''
    if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
        msg_type = field.message_type
        is_map = msg_type is not None and msg_type.GetOptions().map_entry
        if is_map:
            key = _field_type(field.message_type.fields[0], context)
            value = _field_type(field.message_type.fields[1], context)
            return Map(key, value)
        value = _field_type(field, context)
        return Repeated(value)
    return _field_type(field, context)


def message_as_namedtuple(descr, context):
    '''Returns a namedtuple corresponding to the given message descriptor'''
    name = descr.full_name
    if name not in context:
        Msg = namedtuple(name.replace('.', '_'), [f.name for f in descr.fields])
        context[name] = Message(descr.full_name, Msg(*(field_type(f, context) for f in descr.fields)))
    return context[name]


class FieldInt64(marshmallow.fields.Integer):
    def __init__(self, *, as_string: bool = True, **kwargs):
        super().__init__(**{**kwargs, 'as_string': as_string})


class FieldBytes(marshmallow.fields.String):
    pass


field_mapping = {
    FieldInt64: ('string', 'int64'),
    FieldBytes: ('string', 'byte'),
}


def register_message(mapping: Dict[str, type], descr):
    scalar_field = {
        'TYPE_DOUBLE': lambda: marshmallow.fields.Float(),
        'TYPE_FLOAT': lambda: marshmallow.fields.Float(),
        'TYPE_INT32': lambda: marshmallow.fields.Integer(),
        'TYPE_UINT32': lambda: marshmallow.fields.Integer(),
        'TYPE_FIXED32': lambda: marshmallow.fields.Integer(),
        'TYPE_SFIXED32': lambda: marshmallow.fields.Integer(),
        'TYPE_SINT32': lambda: marshmallow.fields.Integer(),
        'TYPE_INT64': lambda: FieldInt64(),
        'TYPE_UINT64': lambda: FieldInt64(),
        'TYPE_FIXED64': lambda: FieldInt64(),
        'TYPE_SFIXED64': lambda: FieldInt64(),
        'TYPE_SINT64': lambda: FieldInt64(),
        'TYPE_BOOL': lambda: marshmallow.fields.Boolean(),
        'TYPE_STRING': lambda: marshmallow.fields.String(),
        'TYPE_BYTES': lambda: FieldBytes(),
        'TYPE_ENUM': lambda: marshmallow.fields.String(),  # TODO validation support?
    }

    def message_field(msg: Message):
        # special cases for google's well known types (that have custom json encoding)
        if msg.name in ('google.protobuf.Timestamp', 'google.protobuf.FieldMask'):
            return marshmallow.fields.String()
        return marshmallow.fields.Nested(register(msg))

    def list_field(values):
        return marshmallow.fields.List(values)
    def dict_field(keys, values):
        return marshmallow.fields.Dict(keys=keys, values=values)

    def register(msg: Message):
        nonlocal mapping
        if msg.name not in mapping:
            schema = {}
            for field, field_proto in msg.schema._asdict().items():
                if isinstance(field_proto, str):
                    field_value = scalar_field[field_proto]()
                elif isinstance(field_proto, Message):
                    field_value = message_field(field_proto)
                elif isinstance(field_proto, Repeated):
                    if isinstance(field_proto.value, str):
                        values = scalar_field[field_proto.value]()
                    elif isinstance(field_proto.value, Message):
                        values = message_field(field_proto.value)
                    else:
                        raise ValueError(f"unexpected type {field_proto.value}")
                    field_value = list_field(values)
                elif isinstance(field_proto, Map):
                    if isinstance(field_proto.key, str):
                        keys = scalar_field[field_proto.key]()
                    elif isinstance(field_proto.key, Message):
                        keys = message_field(field_proto.key)
                    else:
                        raise ValueError(f"unexpected type {field_proto.key}")
                    if isinstance(field_proto.value, str):
                        values = scalar_field[field_proto.value]()
                    elif isinstance(field_proto.value, Message):
                        values = message_field(field_proto.value)
                    else:
                        raise ValueError(f"unexpected type {field_proto.value}")
                    field_value = dict_field(keys, values)
                else:
                    raise ValueError(f"unexpected type {field_proto}")
                schema[field] = field_value
            mapping[msg.name] = marshmallow.Schema.from_dict(schema)
        return mapping[msg.name]

    return register(message_as_namedtuple(descr, {}))
