#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#

import copy
from abc import abstractmethod
from collections import OrderedDict
from enum import Enum, EnumMeta


def _string_representation(x):
    if hasattr(x, "__name__"):
        return x.__name__
    else:
        return str(x)


def _check_record_or_field(x):
    if (type(x) is type and not issubclass(x, Record)) \
            and not isinstance(x, Field):
        raise Exception('Argument ' + _string_representation(x) + ' is not a Record or a Field')


class RecordMeta(type):
    def __new__(metacls, name, parents, dct):
        if name != 'Record':
            # Do not apply this logic to the base class itself
            dct['_fields'] = RecordMeta._get_fields(dct)
            dct['_required'] = False
        return type.__new__(metacls, name, parents, dct)

    @classmethod
    def _get_fields(cls, dct):
        # Build a set of valid fields for this record
        fields = OrderedDict()
        for name, value in dct.items():
            if issubclass(type(value), EnumMeta):
                value = CustomEnum(value)
            elif type(value) == RecordMeta:
                # We expect an instance of a record rather than the class itself
                value = value()

            if isinstance(value, Record) or isinstance(value, Field):
                fields[name] = value
        return fields


class Record(metaclass=RecordMeta):

    # This field is used to set namespace for Avro Record schema.
    _avro_namespace = None

    # Generate a schema where fields are sorted alphabetically
    _sorted_fields = False

    def __init__(self, default=None, required_default=False, required=False, *args, **kwargs):
        self._required_default = required_default
        self._default = default
        self._required = required

        for k, value in self._fields.items():
            if k in kwargs:
                if isinstance(value, Record) and isinstance(kwargs[k], dict):
                    # Use dict init Record object
                    copied = copy.copy(value)
                    copied.__init__(**kwargs[k])
                    self.__setattr__(k, copied)
                elif isinstance(value, Array) and isinstance(kwargs[k], list) and len(kwargs[k]) > 0 \
                        and isinstance(value.array_type, Record) and isinstance(kwargs[k][0], dict):
                    arr = []
                    for item in kwargs[k]:
                        copied = copy.copy(value.array_type)
                        copied.__init__(**item)
                        arr.append(copied)
                    self.__setattr__(k, arr)
                elif isinstance(value, Map) and isinstance(kwargs[k], dict) and len(kwargs[k]) > 0 \
                    and isinstance(value.value_type, Record) and isinstance(list(kwargs[k].values())[0], dict):
                    dic = {}
                    for mapKey, mapValue in kwargs[k].items():
                        copied = copy.copy(value.value_type)
                        copied.__init__(**mapValue)
                        dic[mapKey] = copied
                    self.__setattr__(k, dic)
                else:
                    # Value was overridden at constructor
                    self.__setattr__(k, kwargs[k])
            elif isinstance(value, Record):
                # Value is a subrecord
                self.__setattr__(k, value)
            else:
                # Set field to default value, without revalidating the default value type
                super(Record, self).__setattr__(k, value.default())

    @classmethod
    def schema(cls):
        return cls.schema_info(set())

    @classmethod
    def schema_info(cls, defined_names):
        namespace_prefix = ''
        if cls._avro_namespace is not None:
            namespace_prefix = cls._avro_namespace + '.'
        namespace_name = namespace_prefix + cls.__name__

        if namespace_name in defined_names:
            return namespace_name

        defined_names.add(namespace_name)

        schema = {
            'type': 'record',
            'name': str(cls.__name__)
        }
        if cls._avro_namespace is not None:
            schema['namespace'] = cls._avro_namespace
        schema['fields'] = []

        def get_filed_default_value(value):
            if isinstance(value, Enum):
                return value.name
            else:
                return value

        if cls._sorted_fields:
            fields = sorted(cls._fields.keys())
        else:
            fields = cls._fields.keys()
        for name in fields:
            field = cls._fields[name]
            field_type = field.schema_info(defined_names) \
                if field._required else ['null', field.schema_info(defined_names)]
            schema['fields'].append({
                'name': name,
                'default': get_filed_default_value(field.default()),
                'type': field_type
            }) if field.required_default() else schema['fields'].append({
                'name': name,
                'type': field_type,
            })

        return schema

    def __setattr__(self, key, value):
        if key == '_default':
            super(Record, self).__setattr__(key, value)
        elif key == '_required_default':
            super(Record, self).__setattr__(key, value)
        elif key == '_required':
            super(Record, self).__setattr__(key, value)
        else:
            if key not in self._fields:
                raise AttributeError('Cannot set undeclared field ' + key + ' on record')

            # Check that type of value matches the field type
            field = self._fields[key]
            value = field.validate_type(key, value)
            super(Record, self).__setattr__(key, value)

    def __eq__(self, other):
        for field in self._fields:
            if self.__getattribute__(field) != other.__getattribute__(field):
                return False
        return True

    def __ne__(self, other):
        return not self.__eq__(other)

    def __str__(self):
        return str(self.__dict__)

    def type(self):
        return str(self.__class__.__name__)

    def python_type(self):
        return self.__class__

    def validate_type(self, name, val):
        if val is None and not self._required:
            return self.default()

        if not isinstance(val, self.__class__):
            raise TypeError("Invalid type '%s' for sub-record field '%s'. Expected: %s" % (
                type(val), name, _string_representation(self.__class__)))
        return val

    def default(self):
        if self._default is not None:
            return self._default
        else:
            return None

    def required_default(self):
        return self._required_default


class Field(object):
    def __init__(self, default=None, required=False, required_default=False):
        if default is not None:
            default = self.validate_type('default', default)
        self._default = default
        self._required_default = required_default
        self._required = required

    @abstractmethod
    def type(self):
        pass

    @abstractmethod
    def python_type(self):
        pass

    def validate_type(self, name, val):
        if val is None and not self._required:
            return self.default()

        if not isinstance(val, self.python_type()):
            raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type())))
        return val

    def schema(self):
        # For primitive types, the schema would just be the type itself
        return self.type()

    def schema_info(self, defined_names):
        return self.type()

    def default(self):
        return self._default

    def required_default(self):
        return self._required_default


# All types


class Null(Field):
    def type(self):
        return 'null'

    def python_type(self):
        return type(None)

    def validate_type(self, name, val):
        if val is not None:
            raise TypeError('Field ' + name + ' is set to be None')
        return val


class Boolean(Field):
    def type(self):
        return 'boolean'

    def python_type(self):
        return bool

    def default(self):
        if self._default is not None:
            return self._default
        else:
            return False


class Integer(Field):
    def type(self):
        return 'int'

    def python_type(self):
        return int

    def default(self):
        if self._default is not None:
            return self._default
        else:
            return None


class Long(Field):
    def type(self):
        return 'long'

    def python_type(self):
        return int

    def default(self):
        if self._default is not None:
            return self._default
        else:
            return None


class Float(Field):
    def type(self):
        return 'float'

    def python_type(self):
        return float, int

    def default(self):
        if self._default is not None:
            return self._default
        else:
            return None


class Double(Field):
    def type(self):
        return 'double'

    def python_type(self):
        return float, int

    def default(self):
        if self._default is not None:
            return self._default
        else:
            return None


class Bytes(Field):
    def type(self):
        return 'bytes'

    def python_type(self):
        return bytes, str

    def default(self):
        if self._default is not None:
            return self._default
        else:
            return None

    def validate_type(self, name, val):
        if isinstance(val, str):
            return val.encode()
        return val


class String(Field):
    def type(self):
        return 'string'

    def python_type(self):
        return str, bytes

    def validate_type(self, name, val):
        t = type(val)

        if val is None and not self._required:
            return self.default()

        if not (isinstance(val, (str, bytes)) or t.__name__ == 'unicode'):
            raise TypeError("Invalid type '%s' for field '%s'. Expected a string" % (t, name))
        if isinstance(val, bytes):
            return val.decode()
        return val

    def default(self):
        if self._default is not None:
            return self._default
        else:
            return None

# Complex types


class CustomEnum(Field):
    def __init__(self, enum_type, default=None, required=False, required_default=False):
        if not issubclass(enum_type, Enum):
            raise Exception(_string_representation(enum_type) + " is not a valid Enum type")
        self.enum_type = enum_type
        self.values = {}
        for x in enum_type.__members__.values():
            self.values[x.value] = x
        super(CustomEnum, self).__init__(default, required, required_default)

    def type(self):
        return 'enum'

    def python_type(self):
        return self.enum_type

    def validate_type(self, name, val):
        if val is None:
            return None

        if type(val) is str:
            # The enum was passed as a string, we need to check it against the possible values
            if val in self.enum_type.__members__:
                return self.enum_type.__members__[val]
            else:
                raise TypeError(
                    "Invalid enum value '%s' for field '%s'. Expected: %s" % (val, name, self.enum_type.__members__.keys()))
        elif type(val) is int:
            # The enum was passed as an int, we need to check it against the possible values
            if val in self.values:
                return self.values[val]
            else:
                raise TypeError(
                    "Invalid enum value '%s' for field '%s'. Expected: %s" % (val, name, self.values.keys()))
        elif not isinstance(val, self.python_type()):
            raise TypeError("Invalid type '%s' for field '%s'. Expected: %s" % (type(val), name, _string_representation(self.python_type())))
        else:
            return val

    def schema(self):
        return self.schema_info(set())

    def schema_info(self, defined_names):
        if self.enum_type.__name__ in defined_names:
            return self.enum_type.__name__
        defined_names.add(self.enum_type.__name__)
        return {
            'type': self.type(),
            'name': self.enum_type.__name__,
            'symbols': [x.name for x in self.enum_type]
        }

    def default(self):
        if self._default is not None:
            return self._default
        else:
            return None


class Array(Field):
    def __init__(self, array_type, default=None, required=False, required_default=False):
        _check_record_or_field(array_type)
        self.array_type = array_type
        super(Array, self).__init__(default=default, required=required, required_default=required_default)

    def type(self):
        return 'array'

    def python_type(self):
        return list

    def validate_type(self, name, val):
        if val is None:
            return None

        super(Array, self).validate_type(name, val)

        for x in val:
            if not isinstance(x, self.array_type.python_type()):
                raise TypeError('Array field ' + name + ' items should all be of type ' +
                                _string_representation(self.array_type.type()))
        return val

    def schema(self):
        return self.schema_info(set())

    def schema_info(self, defined_names):
        return {
            'type': self.type(),
            'items': self.array_type.schema_info(defined_names) if isinstance(self.array_type, (Array, Map, Record))
                else self.array_type.type()
        }

    def default(self):
        if self._default is not None:
            return self._default
        else:
            return None


class Map(Field):
    def __init__(self, value_type, default=None, required=False, required_default=False):
        _check_record_or_field(value_type)
        self.value_type = value_type
        super(Map, self).__init__(default=default, required=required, required_default=required_default)

    def type(self):
        return 'map'

    def python_type(self):
        return dict

    def validate_type(self, name, val):
        if val is None:
            return None

        super(Map, self).validate_type(name, val)

        for k, v in val.items():
            if type(k) != str and not is_unicode(k):
                raise TypeError('Map keys for field ' + name + '  should all be strings')
            if not isinstance(v, self.value_type.python_type()):
                raise TypeError('Map values for field ' + name + ' should all be of type '
                                + _string_representation(self.value_type.python_type()))

        return val

    def schema(self):
        return self.schema_info(set())

    def schema_info(self, defined_names):
        return {
            'type': self.type(),
            'values': self.value_type.schema_info(defined_names) if isinstance(self.value_type, (Array, Map, Record))
                else self.value_type.type()
        }

    def default(self):
        if self._default is not None:
            return self._default
        else:
            return None


# Python3 has no `unicode` type, so here we use a tricky way to check if the type of `x` is `unicode` in Python2
# and also make it work well with Python3.
def is_unicode(x):
    return 'encode' in dir(x) and type(x.encode()) == str
