#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from datetime import date, datetime
from decimal import Decimal

from ...expr.reduction import *
from ...expr.arithmetic import BinOp, Add, Substract, Power, Invert, Negate, Abs
from ...expr.merge import JoinCollectionExpr, UnionCollectionExpr
from ...expr.element import MappedExpr, Func
from ...expr.window import CumSum, NthValue, QCut
from ...expr.datetimes import DTScalar
from ...expr.collections import RowAppliedCollectionExpr
from ...expr import element, strings, datetimes, composites
from ...utils import is_source_collection, traverse_until_source
from ... import types as df_types
from . import types
from .models import MemCacheReference
from ..core import Backend
from .... import utils
from ....models import Function
from ..errors import CompileError


BINARY_OP_COMPILE_DIC = {
    'Add': '+',
    'Substract': '-',
    'Multiply': '*',
    'Divide': '/',
    'Greater': '>',
    'GreaterEqual': '>=',
    'Less': '<',
    'LessEqual': '<=',
    'Equal': '==',
    'NotEqual': '!=',
    'And': 'and',
    'Or': 'or'
}

UNARY_OP_COMPILE_DIC = {
    'Negate': '-'
}

WINDOW_COMPILE_DIC = {
    'CumSum': 'sum',
    'CumMean': 'avg',
    'CumMedian': 'median',
    'CumStd': 'stddev',
    'CumMax': 'max',
    'CumMin': 'min',
    'CumCount': 'count',
    'NthValue': 'nth_value',
    'Lag': 'lag',
    'Lead': 'lead',
    'Rank': 'rank',
    'DenseRank': 'dense_rank',
    'PercentRank': 'percent_rank',
    'RowNumber': 'row_number',
    'QCut': 'ntile',
    'CumeDist': 'cume_dist',
}

MATH_COMPILE_DIC = {
    'Abs': 'abs',
    'Sqrt': 'sqrt',
    'Sin': 'sin',
    'Sinh': 'sinh',
    'Cos': 'cos',
    'Cosh': 'cosh',
    'Tan': 'tan',
    'Tanh': 'tanh',
    'Exp': 'exp',
    'Arccos': 'acos',
    'Arcsin': 'asin',
    'Arctan': 'atan',
    'Ceil': 'ceil',
    'Floor': 'floor'
}

DATE_PARTS_DIC = {
    'Year': 'yyyy',
    'Month': 'mm',
    'Day': 'dd',
    'Hour': 'hh',
    'Minute': 'mi',
    'Second': 'ss'
}


class OdpsSQLCompiler(Backend):
    """
    OdpsSQLCompiler will compile an Expr into an ODPS SQL.
    """

    def __init__(self, ctx, indent_size=2, beautify=False):
        self._ctx = ctx
        self._indent_size = indent_size
        self._beautify = beautify

        # use for `join` or `union` operations etc.
        self._sub_compiles = defaultdict(lambda: list())
        self._union_no_alias = dict()
        # When encountering `join` or `union`, we will try to compile all child branches,
        # for each nodes of these branches, we should not check the uniqueness,
        # when compilation finishes, we substitute the children of `join` or `union` with None,
        # so the upcoming compilation will not visit its children.
        # When everything are done, we use the callbacks to substitute back the original children
        # of the `join` or `union` node.
        self._callbacks = list()

        # store the expr ids to do mem cache
        self._mem_ref_caches = set()

        self._re_init()

    def _re_init(self):
        self._select_clause = None
        self._from_clause = None
        self._where_clause = None
        self._group_by_clause = None
        self._having_clause = None
        self._order_by_clause = None
        self._table_sample_clause = None
        self._limit = None

    def _cleanup(self):
        self._sub_compiles = defaultdict(lambda: list())
        for callback in self._callbacks:
            callback()
        self._callbacks = list()

    @classmethod
    def _need_recursive_handle_in_expr(cls, node):
        return isinstance(node, (element.IsIn, element.NotIn)) and \
               not all(n is None for n in node.args) and \
               isinstance(node.values[0], SequenceExpr)

    @classmethod
    def _retrieve_until_find_root(cls, expr):
        for node in traverse_until_source(expr, top_down=True, unique=True):
            if isinstance(node, (JoinCollectionExpr, UnionCollectionExpr, LateralViewCollectionExpr)) and \
                    not all(n is None for n in node.args):
                yield node
            else:
                ins = [n for n in node.children() if cls._need_recursive_handle_in_expr(n)]
                if len(ins) > 0:
                    for n in ins:
                        yield n

    def _compile_union_node(self, expr, traversed):
        if isinstance(expr.lhs, UnionCollectionExpr):
            self._union_no_alias[expr.lhs] = True
        compiled = self._compile(expr.lhs)

        self._sub_compiles[expr].append(compiled)

        compiled = self._compile(expr.rhs)
        self._sub_compiles[expr].append(compiled)

        args = expr.args
        self._ctx._expr_raw_args[id(expr)] = args
        for arg_name in expr._args:
            setattr(expr, arg_name, None)

        def cb():
            for arg_name, arg in zip(expr._args, args):
                setattr(expr, arg_name, arg)

        self._callbacks.append(cb)

    def _compile_join_node(self, expr, traversed):
        travs = set()

        compiled, trav = self._compile(expr.lhs, return_traversed=True)
        travs.update(trav)
        if not is_source_collection(expr.lhs) and not isinstance(expr.lhs, JoinCollectionExpr):
            self._sub_compiles[expr].append(
                '(\n{0}\n) {1}'.format(utils.indent(compiled, self._indent_size),
                                     self._ctx.get_collection_alias(expr.lhs, create=True)[0])
            )
        else:
            self._sub_compiles[expr].append(self._ctx.get_expr_compiled(expr.lhs))

        compiled, trav = self._compile(expr.rhs, return_traversed=True)
        travs.update(trav)
        if not is_source_collection(expr.rhs):
            self._sub_compiles[expr].append(
                '(\n{0}\n) {1}'.format(utils.indent(compiled, self._indent_size),
                                     self._ctx.get_collection_alias(expr.rhs, create=True)[0])
            )
        else:
            self._sub_compiles[expr].append(self._ctx.get_expr_compiled(expr.rhs))

        if expr.predicate is None:
            self._sub_compiles[expr].append(None)
            traversed.update(travs)
        else:
            self._compile(expr.predicate, traversed)
            self._sub_compiles[expr].append(self._ctx.get_expr_compiled(expr.predicate))

        if expr._mapjoin:
            self._ctx._mapjoin_hints.append(self._ctx.get_collection_alias(expr.rhs)[0])

        if expr._skewjoin:
            skewjoin_expr = self._ctx.get_collection_alias(expr.rhs)[0]
            if isinstance(expr._skewjoin, list):
                skewjoin_expr += '({0})'.format(', '.join(expr._skewjoin))
            if expr._skewjoin_values:
                skewjoin_expr += '({0})'.format(
                    ', '.join(
                        '({0})'.format(', '.join(repr(s) for s in values))
                        for values in expr._skewjoin_values
                    )
                )
            self._ctx._skewjoin_hints.append(skewjoin_expr)

        args = expr.args
        self._ctx._expr_raw_args[id(expr)] = args
        for arg_name in expr._args:
            setattr(expr, arg_name, None)

        def cb():
            for arg_name, arg in zip(expr._args, args):
                setattr(expr, arg_name, arg)

        self._callbacks.append(cb)

    @classmethod
    def _find_table(cls, expr):
        return next(it for it in expr.traverse(top_down=True, unique=True)
                    if isinstance(it, CollectionExpr))

    def _compile_in_node(self, expr, traversed):
        self._compile(expr.input)
        self._sub_compiles[expr].append(self._ctx.get_expr_compiled(expr.input))

        to_sub = self._find_table(expr.values[0])[[expr.values[0], ]]
        compiled = self._compile(to_sub)
        self._sub_compiles[expr].append(compiled)

        args = expr.args
        self._ctx._expr_raw_args[id(expr)] = args
        for arg_name in expr._args:
            setattr(expr, arg_name, None)

        def cb():
            for arg_name, arg in zip(expr._args, args):
                setattr(expr, arg_name, arg)

        self._callbacks.append(cb)

    def _compile_lateral_view(self, expr, traversed):
        travs = set()
        compiled_input, trav = self._compile(expr.input, return_traversed=True)
        travs.update(trav)

        if not is_source_collection(expr.input) and not isinstance(expr.input, LateralViewCollectionExpr):
            compiled_input = '(\n{0}\n) {1}'.format(utils.indent(compiled_input, self._indent_size),
                                                    self._ctx.get_collection_alias(expr.input, create=True)[0])
        else:
            compiled_input = self._ctx.get_expr_compiled(expr.input)

        from_lines = [compiled_input]
        for lview in expr.lateral_views:
            self._compile(lview)
            from_lines.append(self._ctx.get_expr_compiled(lview))
        from_clause = ' \n'.join(from_lines)

        traversed.update(travs)

        self._sub_compiles[expr].append(from_clause)

        args = expr.args
        self._ctx._expr_raw_args[id(expr)] = args
        for arg_name in expr._args:
            setattr(expr, arg_name, None)

        def cb():
            for arg_name, arg in zip(expr._args, args):
                setattr(expr, arg_name, arg)

        self._callbacks.append(cb)

    def _compile(self, expr, traversed=None,
                 return_traversed=False, root_expr=None):
        roots = self._retrieve_until_find_root(expr)

        if traversed is None:
            traversed = set()

        for root in roots:
            if root is not None:
                if isinstance(root, JoinCollectionExpr):
                    self._compile_join_node(root, traversed)
                elif isinstance(root, UnionCollectionExpr):
                    self._compile_union_node(root, traversed)
                elif isinstance(root, LateralViewCollectionExpr):
                    self._compile_lateral_view(root, traversed)
                elif isinstance(root, (element.IsIn, element.NotIn)):
                    self._compile_in_node(root, traversed)
                root.accept(self)
                traversed.add(id(root))

        for node in traverse_until_source(expr):
            if id(node) not in traversed:
                node.accept(self)
                traversed.add(id(node))

        if (
            expr is root_expr
            and self._select_clause is None
            and (self._ctx._mapjoin_hints or self._ctx._skewjoin_hints)
        ):
            self.add_select_clause(expr, '* ')

        sql = self.to_sql().strip()
        if not return_traversed:
            return sql
        return sql, traversed

    def compile(self, expr):
        try:
            sql = self._compile(expr, root_expr=expr)
            symbol_columns = dict(self._ctx.get_all_need_alias_column_symbols())
            sql = self._fill_back_columns(sql, symbol_columns)
            sql = self._re_join_select_fields(sql, symbol_columns)

            if self._mem_ref_caches:
                dep_sqls = self._ctx.get_mem_cache_dep_sqls(*self._mem_ref_caches)
                dep_sqls = [dep_sql if dep_sql.endswith(';') else dep_sql + ';'
                            for dep_sql in dep_sqls]
                return dep_sqls + [sql,]

            return sql
        finally:
            self._cleanup()

    def to_sql(self):
        if not self._from_clause.startswith('SELECT'):
            lines = [
                'SELECT {0} '.format(self._select_clause or '*'),
                'FROM {0} '.format(self._from_clause),
            ]
        else:
            # special case due to `union`
            lines = [self._from_clause]

        # tablesample need to be placed right after from clause
        if self._table_sample_clause:
            lines.append('TABLESAMPLE {0}'.format(self._table_sample_clause))
        if self._where_clause:
            lines.append('WHERE {0} '.format(self._where_clause))
        if self._group_by_clause:
            lines.append(self._group_by_clause)
            if self._having_clause:
                lines.append('HAVING {0} '.format(self._having_clause))
        if self._order_by_clause:
            if self._order_by_clause.startswith('ORDER BY') and not self._limit and \
                    options.df.odps.sort.limit:
                # for `order by`, limit is required
                # for `sort by`, limit is unnecessary
                self._limit = options.df.odps.sort.limit
            lines.append(self._order_by_clause)
        if self._limit is not None:
            lines.append('LIMIT {0}'.format(self._limit))

        self._re_init()
        return '\n'.join(lines)

    def _fill_back_columns(self, sql, symbols_to_columns):
        symbol_compiled = dict()
        for symbol, column in six.iteritems(symbols_to_columns):
            try:
                collection, name = self._retrieve_column_alias_collection(column)
            except KeyError:
                # sorted column relative, just ignore
                continue
            symbol_compiled[symbol] = '{0}.{1}'.format(
                self._ctx.get_collection_alias(collection)[0],
                self._quote(name))
        sql = sql % symbol_compiled

        reg = re.compile(r'###\[(col_\d+)\]##')

        def repl(matched):
            symbol = matched.group(1)
            compiled = symbol_compiled[symbol]
            compiled_name = self._unquote(compiled)
            column = symbols_to_columns[symbol]
            if compiled_name != column.name:
                return ' AS {0}'.format(self._quote(column.name))
            else:
                return ''

        return reg.sub(repl, sql)

    def _re_join_select_fields(self, sql, symbols_to_columns):
        if not self._beautify:
            return sql

        reg = re.compile(r'/{_i(\d+)}')

        for select_idx in reg.findall(sql):
            s_regex_str = '//{{_i{0}}}(.+?){{_i{0}}}//'.format(select_idx)
            regex_str = '\n?( *)/{{_i{0}}}({1})+/'.format(select_idx, s_regex_str)
            regex = re.compile(regex_str, (re.M | re.DOTALL))
            s_regex = re.compile(s_regex_str, (re.M | re.DOTALL))

            def repl(matched):
                space = len(matched.group(1))
                joined = matched.group()

                fields = s_regex.findall(joined)
                sub = self._join_compiled_fields(fields)
                if not joined.startswith('\n'):
                    sub = sub.lstrip('\n')
                if space > self._indent_size:
                    return utils.indent(sub, space-self._indent_size)
                return sub

            sql = regex.sub(repl, sql)

        return sql

    def _retrieve_column_alias_collection(self, expr):
        column_name = expr.source_name

        collection = expr.input
        while True:
            if isinstance(collection, JoinCollectionExpr):
                idx, column_name = collection._column_origins[column_name]
                args = self._ctx._expr_raw_args[id(collection)]  # get the args which are substituted out
                lhs, rhs = args[0], args[1]
                collection = (lhs, rhs)[idx]
            elif isinstance(collection, LateralViewCollectionExpr):
                args = self._ctx._expr_raw_args[id(collection)]  # get the args which are substituted out
                while True:
                    for lv in args[2]:
                        if utils.to_lower_str(column_name) in lv.schema._name_indexes:
                            return lv, column_name
                    if isinstance(args[0], LateralViewCollectionExpr):
                        args = self._ctx._expr_raw_args[id(args[0])]
                    else:
                        return args[0], column_name
            elif self._ctx.get_collection_alias(collection, silent=True):
                return collection, column_name
            elif isinstance(getattr(collection, 'input', None), CollectionExpr):
                collection = collection.input

        raise CompileError('Cannot find table alias for column: \n%s' % repr_obj(expr))

    def sub_sql_to_from_clause(self, expr):
        sql = self.to_sql()

        alias, _ = self._ctx.get_collection_alias(expr, create=True)
        from_clause = '(\n{0}\n) {1}'.format(
            utils.indent(sql, self._indent_size), alias
        )

        self._re_init()
        self._from_clause = from_clause

    def add_select_clause(self, expr, select_clause):
        if self._select_clause is not None:
            self.sub_sql_to_from_clause(expr.input)
        elif self._order_by_clause is not None:
            self.sub_sql_to_from_clause(expr.input)
        elif isinstance(expr, Summary) and self._limit is not None:
            self.sub_sql_to_from_clause(expr.input)
        elif isinstance(expr, (GroupByCollectionExpr, MutateCollectionExpr, RowAppliedCollectionExpr)) and \
                self._limit is not None:
            self.sub_sql_to_from_clause(expr.input)
        elif isinstance(expr.input, ReshuffledCollectionExpr):
            self.sub_sql_to_from_clause(expr.input)

        self._select_clause = select_clause
        join_hints = []
        if self._ctx._mapjoin_hints:
            join_hints.append(
                'mapjoin({0})'.format(', '.join(self._ctx._mapjoin_hints))
            )
            self._ctx._mapjoin_hints = []
        if self._ctx._skewjoin_hints:
            join_hints.append(
                'skewjoin({0})'.format(', '.join(self._ctx._skewjoin_hints))
            )
            self._ctx._skewjoin_hints = []

        if join_hints:
            self._select_clause = '/*+ {0} */ {1}'.format(
                ', '.join(join_hints), self._select_clause
            )

    def add_from_clause(self, expr, from_clause):
        if self._from_clause is None:
            self._from_clause = from_clause

    def add_where_clause(self, expr, where_clause):
        if any(clause is not None for clause in
               (self._where_clause, self._select_clause, self._limit)):
            self.sub_sql_to_from_clause(expr.input)

        self._where_clause = where_clause

    def add_group_by_clause(self, expr, group_by_clause):
        if self._group_by_clause is not None:
            self.sub_sql_to_from_clause(expr.input)
        elif isinstance(expr, ReshuffledCollectionExpr) and \
                self._select_clause is not None:
            self.sub_sql_to_from_clause(expr.input)

        self._group_by_clause = group_by_clause

    def add_having_clause(self, expr, having_clause):
        if self._having_clause is None:
            self._having_clause = having_clause

        assert having_clause == self._having_clause

    def add_order_by_clause(self, expr, order_by_clause):
        if self._order_by_clause is not None:
            self.sub_sql_to_from_clause(expr.input)

        self._order_by_clause = order_by_clause

    def add_table_sample_clause(self, expr, table_sample_clause):
        if any(clause is not None for clause in
               (self._where_clause, self._table_sample_clause, self._group_by_clause,
                self._having_clause, self._order_by_clause, self._limit)):
            self.sub_sql_to_from_clause(expr.input)
        self._table_sample_clause = table_sample_clause

    def set_limit(self, expr, limit):
        if self._limit is not None:
            self.sub_sql_to_from_clause(expr.input)

        self._limit = limit

    def visit_source_collection(self, expr):
        source_data = expr._source_data
        alias = self._ctx.register_collection(expr)
        if isinstance(source_data, MemCacheReference):
            from_clause = '{0} {1}'.format(source_data.ref_name, alias)
            self.add_from_clause(expr, from_clause)
            self._ctx.add_expr_compiled(expr, from_clause)
            self._mem_ref_caches.add(source_data.expr_id)
        else:
            table_parts = [source_data.project.name]
            schema = source_data.get_schema()
            if schema is not None:
                table_parts.append(schema.name)

            if options.df.quote:
                table_parts.append("`%s`" % source_data.name)
            else:
                table_parts.append(source_data.name)

            name = '.'.join(table_parts)

            from_clause = '{0} {1}'.format(name, alias)
            self.add_from_clause(expr, from_clause)
            self._ctx.add_expr_compiled(expr, from_clause)

    def _compile_select_field(self, field):
        compiled = self._ctx.get_expr_compiled(field)

        if not isinstance(field, Column) or field._source_data_type != field._data_type:
            compiled = '{0} AS {1}'.format(compiled, self._quote(field.name))
        else:
            if compiled.startswith('%(') and compiled.endswith(')s'):
                symbol = compiled[2:-2]
                compiled = '{0}###[{1}]##'.format(compiled, symbol)
            else:
                compiled_name = self._unquote(compiled)
                if field.name != compiled_name:
                    compiled = '{0} AS {1}'.format(compiled, self._quote(field.name))

        return compiled

    def _join_select_fields(self, fields):
        if not self._beautify:
            return ', '.join(fields)
        else:
            select_id = self._ctx.next_select_id()

            def h(field):
                return '//{{_i{0}}}{1}{{_i{0}}}//'.format(select_id, field)

            return utils.indent('\n/{{_i{0}}}{1}/'.format(select_id, ''.join([h(f) for f in fields])),
                                self._indent_size)

    def _join_compiled_fields(self, fields):
        if not self._beautify:
            return ', '.join(fields)
        else:
            buf = six.StringIO()
            buf.write('\n')

            split_fields = [field.rsplit(' AS ', 1) for field in fields]
            get = lambda s: s if '\n' not in s else s.rsplit('\n', 1)[1]
            max_length = max(len(get(f[0])) for f in split_fields)

            for f in split_fields:
                if len(f) > 1:
                    buf.write(f[0].ljust(max_length))
                    buf.write(' AS ')
                    buf.write(f[1])
                else:
                    buf.write(f[0])
                buf.write(',\n')

            return utils.indent(buf.getvalue()[:-2], self._indent_size)

    def visit_project_collection(self, expr):
        fields = expr._fields
        compiled_fields = [self._compile_select_field(field)
                           for field in fields]

        compiled = self._join_select_fields(compiled_fields)

        self._ctx.add_expr_compiled(expr, compiled)

        self.add_select_clause(expr, compiled)

    def visit_lateral_view(self, expr):
        compiled = self._sub_compiles[expr][0]
        self._ctx.add_expr_compiled(expr, compiled)
        self.add_from_clause(expr, compiled)

    def visit_filter_collection(self, expr):
        predicate = expr.args[1]
        compiled = self._ctx.get_expr_compiled(predicate)

        self._ctx.add_expr_compiled(expr, compiled)
        self.add_where_clause(expr, compiled)

    def visit_filter_partition_collection(self, expr):
        compiled = self._ctx.get_expr_compiled(expr._predicate)

        self._ctx.add_expr_compiled(expr, compiled)
        self.add_where_clause(expr, compiled)

        compiled_fields = [self._compile_select_field(field)
                           for field in expr.fields]

        compiled = self._join_select_fields(compiled_fields)

        self._ctx.add_expr_compiled(expr, compiled)
        self.add_select_clause(expr, compiled)

    def visit_slice_collection(self, expr):
        sliced = expr._indexes
        if sliced[0] is not None:
            raise NotImplementedError
        if sliced[2] is not None:
            raise NotImplementedError
        if sliced[1].value <= 0:
            raise CompileError('limit number must be greater than 0')

        self.set_limit(expr, sliced[1].value)

    def visit_element_op(self, expr):
        if isinstance(expr, element.IsNull):
            compiled = '{0} IS NULL'.format(self._parenthesis(expr.input))
        elif isinstance(expr, element.NotNull):
            compiled = '{0} IS NOT NULL'.format(self._parenthesis(expr.input))
        elif isinstance(expr, element.FillNa):
            compiled = 'IF(%(input)s IS NULL, %(value)s, %(else_value)s)' % {
                'input': self._parenthesis(expr.input),
                'value': self._ctx.get_expr_compiled(expr._fill_value),
                'else_value': self._ctx.get_expr_compiled(expr.input),
            }
        elif isinstance(expr, element.IsIn):
            if expr.values is not None:
                compiled = '{0} IN ({1})'.format(
                    self._ctx.get_expr_compiled(expr.input),
                    ', '.join(self._ctx.get_expr_compiled(it) for it in expr.values)
                )
            else:
                subs = self._sub_compiles[expr]
                compiled = '{0} IN ({1})'.format(
                    subs[0], subs[1].replace('\n', '')
                )
        elif isinstance(expr, element.NotIn):
            if expr.values is not None:
                compiled = '{0} NOT IN ({1})'.format(
                    self._ctx.get_expr_compiled(expr.input),
                    ', '.join(self._ctx.get_expr_compiled(it) for it in expr.values)
                )
            else:
                subs = self._sub_compiles[expr]
                compiled = '{0} NOT IN ({1})'.format(
                    subs[0], subs[1].replace('\n', '')
                )
        elif isinstance(expr, element.IfElse):
            compiled = 'IF({0}, {1}, {2})'.format(
                self._ctx.get_expr_compiled(expr._input),
                self._ctx.get_expr_compiled(expr._then),
                self._ctx.get_expr_compiled(expr._else),
            )
        elif isinstance(expr, element.Switch):
            case = self._ctx.get_expr_compiled(expr.case) + ' ' \
                if expr.case is not None else ''
            lines = ['CASE {0}'.format(case)]
            for pair in zip(expr.conditions, expr.thens):
                args = [self._ctx.get_expr_compiled(p) for p in pair]
                lines.append('WHEN {0} THEN {1} '.format(*args))
            if expr.default is not None:
                lines.append('ELSE {0} '.format(self._ctx.get_expr_compiled(expr.default)))
            lines.append('END')
            if self._beautify:
                for i in range(1, len(lines) - 1):
                    lines[i] = utils.indent(lines[i], self._indent_size)
                compiled = '\n'.join(lines)
            else:
                compiled = ''.join(lines)
        elif isinstance(expr, element.IntToDatetime):
            compiled = 'FROM_UNIXTIME({0})'.format(
                self._ctx.get_expr_compiled(expr._input),
            )
        else:
            raise NotImplementedError

        self._ctx.add_expr_compiled(expr, compiled)

    def _parenthesis(self, child):
        if isinstance(child, BinOp):
            return '(%s)' % self._ctx.get_expr_compiled(child)
        elif isinstance(child, (element.IsNull, element.NotNull, element.IsIn, element.NotIn,
                                element.Between, element.Switch, element.Cut)):
            return '(%s)' % self._ctx.get_expr_compiled(child)
        else:
            return self._ctx.get_expr_compiled(child)

    def visit_binary_op(self, expr):
        if isinstance(expr, Add) and expr.dtype == df_types.string:
            compiled = 'CONCAT({0}, {1})'.format(
                self._ctx.get_expr_compiled(expr.lhs),
                self._ctx.get_expr_compiled(expr.rhs),
            )
        elif isinstance(expr, (Add, Substract)) and expr.dtype == df_types.datetime:
            if isinstance(expr, Add) and \
                    all(child.dtype == df_types.datetime for child in (expr.lhs, expr.rhs)):
                raise CompileError('Cannot add two datetimes')
            if isinstance(expr.rhs, DTScalar) or (isinstance(expr, Add) and expr.lhs, DTScalar):
                if isinstance(expr.rhs, DTScalar):
                    dt, scalar = expr.lhs, expr.rhs
                else:
                    dt, scalar = expr.rhs, expr.lhs

                class_name = type(scalar).__name__[:-6]
                date_part = DATE_PARTS_DIC[class_name]
                val = scalar.value
                if isinstance(expr, Substract):
                    val = -val
                compiled = 'DATEADD({0}, {1}, {2})'.format(
                    self._ctx.get_expr_compiled(dt), repr(val), repr(date_part)
                )
        else:
            compiled, op = None, None
            try:
                op = BINARY_OP_COMPILE_DIC[expr.node_name].upper()
            except KeyError:
                if isinstance(expr, Power):
                    compiled = 'POW({0}, {1})'.format(
                        self._ctx.get_expr_compiled(expr.lhs),
                        self._ctx.get_expr_compiled(expr.rhs)
                    )
                    if not isinstance(expr.dtype, df_types.Float):
                        compiled = self._cast(compiled, df_types.float64, expr.dtype)
                else:
                    raise NotImplementedError

            if compiled is None:
                lhs, rhs = expr.args
                if op:
                    compiled = '{0} {1} {2}'.format(
                        self._parenthesis(lhs), op, self._parenthesis(rhs)
                    )
                else:
                    raise NotImplementedError

        self._ctx.add_expr_compiled(expr, compiled)

    def visit_unary_op(self, expr):
        try:
            if isinstance(expr, Negate) and expr.input.dtype == df_types.boolean:
                compiled = 'NOT {0}'.format(self._parenthesis(expr.input))
            else:
                op = UNARY_OP_COMPILE_DIC[expr.node_name]

                compiled = '{0}{1}'.format(
                        op, self._parenthesis(expr.input))
        except KeyError:
            if isinstance(expr, Abs):
                compiled = 'ABS({0})'.format(
                    self._ctx.get_expr_compiled(expr.input))
            elif isinstance(expr, (Invert, Negate)) and \
                    expr.input.dtype == df_types.boolean:
                compiled = 'NOT {0}'.format(self._parenthesis(expr.input))
            else:
                raise NotImplementedError

        self._ctx.add_expr_compiled(expr, compiled)

    def visit_math(self, expr):
        compiled = None
        try:
            op = MATH_COMPILE_DIC[expr.node_name]
        except KeyError:
            if expr.node_name == 'Log':
                if expr._base is None:
                    op = 'ln'
                else:
                    compiled = 'LOG({0}, {1})'.format(
                        self._ctx.get_expr_compiled(expr._base),
                        self._ctx.get_expr_compiled(expr.input)
                    )
            elif expr.node_name == 'Log2':
                compiled = 'LOG(2, {0})'.format(
                    self._ctx.get_expr_compiled(expr.input)
                )
            elif expr.node_name == 'Log10':
                compiled = 'LOG(10, {0})'.format(
                    self._ctx.get_expr_compiled(expr.input)
                )
            elif expr.node_name == 'Log1p':
                compiled = 'LN(1 + {0})'.format(
                    self._ctx.get_expr_compiled(expr.input)
                )
            elif expr.node_name == 'Expm1':
                compiled = 'EXP({0}) - 1'.format(
                    self._ctx.get_expr_compiled(expr.input)
                )
            elif expr.node_name == 'Trunc':
                if expr._decimals is None:
                    op = 'TRUNC'
                else:
                    compiled = 'TRUNC({0}, {1})'.format(
                        self._ctx.get_expr_compiled(expr.input),
                        self._ctx.get_expr_compiled(expr._decimals)
                    )
            elif expr.node_name == 'Round':
                if expr._decimals is None:
                    op = 'ROUND'
                else:
                    compiled = 'ROUND({0}, {1})'.format(
                        self._ctx.get_expr_compiled(expr.input),
                        self._ctx.get_expr_compiled(expr._decimals)
                    )
            else:
                raise NotImplementedError

        if compiled is None:
            compiled = '{0}({1})'.format(
                op.upper(), self._ctx.get_expr_compiled(expr.input))

        self._ctx.add_expr_compiled(expr, compiled)

    def visit_string_op(self, expr):
        # FIXME quite a few operations cannot support by internal function
        compiled = None

        input = self._ctx.get_expr_compiled(expr.input)
        if isinstance(expr, strings.Capitalize):
            compiled = 'CONCAT(TOUPPER(SUBSTR(%(input)s, 1, 1)), TOLOWER(SUBSTR(%(input)s, 2)))' % {
                'input': input
            }
        elif isinstance(expr, strings.CatStr):
            nodes = [expr._input]
            if expr._others is not None:
                others = (expr._others, ) if not isinstance(expr._others, Iterable) else expr._others
                for other in others:
                    if expr._sep is not None:
                        nodes.extend([expr._sep, other])
                    else:
                        nodes.append(other)
            compiled = 'CONCAT(%s)' % ', '.join(self._ctx.get_expr_compiled(e) for e in nodes)
        elif isinstance(expr, strings.Contains):
            if expr.regex:
                raise NotImplementedError
            compiled = 'INSTR(%s, %s) > 0' % (input, self._ctx.get_expr_compiled(expr._pat))
        elif isinstance(expr, strings.Endswith):
            # TODO: any better solution?
            compiled = 'INSTR(REVERSE(%s), REVERSE(%s)) == 1' % (
                input, self._ctx.get_expr_compiled(expr._pat))
        elif isinstance(expr, strings.Startswith):
            compiled = 'INSTR(%s, %s) == 1' % (input, self._ctx.get_expr_compiled(expr._pat))
        elif isinstance(expr, strings.Find):
            if isinstance(expr.start, six.integer_types):
                start = expr.start + 1 if expr.start >= 0 else expr.start
            else:
                start = 'IF(%(start)s >= 0, %(start)s + 1, %(start)s)' % {
                    'start': self._ctx.get_expr_compiled(expr._start)
                }
            if expr.end is not None:
                raise NotImplementedError
            else:
                compiled = 'INSTR(%s, %s, %s) - 1' % (
                    input, self._ctx.get_expr_compiled(expr._sub), start)
        elif isinstance(expr, strings.Get):
            compiled = 'SUBSTR(%s, %s, 1)' % (input, expr.index + 1)
        elif isinstance(expr, strings.Len):
            compiled = 'LENGTH(%s)' % input
        elif isinstance(expr, strings.Lower):
            compiled = 'TOLOWER(%s)' % input
        elif isinstance(expr, strings.Upper):
            compiled = 'TOUPPER(%s)' % input
        elif isinstance(expr, (strings.Lstrip, strings.Rstrip, strings.Strip)):
            if expr.to_strip != ' ':
                raise NotImplementedError
            func = {
                'Lstrip': 'LTRIM',
                'Rstrip': 'RTRIM',
                'Strip': 'TRIM'
            }
            compiled = '%s(%s)' % (func[type(expr).__name__], input)
        elif isinstance(expr, strings.Slice):
            # internal function will be compiled in two cases:
            # 1) start is not None
            # 2) positive start and end
            if expr.end is None and expr.step is None:
                compiled = 'SUBSTR(%s, %s)' % (input, expr.start + 1)
            else:
                # expr.start and expr.end
                length = expr.end - expr.start
                compiled = 'SUBSTR(%s, %s, %s)' % (input, expr.start + 1, length)
        elif isinstance(expr, strings.Repeat):
            compiled = 'REPEAT(%s, %s)' % (
                input, self._ctx.get_expr_compiled(expr._repeats))
        elif isinstance(expr, strings.Split):
            if expr.n != -1:
                raise NotImplementedError

            escape_pat = re.escape(expr.pat)
            nre_compiled = 'SPLIT(%s, \'%s\')' % (input, utils.escape_odps_string(expr.pat))
            re_compiled = 'SPLIT(%s, \'%s\')' % (input, utils.escape_odps_string(re.escape(expr.pat)))

            try:
                re.compile(expr.pat)
                is_regex = True
            except:
                is_regex = False

            if expr.pat == escape_pat or not is_regex:
                compiled = nre_compiled
            else:
                compiled = 'IF(SIZE(%(re)s) = 0, %(nre)s, %(re)s)' % dict(re=re_compiled, nre=nre_compiled)
        elif isinstance(expr, strings.StringToDict):
            compiled = 'STR_TO_MAP(%s, \'%s\', \'%s\')' % (input, expr.item_delim, expr.kv_delim)

        if compiled is not None:
            self._ctx.add_expr_compiled(expr, compiled)
        else:
            raise NotImplementedError

    def visit_datetime_op(self, expr):
        # FIXME quite a few operations cannot support by internal function
        class_name = type(expr).__name__
        input = self._ctx.get_expr_compiled(expr.input)

        compiled = None
        if class_name in DATE_PARTS_DIC:
            compiled = 'DATEPART(%s, %r)' % (input, DATE_PARTS_DIC[class_name])
        elif isinstance(expr, datetimes.WeekOfYear):
            compiled = 'WEEKOFYEAR(%s)' % input
        elif isinstance(expr, datetimes.WeekDay):
            compiled = 'WEEKDAY(%s)' % input
        elif isinstance(expr, datetimes.Date):
            compiled = 'DATETRUNC(%s, %r)' % (input, 'dd')
        elif isinstance(expr, datetimes.UnixTimestamp):
            compiled = 'UNIX_TIMESTAMP(%s)' % input

        if compiled is not None:
            self._ctx.add_expr_compiled(expr, compiled)
        else:
            raise NotImplementedError

    def visit_composite_op(self, expr):
        if getattr(expr, 'input', None) is not None:
            input = self._ctx.get_expr_compiled(expr.input)
        else:
            input = None

        compiled = None
        if isinstance(expr, composites.ListDictLength):
            compiled = 'SIZE(%s)' % input
        elif isinstance(expr, composites.ListDictGetItem):
            compiled = '%s[%s]' % (input, self._ctx.get_expr_compiled(expr._key))
        elif isinstance(expr, composites.ListSort):
            compiled = 'SORT_ARRAY(%s)' % input
        elif isinstance(expr, composites.ListContains):
            compiled = 'ARRAY_CONTAINS(%s, %s)' % (
                input,
                self._ctx.get_expr_compiled(expr._value)
            )
        elif isinstance(expr, composites.DictKeys):
            compiled = 'MAP_KEYS(%s)' % input
        elif isinstance(expr, composites.DictValues):
            compiled = 'MAP_VALUES(%s)' % input
        elif isinstance(expr, composites.ListBuilder):
            compiled = 'ARRAY(%s)' % ', '.join(
                self._ctx.get_expr_compiled(it) for it in expr._values
            )
        elif isinstance(expr, composites.DictBuilder):
            compiled = 'MAP(%s)' % ', '.join(
                '%s, %s' % (self._ctx.get_expr_compiled(it1), self._ctx.get_expr_compiled(it2))
                for it1, it2 in zip(expr._keys, expr._values)
            )

        if compiled is not None:
            self._ctx.add_expr_compiled(expr, compiled)
        else:
            raise NotImplementedError

    def visit_groupby(self, expr):
        bys, having, aggs, fields = tuple(expr.args[1:])
        if fields is None:
            fields = bys + aggs

        by_fields = [self._ctx.get_expr_compiled(by) for by in bys]
        group_by_clause = 'GROUP BY {0} '.format(self._join_compiled_fields(by_fields))

        select_fields = [self._compile_select_field(field) for field in fields]
        select_clause = self._join_select_fields(select_fields)

        self.add_select_clause(expr, select_clause)
        self.add_group_by_clause(expr, group_by_clause)

        if having:
            self.add_having_clause(expr, self._ctx.get_expr_compiled(having))

    def visit_mutate(self, expr):
        bys, mutates, fields = tuple(expr.args[1:])
        if fields is None:
            fields = bys + mutates

        select_fields = [self._compile_select_field(field) for field in fields]
        select_clause = self._join_select_fields(select_fields)

        self.add_select_clause(expr, select_clause)

    def visit_sort_column(self, expr):
        def get_field(field):
            if isinstance(field.input, CollectionExpr):
                return field._source_name
            elif isinstance(field.input, Column):
                return field.input.source_name
            else:
                return self._ctx.get_expr_compiled(field.input)
        compiled = '{0} DESC'.format(get_field(expr)) \
            if not expr._ascending else get_field(expr)
        self._ctx.add_expr_compiled(expr, compiled)

    def visit_sort(self, expr):
        keys_fields = expr.args[1]

        order_by_clause = 'ORDER BY {0} '.format(self._join_compiled_fields(
            [self._ctx.get_expr_compiled(field) for field in keys_fields]))

        self.add_order_by_clause(expr, order_by_clause)

    def visit_sample(self, expr):
        assert expr._parts is None and not expr._replace.value
        if expr._frac is not None:
            self.add_table_sample_clause(expr, "({0} PERCENT)".format(int(expr._frac.value * 100)))
        elif expr._n is not None:
            self.add_table_sample_clause(expr, "({0} ROWS)".format(expr._n.value))

    def visit_distinct(self, expr):
        distinct_fields = expr.args[1]

        fields_clause = self._join_select_fields(
            [self._compile_select_field(field) for field in distinct_fields])
        select_clause = 'DISTINCT {0}'.format(fields_clause)

        self.add_select_clause(expr, select_clause)

    def visit_reduction(self, expr):
        is_unique = getattr(expr, '_unique', False)

        if isinstance(expr, (Count, GroupedCount)) and isinstance(expr.input, CollectionExpr):
            compiled = 'COUNT(1)'
            self._ctx.add_expr_compiled(expr, compiled)
            return

        if isinstance(expr, (Std, GroupedStd)):
            if expr._ddof not in (0, 1):
                raise CompileError('Does not support %s with ddof=%s' % (
                    expr.node_name, expr._ddof))

        compiled = None

        if isinstance(expr, (Mean, GroupedMean)):
            node_name = 'avg'
        elif isinstance(expr, (Std, GroupedStd)):
            node_name = 'stddev' if expr._ddof == 0 else 'stddev_samp'
        elif isinstance(expr, (Sum, GroupedSum)) and expr.input.dtype == df_types.string:
            if is_unique:
                compiled = 'WM_CONCAT(DISTINCT \'\', %s)' % self._ctx.get_expr_compiled(expr.input)
            else:
                compiled = 'WM_CONCAT(\'\', %s)' % self._ctx.get_expr_compiled(expr.input)
        elif isinstance(expr, (Sum, GroupedSum)) and expr.input.dtype == df_types.boolean:
            if getattr(expr, '_unique', False):
                compiled = 'SUM(DISTINCT IF(%s, 1, 0))' % self._ctx.get_expr_compiled(expr.input)
            else:
                compiled = 'SUM(IF(%s, 1, 0))' % self._ctx.get_expr_compiled(expr.input)
        elif isinstance(expr, (Max, GroupedMax, Min, GroupedMin)) and \
                expr.input.dtype == df_types.boolean:
            compiled = '%s(IF(%s, 1, 0)) == 1' % (
                expr.node_name, self._ctx.get_expr_compiled(expr.input))
        elif isinstance(expr, (Any, GroupedAny)):
            compiled = 'MAX(IF(%s, 1, 0)) == 1' % self._ctx.get_expr_compiled(expr.args[0])
        elif isinstance(expr, (All, GroupedAll)):
            compiled = 'MIN(IF(%s, 1, 0)) == 1' % self._ctx.get_expr_compiled(expr.args[0])
        elif isinstance(expr, (NUnique, GroupedNUnique)):
            compiled = 'COUNT(DISTINCT %s)' % ', '.join(
                self._ctx.get_expr_compiled(c) for c in expr.inputs)
        elif isinstance(expr, (Cat, GroupedCat)):
            if is_unique:
                compiled = 'WM_CONCAT(DISTINCT %s, %s)' % (self._ctx.get_expr_compiled(expr._sep),
                                                           self._ctx.get_expr_compiled(expr.input))
            else:
                compiled = 'WM_CONCAT(%s, %s)' % (self._ctx.get_expr_compiled(expr._sep),
                                                  self._ctx.get_expr_compiled(expr.input))
        elif isinstance(expr, (Quantile, GroupedQuantile)):
            if not isinstance(expr._prob, (list, set)):
                probs_expr = expr._prob
            else:
                probs_expr = 'ARRAY(' + ', '.join(str(p) for p in expr._prob) + ')'

            if expr.input.data_type in (df_types.float32, df_types.float64) and types.get_local_use_odps2_types():
                func_name = 'PERCENTILE_APPROX'
            else:
                func_name = 'PERCENTILE'

            if is_unique:
                compiled = '%s(DISTINCT %s, %s)' % (func_name, self._ctx.get_expr_compiled(expr.input), probs_expr)
            else:
                compiled = '%s(%s, %s)' % (func_name, self._ctx.get_expr_compiled(expr.input), probs_expr)
        elif isinstance(expr, (ToList, GroupedToList)):
            func_name = 'COLLECT_SET' if expr._unique else 'COLLECT_LIST'
            compiled = '%s(%s)' % (func_name, self._ctx.get_expr_compiled(expr.input))
        else:
            node_name = expr.node_name

        if compiled is None:
            if is_unique:
                compiled = '{0}(DISTINCT {1})'.format(
                    node_name.upper(), self._ctx.get_expr_compiled(expr.args[0]))
            else:
                compiled = '{0}({1})'.format(
                    node_name.upper(), self._ctx.get_expr_compiled(expr.args[0]))

        self._ctx.add_expr_compiled(expr, compiled)

    def visit_user_defined_aggregator(self, expr):
        is_func_created = False
        if isinstance(expr._aggregator, six.string_types):
            func_name = expr._aggregator
        elif isinstance(expr._aggregator, Function):
            func_name = expr._aggregator.name
        else:
            func_name = self._ctx.get_udf(expr._aggregator)
            is_func_created = True

        args = [self._ctx.get_expr_compiled(i) for i in expr.inputs]

        if hasattr(expr, '_func_args') and expr._func_args is not None \
                and not is_func_created:
            func_args = [repr(arg) for arg in expr._func_args]
            args.extend(func_args)
        if getattr(expr, '_unique', False):
            compiled = '{0}(DISTINCT {1})'.format(func_name, ', '.join(args))
        else:
            compiled = '{0}({1})'.format(func_name, ', '.join(args))

        self._ctx.add_expr_compiled(expr, compiled)

    def visit_column(self, expr):
        alias = self._ctx.get_collection_alias(expr.input, silent=True)

        if alias:
            alias = alias[0]
            compiled = '{0}.{1}'.format(alias, self._quote(expr.source_name))
        else:
            symbol = self._ctx.add_need_alias_column(expr)
            compiled = '%({0})s'.format(symbol)

        if expr._source_data_type != expr._data_type:
            compiled = 'CAST({0} AS {1})'.format(
                compiled, types.df_type_to_odps_type(expr.dtype))

        self._ctx.add_expr_compiled(expr, compiled)

    def visit_function(self, expr):
        is_func_created = False
        if isinstance(expr, Func):
            func_name = expr._func_name
        else:
            if isinstance(expr._func, six.string_types):
                func_name = expr._func
            elif isinstance(expr._func, Function):
                func_name = expr._func.name
            else:
                func_name = self._ctx.get_udf(expr._func)
                is_func_created = True

        if isinstance(expr, (MappedExpr, Func)):
            args = [self._ctx.get_expr_compiled(f) for f in expr.inputs]
        else:
            raise NotImplementedError
        if hasattr(expr, '_func_args') and expr._func_args is not None \
                and not is_func_created:
            func_args = [repr(arg) for arg in expr._func_args]
            args.extend(func_args)
        compiled = '{0}({1})'.format(func_name, ', '.join(args))

        self._ctx.add_expr_compiled(expr, compiled)

    def visit_builtin_function(self, expr):
        compiled = '{0}({1})'.format(expr._func_name,
                                     ', '.join(repr(arg) for arg in expr._func_args))
        self._ctx.add_expr_compiled(expr, compiled)

    def _quote(self, compiled):
        if options.df.quote:
            return '`{0}`'.format(compiled)
        else:
            return compiled

    def _unquote(self, compiled):
        if options.df.quote:
            reg = re.compile(r'`([^`]+)`')
        else:
            reg = re.compile(r'\.(\w+)')
        matched = reg.search(compiled)
        if matched:
            return matched.group(1)
        return compiled

    def visit_reshuffle(self, expr):
        bys, sorts = expr._by, expr._sort_fields

        by_fields = [self._unquote(self._ctx.get_expr_compiled(by)) for by in bys]
        distribute_by_clause = 'DISTRIBUTE BY {0} '.format(
            self._join_compiled_fields(by_fields))

        self.add_group_by_clause(expr, distribute_by_clause)

        if sorts:
            sort_fields = [self._ctx.get_expr_compiled(sort) for sort in sorts]
            sort_by_clause = 'SORT BY {0}'.format(self._join_compiled_fields(sort_fields))

            self.add_order_by_clause(expr, sort_by_clause)

    def visit_apply_collection(self, expr):
        is_func_created = False
        if isinstance(expr._func, six.string_types):
            func_name = expr._func
        elif isinstance(expr._func, Function):
            func_name = expr._func.name
        else:
            func_name = self._ctx.get_udf(expr._func)
            is_func_created = True

        args = [self._ctx.get_expr_compiled(f) for f in expr._fields]
        if hasattr(expr, '_func_args') and expr._func_args is not None \
                and not is_func_created:
            func_args = [repr(arg) for arg in expr._func_args]
            args.extend(func_args)

        func_call = '{0}({1})'.format(func_name, ', '.join(args))
        field_list = ', '.join(self._quote(n) for n in expr.schema.names)
        if expr._lateral_view:
            lv_prefix = 'LATERAL VIEW ' if not expr._keep_nulls else 'LATERAL VIEW OUTER '
            compiled = lv_prefix + '{0} {1} AS {2}'.format(
                func_call, self._ctx.get_collection_alias(expr, create=True)[0], field_list)
        else:
            compiled = '{0} AS ({1})'.format(func_call, field_list)
            self.add_select_clause(expr, compiled)

        self._ctx.add_expr_compiled(expr, compiled)

    def _wrap_typed(self, expr, compiled):
        if expr._source_data_type != expr._data_type:
            compiled = 'cast({0} AS {1})'.format(
                compiled, types.df_type_to_odps_type(expr._data_type))

        return compiled

    def visit_sequence(self, expr):
        compiled = expr._source_name
        compiled = self._wrap_typed(expr, compiled)

        self._ctx.add_expr_compiled(expr, compiled)

    def _compile_window_order_by(self, expr):
        if isinstance(expr.input, SequenceExpr):
            compiled = self._ctx.get_expr_compiled(expr.input)
            return '%s DESC' % compiled if not expr._ascending else compiled
        else:
            return self._ctx.get_expr_compiled(expr)

    def _compile_window_function(self, func, args, partition_by=None,
                                 order_by=None, preceding=None, following=None):
        partition_by = 'PARTITION BY {0}'.format(partition_by or '1')
        order_by = 'ORDER BY {0}'.format(order_by) if order_by is not None else ''

        if isinstance(preceding, tuple):
            window_clause = 'ROWS BETWEEN {0} PRECEDING AND {1} PRECEDING' \
                .format(*[self._ctx.get_expr_compiled(p) for p in preceding])
        elif isinstance(following, tuple):
            window_clause = 'ROWS BETWEEN {0} FOLLOWING AND {1} FOLLOWING' \
                .format(*[self._ctx.get_expr_compiled(f) for f in following])
        elif preceding is not None and following is not None:
            window_clause = 'ROWS BETWEEN {0} PRECEDING AND {1} FOLLOWING' \
                .format(self._ctx.get_expr_compiled(preceding), self._ctx.get_expr_compiled(following))
        elif preceding is not None:
            window_clause = 'ROWS {0} PRECEDING'.format(self._ctx.get_expr_compiled(preceding))
        elif following is not None:
            window_clause = 'ROWS {0} FOLLOWING'.format(self._ctx.get_expr_compiled(following))
        else:
            window_clause = ''

        over = ' '.join(sub for sub in (partition_by, order_by, window_clause)
                        if len(sub) > 0)

        return '{0}({1}) OVER ({2})'.format(func, args, over)

    def visit_cum_window(self, expr):
        col_compiled = self._ctx.get_expr_compiled(expr.input)
        if isinstance(expr, CumSum) and expr.input.dtype == df_types.boolean:
            col_compiled = 'IF({0}, 1, 0)'.format(col_compiled)
        elif isinstance(expr, NthValue):
            col_compiled = '{0}, {1}'.format(col_compiled, expr._nth + 1)
            if expr._skip_nulls:
                col_compiled += ', true'
        if expr.distinct:
            col_compiled = 'DISTINCT {0}'.format(col_compiled)

        partition_by = ', '.join(self._ctx.get_expr_compiled(by)
                                 for by in expr._partition_by) if expr._partition_by else None
        order_by = ', '.join(self._compile_window_order_by(by)
                             for by in expr._order_by) if expr._order_by else None

        func_name = WINDOW_COMPILE_DIC[expr.node_name].upper()
        compiled = self._compile_window_function(func_name, col_compiled, partition_by=partition_by,
                                                 order_by=order_by, preceding=expr._preceding,
                                                 following=expr._following)

        self._ctx.add_expr_compiled(expr, compiled)

    def visit_rank_window(self, expr):
        func_name = WINDOW_COMPILE_DIC[expr.node_name].upper()
        col_compiled = ''
        if isinstance(expr, QCut):
            col_compiled = str(expr._bins)

        partition_by = ', '.join(self._ctx.get_expr_compiled(by)
                                 for by in expr._partition_by) if expr._partition_by else None
        order_by = ', '.join(self._compile_window_order_by(by)
                             for by in expr._order_by) if expr._order_by else None

        compiled = self._compile_window_function(func_name, col_compiled, partition_by=partition_by,
                                                 order_by=order_by)

        self._ctx.add_expr_compiled(expr, compiled)

    def visit_shift_window(self, expr):
        func_name = WINDOW_COMPILE_DIC[expr.node_name].upper()

        compiled_fields = [self._ctx.get_expr_compiled(expr.input), ]
        if expr._offset:
            compiled_fields.append(self._ctx.get_expr_compiled(expr._offset))
        if expr._default:
            compiled_fields.append(self._ctx.get_expr_compiled(expr._default))

        col_compiled = self._join_compiled_fields(compiled_fields)

        partition_by = ', '.join(self._ctx.get_expr_compiled(by)
                                 for by in expr._partition_by) if expr._partition_by else None
        order_by = ', '.join(self._compile_window_order_by(by)
                             for by in expr._order_by) if expr._order_by else None

        compiled = self._compile_window_function(func_name, col_compiled, partition_by=partition_by,
                                                 order_by=order_by)

        self._ctx.add_expr_compiled(expr, compiled)

    def visit_scalar(self, expr):
        compiled = None
        if expr._value is not None:
            if expr.dtype == df_types.string:
                val = utils.to_str(expr.value) \
                    if isinstance(expr.value, six.text_type) else expr.value
                compiled = "'{0}'".format(val.replace("'", "\\'"))
            elif isinstance(expr.dtype, df_types.Integer) and types.get_local_use_odps2_types():
                if expr.dtype == df_types.int8:
                    compiled = "{0!r}Y".format(expr._value)
                elif expr.dtype == df_types.int16:
                    compiled = "{0!r}S".format(expr._value)
                elif expr.dtype == df_types.int64:
                    compiled = "{0!r}L".format(expr._value)
            elif isinstance(expr._value, bool):
                compiled = 'true' if expr._value else 'false'
            elif isinstance(expr._value, datetime):
                # FIXME: just ignore shorter than second
                compiled = 'FROM_UNIXTIME({0})'.format(utils.to_timestamp(expr._value))
            elif isinstance(expr._value, date):
                compiled = 'CAST({0!r} AS DATE)'.format(expr._value.strftime("%Y-%m-%d"))
            elif isinstance(expr._value, Decimal):
                compiled = 'CAST({0} AS DECIMAL)'.format(repr(str(expr._value)))
        else:
            compiled = 'CAST(NULL AS {0})'.format(types.df_type_to_odps_type(expr._value_type))

        if compiled is None:
            compiled = repr(expr._value)
        self._ctx.add_expr_compiled(expr, compiled)

    @classmethod
    def _cast(cls, compiled, source_type, to_type):
        source_odps_type = types.df_type_to_odps_type(source_type)
        to_type = types.df_type_to_odps_type(to_type)

        if not to_type.can_explicit_cast(source_odps_type):
            raise CompileError(
                    'Cannot cast from %s to %s' % (source_odps_type, to_type))

        return 'CAST({0} AS {1})'.format(compiled, to_type)

    def visit_cast(self, expr):
        compiled = self._ctx.get_expr_compiled(expr._input)

        if isinstance(expr.source_type, df_types.Integer) and expr.dtype == df_types.datetime:
            compiled = 'FROM_UNIXTIME({0})'.format(self._ctx.get_expr_compiled(expr.input))
        elif expr.dtype != expr.source_type:
            compiled = self._cast(compiled, expr.source_type, expr.dtype)

        self._ctx.add_expr_compiled(expr, compiled)

    def visit_join(self, expr):
        left_compiled, right_compiled, predicate_compiled = tuple(self._sub_compiles[expr])

        from_clause = '{0} \n{1} JOIN \n{2}'.format(
            left_compiled, expr._how, utils.indent(right_compiled, self._indent_size)
        )
        if predicate_compiled:
            from_clause += '\nON {0}'.format(predicate_compiled)

        self.add_from_clause(expr, from_clause)
        self._ctx.add_expr_compiled(expr, from_clause)

    def visit_union(self, expr):
        union_type = 'UNION ALL' if not expr._distinct else 'UNION'
        left_compiled, right_compiled = tuple(self._sub_compiles[expr])

        from_clause = '{0} \n{1}\n{2}'.format(left_compiled, union_type,
                                              utils.indent(right_compiled, self._indent_size))

        compiled = from_clause
        if not self._union_no_alias.get(expr, False):
            compiled = '(\n{0}\n) {1}'.format(utils.indent(from_clause, self._indent_size),
                                              self._ctx.get_collection_alias(expr, create=True)[0])

        self.add_from_clause(expr, compiled)
        self._ctx.add_expr_compiled(expr, compiled)