#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import itertools
import weakref
from collections import deque, defaultdict, OrderedDict

from ...compat import six, Iterable
from ...dag import DAG, Queue
from . import utils


class NodeMetaclass(type):
    def __new__(mcs, name, bases, kv):
        mixin_slots = []
        for b in bases:
            mixin_slots.extend(list(getattr(b, '_mixin_slots', [])))

        if kv.pop('_add_args_slots', True):
            slots = mixin_slots + list(kv.get('__slots__', [])) + list(kv.get('_mixin_slots', [])) + \
                    list(kv.get('_slots', []))
            args = kv.get('_args', [])
            slots.extend(args)
            slots = OrderedDict.fromkeys(slots)

            kv['__slots__'] = tuple(slot for slot in slots if slot not in kv)

        if '__slots__' not in kv:
            kv['__slots__'] = ()
        if mixin_slots:
            kv['_mixin_slots'] = []
        return type.__new__(mcs, name, bases, kv)

    def __instancecheck__(self, instance):
        from .expressions import CollectionExpr

        if issubclass(type(instance), CollectionExpr) and \
                hasattr(instance, '_proxy') and instance._proxy is not None:
            return super(NodeMetaclass, self).__instancecheck__(instance._proxy)

        return super(NodeMetaclass, self).__instancecheck__(instance)


class Node(six.with_metaclass(NodeMetaclass)):
    __slots__ = '_args_indexes', '__weakref__'
    _args = ()  # the child(ren) of the current node
    _extra_args = ()
    _non_table = False

    def __init__(self, *args, **kwargs):
        self._args_indexes = dict((name, i) for i, name in enumerate(self._args))
        self._init(*args, **kwargs)

    def _init(self, *args, **kwargs):
        for arg, value in zip(self._args, args):
            setattr(self, arg, value)

        for key, value in six.iteritems(kwargs):
            setattr(self, key, value)

    @property
    def _node_id(self):
        return id(self)

    def _init_attr(self, attr, val):
        if not hasattr(self, attr):
            setattr(self, attr, val)

    @property
    def args(self):
        return tuple(getattr(self, arg, None) for arg in self._args)

    def iter_args(self):
        for name, arg in zip(self._args, self.args):
            yield name, arg

    @property
    def extra_args(self):
        return tuple(getattr(self, extra_arg, None)
                     for extra_arg in self._extra_args)

    def arg_name_values(self, extra=False):
        arg_names = self._args if not extra else self._args + self._extra_args
        arg_values = self.args if not extra else self.args + self.extra_args
        for arg_name, arg in zip(arg_names, arg_values):
            yield arg_name, arg

    def _data_source(self):
        yield None

    def data_source(self):
        stop_cond = lambda x: hasattr(x, '_source_data') and x._source_data is not None
        for n in self.traverse(top_down=True, unique=True, stop_cond=stop_cond):
            for ds in n._data_source():
                if ds is not None:
                    yield ds

    def substitute(self, old_arg, new_arg, dag=None):
        if dag is not None:
            dag.substitute(old_arg, new_arg, parents=[self])
            return

        if hasattr(old_arg, '_name') and old_arg._name is not None and \
                        new_arg._name is None:
            new_arg._name = old_arg._name

        for arg_name, arg in self.arg_name_values(extra=True):
            if not isinstance(arg, (list, tuple)):
                if arg is old_arg:
                    setattr(self, arg_name, new_arg)
            else:
                subs = list(arg)
                for i in range(len(subs)):
                    if subs[i] is old_arg:
                        subs[i] = new_arg
                setattr(self, arg_name, type(arg)(subs))

    def children(self, extra=False):
        args = []

        all_args = self.args if not extra else self.args + self.extra_args
        for arg in all_args:
            if isinstance(arg, (list, tuple)):
                args.extend(arg)
            else:
                args.append(arg)

        return [arg for arg in args if arg is not None]

    def leaves(self):
        for n in self.traverse(unique=True):
            if len(n.children()) == 0:
                yield n

    def traverse(self, top_down=False, unique=False, traversed=None,
                 extra=False, stop_cond=None):
        traversed = traversed if traversed is not None else set()

        def is_trav(n):
            if not unique:
                return False
            if n._node_id in traversed:
                return True
            traversed.add(n._node_id)
            return False

        q = deque()
        q.append(self)
        if is_trav(self):
            return

        checked = set()
        yields = set()
        while len(q) > 0:
            curr = q.popleft()
            if top_down:
                yield curr
                if stop_cond is None or not stop_cond(curr):
                    children = [c for c in curr.children(extra=extra)
                                if not is_trav(c)]
                    q.extendleft(children[::-1])
            else:
                if curr._node_id not in checked:
                    children = curr.children(extra=extra)
                    if len(children) == 0:
                        yield curr
                    else:
                        q.appendleft(curr)
                        if stop_cond is None or not stop_cond(curr):
                            q.extendleft([c for c in children
                                          if not is_trav(c) or c._node_id not in checked][::-1])
                    checked.add(curr._node_id)
                else:
                    if curr._node_id not in yields:
                        yield curr
                        if unique:
                            yields.add(curr._node_id)

    def __eq__(self, other):
        return self.equals(other)

    def equals(self, other):
        if other is None:
            return False

        if not isinstance(other, type(self)):
            return False

        def slot_values(obj):
            return [getattr(obj, slot, None) for slot in utils.get_attrs(obj)]

        def cmp(x, y):
            if isinstance(x, Node):
                res = x.equals(y)
            elif isinstance(y, Node):
                res = y.equals(y)
            elif isinstance(x, (tuple, list)) and \
                    isinstance(y, (tuple, list)):
                res = all(map(cmp, x, y))
            else:
                res = x == y
            return res
        return all(map(cmp, slot_values(self), slot_values(other)))

    def __hash__(self):
        return hash((type(self), tuple(self.children())))

    def is_ancestor(self, other):
        other = utils.get_proxied_expr(other)
        for n in self.traverse(top_down=True, unique=True):
            if n is other:
                return True
        return False

    def path(self, other, strict=False):
        all_apaths = self.all_path(other, strict=strict)
        try:
            return next(all_apaths)
        except StopIteration:
            return

    def _all_path(self, other):
        if self is other:
            yield [self, ]

        node_poses = defaultdict(lambda: 0)

        q = deque()
        q.append(self)

        while len(q) > 0:
            curr = q[-1]
            children = curr.children()

            pos = node_poses[curr._node_id]
            if len(children) == 0 or pos >= len(children):
                q.pop()
                # TODO: add this in the future
                # if pos >= len(children):
                #     node_poses[curr._node_id] = 0
                continue

            n = children[pos]
            q.append(n)
            if n is other:
                yield list(q)
                q.pop()

            node_poses[curr._node_id] += 1

    def all_path(self, other, strict=False):
        # remember, if the node has been changed into another one during traversing
        # the modification may not be applied to the paths

        i = 0
        for i, path in zip(itertools.count(1), self._all_path(other)):
            yield path

        if i == 0 and not strict:
            for path in other._all_path(self):
                yield path

    def _attr_dict(self):
        slots = utils.get_attrs(self)

        return dict((attr, getattr(self, attr, None)) for attr in slots)

    def _copy_type(self):
        return type(self)

    def copy(self, clear_keys=None, **kw):
        proxied = utils.get_proxied_expr(self)
        attr_dict = proxied._attr_dict()
        if clear_keys is not None:
            for k in clear_keys:
                attr_dict.pop(k, None)
        attr_dict.update(kw)
        copied = type(proxied)(**attr_dict)
        return copied

    def copy_to(self, target):
        slots = utils.get_attrs(self)
        for attr in slots:
            if hasattr(self, attr):
                setattr(target, attr, getattr(self, attr, None))

    def copy_tree(self, on_copy=None, extra=False, stop_cond=None):
        if on_copy is not None and not isinstance(on_copy, Iterable):
            on_copy = [on_copy, ]

        expr_id_to_copied = dict()

        def get(n):
            if n is None:
                return n
            try:
                return expr_id_to_copied[n._node_id]
            except KeyError:
                if stop_cond is not None:
                    return n
                raise

        for node in self.traverse(unique=True, extra=extra, stop_cond=stop_cond):
            node = utils.get_proxied_expr(node)
            attr_dict = node._attr_dict()

            for arg_name, arg in node.arg_name_values(extra=extra):
                if isinstance(arg, (tuple, list)):
                    attr_dict[arg_name] = type(arg)(get(it) for it in arg)
                else:
                    attr_dict[arg_name] = get(arg)

            copied_node = type(node)(**attr_dict)
            expr_id_to_copied[node._node_id] = copied_node
            if on_copy is not None:
                [func(node, copied_node) for func in on_copy]

        return expr_id_to_copied[self._node_id]

    def to_dag(self, copy=True, on_copy=None, dag=None, validate=True, stop_cond=None):
        if copy:
            expr = self.copy_tree(on_copy=on_copy, extra=True)
        else:
            expr = self

        dag = dag or ExprDAG(expr)
        queue = Queue()

        if not dag.contains_node(expr):
            dag.add_node(expr)
            queue.put(expr)

        traversed = set()
        traversed.add(expr._node_id)

        while not queue.empty():
            node = queue.get()

            for child in node.children(extra=True):
                if not dag.contains_node(child):
                    dag.add_node(child)
                if not dag.contains_edge(child, node):
                    dag.add_edge(child, node, validate=False)

                if child._node_id in traversed:
                    continue
                traversed.add(child._node_id)
                if stop_cond is None or not stop_cond(child):
                    queue.put(child)

        if validate:
            dag._validate()  # validate the DAG
        return dag

    def __getstate__(self):
        slots = utils.get_attrs(self)

        return tuple((slot, object.__getattribute__(self, slot)) for slot in slots
                     if not slot.startswith('__'))

    def __setstate__(self, state):
        self.__init__(**dict(state))


def _extract_df_inputs(o):
    if isinstance(o, Node):
        yield o
    elif isinstance(o, dict):
        for v in itertools.chain(*(_extract_df_inputs(dv) for dv in six.itervalues(o))):
            if v is not None:
                yield v
    elif isinstance(o, (list, set, tuple)):
        for v in itertools.chain(*(_extract_df_inputs(dv) for dv in o)):
            if v is not None:
                yield v
    else:
        yield None


class ExprProxy(object):
    def __init__(self, expr, d=None, compare=False):
        if d is not None:
            def callback(_):
                if self in d:
                    del d[self]
        else:
            callback = None
        self._ref = weakref.ref(expr, callback)
        self._cmp = compare
        self._hash = hash(expr)
        self._expr_id = expr._node_id

    def __call__(self):
        return self._ref()

    def __hash__(self):
        return self._hash

    def __eq__(self, other):
        if isinstance(other, ExprProxy):
            if self._ref() is not None and other() is not None:
                return self._ref() is other()
            return self._expr_id == other._expr_id

        obj = self._ref()
        if obj is not None and self._cmp:
            return obj.equals(other)

        return self._expr_id == other._node_id


class ExprDAG(DAG):
    def __init__(self, root, dag=None):
        self._root = weakref.ref(root)
        super(ExprDAG, self).__init__()
        if dag is not None:
            self._graph = dag._graph
            self._map = dag._map

    @property
    def root(self):
        return self._root()

    @root.setter
    def root(self, root):
        self._root = weakref.ref(root)

    def ensure_all_nodes_in_dag(self):
        self.root.to_dag(copy=False, dag=self, validate=False)

    def traverse(self, root=None, top_down=False, traversed=None, stop_cond=None):
        root = root or self.root
        return root.traverse(top_down=top_down, unique=True,
                             traversed=traversed, stop_cond=stop_cond)

    def substitute(self, expr, new_expr, parents=None):
        if expr is self.root:
            self.root = new_expr

        parents = self.successors(expr) if parents is None else parents

        if expr._need_cache:
            new_expr.cache()

        q = Queue()
        q.put(new_expr)

        while not q.empty():
            node = q.get()
            if not self.contains_node(node):
                self.add_node(node)

            for child in node.children():
                if not self.contains_node(child):
                    q.put(child)
                    self.add_node(child)
                if not self.contains_edge(child, node):
                    self.add_edge(child, node, validate=False)

        for parent in parents:
            parent.substitute(expr, new_expr)
            self.add_edge(new_expr, parent, validate=False)
            try:
                self.remove_edge(expr, parent)
            except KeyError:
                pass

    def prune(self):
        while True:
            nodes = [n for n, succ in six.iteritems(self._graph) if len(succ) == 0]
            if len(nodes) == 1 and nodes[0] is self.root:
                break
            for node in nodes:
                if node is not self.root:
                    self.remove_node(node)

    def closest_ancestors(self, node, cond):
        collected = set()
        stop_cond = lambda n: not collected.intersection(self.successors(n))
        for n in self.bfs(node, self.predecessors, stop_cond):
            if cond(n):
                collected.add(n)
                yield n


class ExprDictionary(dict):
    def _ref(self, obj, ref_self=False):
        r = self if ref_self else None
        return obj if isinstance(obj, ExprProxy) else ExprProxy(obj, d=r)

    def __getitem__(self, item):
        if item is None:
            raise KeyError
        return dict.__getitem__(self, self._ref(item))

    def __setitem__(self, key, value):
        if key is None:
            raise KeyError
        return dict.__setitem__(self, self._ref(key, True), value)

    def __iter__(self):
        for k in dict.__iter__(self):
            yield k()

    def __delitem__(self, key):
        if key is None:
            raise KeyError
        return dict.__delitem__(self, self._ref(key))

    def __contains__(self, item):
        if item is None:
            return False
        return dict.__contains__(self, self._ref(item))

    def get(self, k, d=None):
        if k is None:
            return d
        return dict.get(self, self._ref(k), d)

    def has_key(self, k):
        if k is None:
            return False
        return dict.has_key(self, self._ref(k))

    def pop(self, k, d=None):
        if k is None:
            return False
        return dict.pop(self, self._ref(k), d)

    def popitem(self):
        k, v = dict.popitem(self)
        return k(), v

    def setdefault(self, k, d=None):
        if k is None:
            raise KeyError
        return dict.setdefault(self, self._ref(k, True), d)

    def update(self, E=None, **F):
        if hasattr(E, 'keys'):
            for k in E.keys():
                self[k] = E[k]
        elif E is not None:
            for k, v in E:
                self[k] = v
        else:
            for k, v in six.iteritems(F):
                self[k] = v

    if six.PY2:
        def items(self):
            return [(k(), v) for k, v in dict.items(self)]

        def keys(self):
            return [k() for k in dict.keys(self)]

        def iteritems(self):
            for k, v in dict.iteritems(self):
                yield k(), v

        def iterkeys(self):
            for k in dict.iterkeys(self):
                yield k()
    else:
        def items(self):
            for k, v in dict.items(self):
                yield k(), v

        def keys(self):
            for k in dict.keys(self):
                yield k()
