# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""UDF runners implementing the local run framework."""

import csv
import re
import sys

from ... import distcache
from ... import types as odps_types
from ... import udf
from ...utils import split_backquoted, to_date, to_milliseconds
from . import utils

__all__ = ["get_csv_runner", "get_table_runner"]

PY2 = sys.version_info[0] == 2
_table_bracket_re = re.compile(r"[^\(]+\([^\)]+\)")


def get_csv_runner(
    udf_class,
    input_col_delim=",",
    null_indicator="NULL",
    stdin=None,
    collector_cls=None,
):
    """Create a runner to read csv with specified udf class."""
    proto = udf.get_annotation(udf_class)
    in_types, out_types = parse_proto(proto)
    stdin = stdin or sys.stdin
    arg_parser = ArgParser(in_types, stdin, input_col_delim, null_indicator)
    stdin_feed = arg_parser.parse()

    collector_cls = collector_cls or StdoutCollector
    collector = collector_cls(out_types)
    ctor = _get_runner_class(udf_class)
    return ctor(udf_class, stdin_feed, collector)


def get_table_runner(
    udf_class, odps_entry, table_desc, record_limit=None, collector_cls=None
):
    """Create a runner to read table with specified udf class."""
    proto = udf.get_annotation(udf_class)
    in_types, out_types = parse_proto(proto)
    tb_feed = table_feed(odps_entry, table_desc, in_types, record_limit)

    collector_cls = collector_cls or StdoutCollector
    collector = collector_cls(out_types)
    ctor = _get_runner_class(udf_class)
    return ctor(udf_class, tb_feed, collector)


def simple_run(udf_class, args):
    proto = udf.get_annotation(udf_class)
    in_types, out_types = parse_proto(proto)
    feed = direct_feed(args)
    collector = DirectCollector(out_types)
    ctor = _get_runner_class(udf_class)
    runner = ctor(udf_class, feed, collector)
    runner.run()
    return collector.results


def initialize():
    """Initialize the local run environment."""
    distcache.get_cache_table = utils.get_cache_table


def _split_data_types(types_str):
    bracket_level = 0
    ret_types = [""]
    for ch in types_str:
        if bracket_level == 0 and ch == ",":
            ret_types[-1] = ret_types[-1].strip()
            ret_types.append("")
        else:
            ret_types[-1] += ch
            if ch in ("<", "("):
                bracket_level += 1
            elif ch in (">", ")"):
                bracket_level -= 1
    return [s for s in ret_types if s]


def _get_types(types_str):
    entries = []
    for t in _split_data_types(types_str):
        t = t.strip()
        entries.append(odps_types.validate_data_type(t))
    return entries


def _get_in_types(types):
    if types == "":
        return []
    return _get_types(types) if types != "*" else ["*"]


def _get_runner_class(udf_class):
    if udf.BaseUDAF in udf_class.__mro__:
        ctor = UDAFRunner
    elif udf.BaseUDTF in udf_class.__mro__:
        ctor = UDTFRunner
    else:
        ctor = UDFRunner
    return ctor


def parse_proto(proto):
    tokens = proto.lower().split("->")
    if len(tokens) != 2:
        raise ValueError("Illegal format of @annotate(%s)" % proto)
    return _get_in_types(tokens[0].strip()), _get_types(tokens[1].strip())


def direct_feed(args):
    for a in args:
        yield a


def _convert_value(value, tp):
    try:
        odps_types._date_allow_int_conversion = True
        value = odps_types.validate_value(value, tp)
    finally:
        odps_types._date_allow_int_conversion = False

    if not PY2:
        return value

    if isinstance(tp, odps_types.Datetime):
        return to_milliseconds(value)
    elif isinstance(tp, odps_types.Date):
        return to_date(value)
    elif isinstance(tp, odps_types.Array):
        return [_convert_value(v, tp.value_type) for v in value]
    elif isinstance(tp, odps_types.Map):
        return {
            _convert_value(k, tp.key_type): _convert_value(v, tp.value_type)
            for k, v in value.items()
        }
    elif isinstance(tp, odps_types.Struct):
        if isinstance(value, dict):
            vals = {
                k: _convert_value(value[k], ftp) for k, ftp in tp.field_types.items()
            }
        else:
            vals = {
                k: _convert_value(getattr(value, k), ftp)
                for k, ftp in tp.field_types.items()
            }
        return tp.namedtuple_type(**vals)
    else:
        return value


def _validate_values(values, types):
    if types == ["*"]:
        return values
    if len(values) != len(types):
        raise ValueError(
            "Input length mismatch: %d expected, %d provided"
            % (len(types), len(values))
        )
    ret_vals = [None] * len(values)
    for idx, (tp, d) in enumerate(zip(types, values)):
        if d is None:
            continue
        try:
            ret_vals[idx] = _convert_value(d, tp)
        except:
            raise ValueError("Input type mismatch: expected %s, received %r" % (tp, d))
    return ret_vals


class ArgParser(object):
    NULL_INDICATOR = "NULL"

    def __init__(self, types, fileobj, delim=",", null_indicator="NULL"):
        self.types = types
        self.delim = delim
        self.null_indicator = null_indicator

        self.reader = csv.reader(fileobj, delimiter=delim)

    def parse(self):
        for record in self.reader:
            tokens = []
            for token in record:
                if token == self.null_indicator:
                    tokens.append(None)
                else:
                    tokens.append(token)

            if len(self.types) == 0 and len(tokens) == 0:
                yield ""
                continue
            yield _validate_values(tokens, self.types)


def _get_table_or_partition(odps_entry, table_desc):
    table_names = []
    table_part = None
    table_cols = None
    for part in split_backquoted(table_desc, "."):
        part = part.strip()
        if not _table_bracket_re.match(part):
            table_names.append(part)
        elif part.startswith("p("):
            table_part = part[2:-1]
        elif part.startswith("c("):
            table_cols = [s.strip() for s in part[2:-1].split(",")]
    data_obj = odps_entry.get_table(".".join(table_names))
    if table_part is not None:
        data_obj = data_obj.get_partition(table_part)
    return data_obj, table_cols


def table_feed(odps_entry, table_desc, in_types, record_limit):
    data_obj, cols = _get_table_or_partition(odps_entry, table_desc)
    with data_obj.open_reader(columns=cols) as reader:
        if record_limit is not None:
            data_src = reader[:record_limit]
        else:
            data_src = reader

        for row in data_src:
            yield _validate_values(row.values, in_types)


class ArgFormatter(object):
    DELIM = "\t"
    NULL_INDICATOR = "NULL"

    def __init__(self, types):
        self.types = types

    def format(self, *args):
        ret = self.DELIM.join([str(a) for a in args])
        return ret


class BaseCollector(object):
    """Basic common logic of collector."""

    def __init__(self, schema):
        self.schema = schema

    def collect(self, *args):
        _validate_values(args, self.schema)
        self.handle_collect(*args)

    def handle_collect(self, *args):
        raise NotImplementedError


class StdoutCollector(BaseCollector):
    """Collect the results to stdout."""

    def __init__(self, schema):
        super(StdoutCollector, self).__init__(schema)
        self.formatter = ArgFormatter(schema)

    def handle_collect(self, *args):
        print(self.formatter.format(*args))


class DirectCollector(BaseCollector):
    """Collect results which can be fetched via self.results into memory."""

    def __init__(self, schema):
        super(DirectCollector, self).__init__(schema)
        self.results = []

    def handle_collect(self, *args):
        if len(self.schema) == 1:
            self.results.append(args[0])
        else:
            self.results.append(args)


class BaseRunner(object):
    def __init__(self, udf_class, feed, collector):
        self.udf_class = udf_class
        self.feed = feed
        self.collector = collector
        # check signature
        self.obj = udf_class()


class UDFRunner(BaseRunner):
    def run(self):
        obj = self.obj
        collector = self.collector
        for args in self.feed:
            r = obj.evaluate(*args)
            collector.collect(r)


class UDTFRunner(BaseRunner):
    def run(self):
        obj = self.obj
        collector = self.collector

        def local_forward(*r):
            collector.collect(*r)

        obj.forward = local_forward
        for args in self.feed:
            obj.process(*args)
        obj.close()


class UDAFRunner(BaseRunner):
    def run(self):
        obj = self.obj
        collector = self.collector
        buf0 = obj.new_buffer()
        buf1 = obj.new_buffer()
        turn = True
        for args in self.feed:
            if turn:
                buf = buf0
                turn = False
            else:
                buf = buf1
                turn = True
            obj.iterate(buf, *args)
        merge_buf = obj.new_buffer()
        obj.merge(merge_buf, buf0)
        obj.merge(merge_buf, buf1)
        collector.collect(obj.terminate(merge_buf))


initialize()
