odps/udf/tools/runners.py (255 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """UDF runners implementing the local run framework.""" import csv import re import sys from ... import distcache from ... import types as odps_types from ... import udf from ...utils import split_backquoted, to_date, to_milliseconds from . import utils __all__ = ["get_csv_runner", "get_table_runner"] PY2 = sys.version_info[0] == 2 _table_bracket_re = re.compile(r"[^\(]+\([^\)]+\)") def get_csv_runner( udf_class, input_col_delim=",", null_indicator="NULL", stdin=None, collector_cls=None, ): """Create a runner to read csv with specified udf class.""" proto = udf.get_annotation(udf_class) in_types, out_types = parse_proto(proto) stdin = stdin or sys.stdin arg_parser = ArgParser(in_types, stdin, input_col_delim, null_indicator) stdin_feed = arg_parser.parse() collector_cls = collector_cls or StdoutCollector collector = collector_cls(out_types) ctor = _get_runner_class(udf_class) return ctor(udf_class, stdin_feed, collector) def get_table_runner( udf_class, odps_entry, table_desc, record_limit=None, collector_cls=None ): """Create a runner to read table with specified udf class.""" proto = udf.get_annotation(udf_class) in_types, out_types = parse_proto(proto) tb_feed = table_feed(odps_entry, table_desc, in_types, record_limit) collector_cls = collector_cls or StdoutCollector collector = collector_cls(out_types) ctor = _get_runner_class(udf_class) return ctor(udf_class, tb_feed, collector) def simple_run(udf_class, args): proto = udf.get_annotation(udf_class) in_types, out_types = parse_proto(proto) feed = direct_feed(args) collector = DirectCollector(out_types) ctor = _get_runner_class(udf_class) runner = ctor(udf_class, feed, collector) runner.run() return collector.results def initialize(): """Initialize the local run environment.""" distcache.get_cache_table = utils.get_cache_table def _split_data_types(types_str): bracket_level = 0 ret_types = [""] for ch in types_str: if bracket_level == 0 and ch == ",": ret_types[-1] = ret_types[-1].strip() ret_types.append("") else: ret_types[-1] += ch if ch in ("<", "("): bracket_level += 1 elif ch in (">", ")"): bracket_level -= 1 return [s for s in ret_types if s] def _get_types(types_str): entries = [] for t in _split_data_types(types_str): t = t.strip() entries.append(odps_types.validate_data_type(t)) return entries def _get_in_types(types): if types == "": return [] return _get_types(types) if types != "*" else ["*"] def _get_runner_class(udf_class): if udf.BaseUDAF in udf_class.__mro__: ctor = UDAFRunner elif udf.BaseUDTF in udf_class.__mro__: ctor = UDTFRunner else: ctor = UDFRunner return ctor def parse_proto(proto): tokens = proto.lower().split("->") if len(tokens) != 2: raise ValueError("Illegal format of @annotate(%s)" % proto) return _get_in_types(tokens[0].strip()), _get_types(tokens[1].strip()) def direct_feed(args): for a in args: yield a def _convert_value(value, tp): try: odps_types._date_allow_int_conversion = True value = odps_types.validate_value(value, tp) finally: odps_types._date_allow_int_conversion = False if not PY2: return value if isinstance(tp, odps_types.Datetime): return to_milliseconds(value) elif isinstance(tp, odps_types.Date): return to_date(value) elif isinstance(tp, odps_types.Array): return [_convert_value(v, tp.value_type) for v in value] elif isinstance(tp, odps_types.Map): return { _convert_value(k, tp.key_type): _convert_value(v, tp.value_type) for k, v in value.items() } elif isinstance(tp, odps_types.Struct): if isinstance(value, dict): vals = { k: _convert_value(value[k], ftp) for k, ftp in tp.field_types.items() } else: vals = { k: _convert_value(getattr(value, k), ftp) for k, ftp in tp.field_types.items() } return tp.namedtuple_type(**vals) else: return value def _validate_values(values, types): if types == ["*"]: return values if len(values) != len(types): raise ValueError( "Input length mismatch: %d expected, %d provided" % (len(types), len(values)) ) ret_vals = [None] * len(values) for idx, (tp, d) in enumerate(zip(types, values)): if d is None: continue try: ret_vals[idx] = _convert_value(d, tp) except: raise ValueError("Input type mismatch: expected %s, received %r" % (tp, d)) return ret_vals class ArgParser(object): NULL_INDICATOR = "NULL" def __init__(self, types, fileobj, delim=",", null_indicator="NULL"): self.types = types self.delim = delim self.null_indicator = null_indicator self.reader = csv.reader(fileobj, delimiter=delim) def parse(self): for record in self.reader: tokens = [] for token in record: if token == self.null_indicator: tokens.append(None) else: tokens.append(token) if len(self.types) == 0 and len(tokens) == 0: yield "" continue yield _validate_values(tokens, self.types) def _get_table_or_partition(odps_entry, table_desc): table_names = [] table_part = None table_cols = None for part in split_backquoted(table_desc, "."): part = part.strip() if not _table_bracket_re.match(part): table_names.append(part) elif part.startswith("p("): table_part = part[2:-1] elif part.startswith("c("): table_cols = [s.strip() for s in part[2:-1].split(",")] data_obj = odps_entry.get_table(".".join(table_names)) if table_part is not None: data_obj = data_obj.get_partition(table_part) return data_obj, table_cols def table_feed(odps_entry, table_desc, in_types, record_limit): data_obj, cols = _get_table_or_partition(odps_entry, table_desc) with data_obj.open_reader(columns=cols) as reader: if record_limit is not None: data_src = reader[:record_limit] else: data_src = reader for row in data_src: yield _validate_values(row.values, in_types) class ArgFormatter(object): DELIM = "\t" NULL_INDICATOR = "NULL" def __init__(self, types): self.types = types def format(self, *args): ret = self.DELIM.join([str(a) for a in args]) return ret class BaseCollector(object): """Basic common logic of collector.""" def __init__(self, schema): self.schema = schema def collect(self, *args): _validate_values(args, self.schema) self.handle_collect(*args) def handle_collect(self, *args): raise NotImplementedError class StdoutCollector(BaseCollector): """Collect the results to stdout.""" def __init__(self, schema): super(StdoutCollector, self).__init__(schema) self.formatter = ArgFormatter(schema) def handle_collect(self, *args): print(self.formatter.format(*args)) class DirectCollector(BaseCollector): """Collect results which can be fetched via self.results into memory.""" def __init__(self, schema): super(DirectCollector, self).__init__(schema) self.results = [] def handle_collect(self, *args): if len(self.schema) == 1: self.results.append(args[0]) else: self.results.append(args) class BaseRunner(object): def __init__(self, udf_class, feed, collector): self.udf_class = udf_class self.feed = feed self.collector = collector # check signature self.obj = udf_class() class UDFRunner(BaseRunner): def run(self): obj = self.obj collector = self.collector for args in self.feed: r = obj.evaluate(*args) collector.collect(r) class UDTFRunner(BaseRunner): def run(self): obj = self.obj collector = self.collector def local_forward(*r): collector.collect(*r) obj.forward = local_forward for args in self.feed: obj.process(*args) obj.close() class UDAFRunner(BaseRunner): def run(self): obj = self.obj collector = self.collector buf0 = obj.new_buffer() buf1 = obj.new_buffer() turn = True for args in self.feed: if turn: buf = buf0 turn = False else: buf = buf1 turn = True obj.iterate(buf, *args) merge_buf = obj.new_buffer() obj.merge(merge_buf, buf0) obj.merge(merge_buf, buf1) collector.collect(obj.terminate(merge_buf)) initialize()