odps/readers.py (329 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import csv
import itertools
import math
from collections import OrderedDict
from requests import Response
from . import compat, options, types, utils
from .compat import StringIO, six
from .models.record import Record
class AbstractRecordReader(object):
def __iter__(self):
return self
def __next__(self):
raise NotImplementedError
next = __next__
@classmethod
def _calc_count(cls, start, end, step):
if end is None:
return end
step = step or 1
return int(math.ceil(float(end - start) / step))
@classmethod
def _get_slice(cls, item):
if isinstance(item, six.integer_types):
start = item
end = start + 1
step = 1
elif isinstance(item, slice):
start = item.start or 0
end = item.stop
step = item.step or 1
else:
raise ValueError("Reader only supports index and slice operation.")
return start, end, step
def __getitem__(self, item):
start, end, step = self._get_slice(item)
count = self._calc_count(start, end, step)
if start < 0 or (count is not None and count <= 0) or step < 0:
raise ValueError("start, count, or step cannot be negative")
it = self._get_slice_iter(start=start, end=end, step=step)
if isinstance(item, six.integer_types):
try:
return next(it)
except StopIteration:
raise IndexError("Index out of range: %s" % item)
return it
def _get_slice_iter(self, start=None, end=None, step=None):
class SliceIterator(six.Iterator):
def __init__(self, it):
self.it = it
def __iter__(self):
return self.it
def __next__(self):
return next(self.it)
@staticmethod
def to_pandas():
if end is not None:
count = (end - (start or 0)) // (step or 1)
else:
count = None
pstep = None if step == 1 else step
kw = dict(start=start, count=count, step=pstep)
kw = {k: v for k, v in kw.items() if v is not None}
return parent.to_pandas(**kw)
parent = self
return SliceIterator(self._iter(start=start, end=end, step=step))
def _iter(self, start=None, end=None, step=None):
start = start or 0
step = step or 1
curr = start
for _ in range(start):
try:
next(self)
except StopIteration:
return
while True:
for i in range(step):
try:
record = next(self)
except StopIteration:
return
if i == 0:
yield record
curr += 1
if end is not None and curr >= end:
return
def _data_to_result_frame(
self, data, unknown_as_string=True, as_type=None, columns=None
):
from .df.backends.frame import ResultFrame
from .df.backends.odpssql.types import (
odps_schema_to_df_schema,
odps_type_to_df_type,
)
kw = dict()
if getattr(self, "schema", None) is not None:
kw["schema"] = odps_schema_to_df_schema(self.schema)
elif getattr(self, "_schema", None) is not None:
# do not remove as there might be coverage missing
kw["schema"] = odps_schema_to_df_schema(self._schema)
column_names = columns or getattr(self, "_column_names", None)
if column_names is not None:
self._columns = [self.schema[c] for c in column_names]
if getattr(self, "_columns", None) is not None:
cols = []
for col in self._columns:
col = copy.copy(col)
col.type = odps_type_to_df_type(col.type)
cols.append(col)
kw["columns"] = cols
if hasattr(self, "raw"):
try:
import pandas as pd
from .df.backends.pd.types import pd_to_df_schema
data = pd.read_csv(StringIO(self.raw))
schema = kw["schema"] = pd_to_df_schema(
data, unknown_as_string=unknown_as_string, as_type=as_type
)
columns = kw.pop("columns", None)
if columns and len(columns) < len(schema):
sel_cols = [c.name for c in self._columns]
data = data[sel_cols]
kw["schema"] = types.OdpsSchema(columns)
except (ImportError, ValueError):
pass
if not kw:
raise ValueError(
"Cannot convert to ResultFrame from %s." % type(self).__name__
)
return ResultFrame(data, **kw)
def to_result_frame(
self,
unknown_as_string=True,
as_type=None,
start=None,
count=None,
columns=None,
**iter_kw
):
read_row_batch_size = options.tunnel.read_row_batch_size
if "end" in iter_kw:
end = iter_kw["end"]
else:
end = (
None
if count is None
else (start or 0) + count * (iter_kw.get("step") or 1)
)
frames = []
if hasattr(self, "raw"):
# data represented as raw csv: just skip iteration
data = [r for r in self._iter(start=start, end=end, **iter_kw)]
else:
offset_iter = itertools.cycle(compat.irange(read_row_batch_size))
data = [None] * read_row_batch_size
for offset, rec in zip(
offset_iter, self._iter(start=start, end=end, **iter_kw)
):
data[offset] = rec
if offset != read_row_batch_size - 1:
continue
frames.append(
self._data_to_result_frame(
data, unknown_as_string=unknown_as_string, as_type=as_type
)
)
data = [None] * read_row_batch_size
if len(frames) > options.tunnel.batch_merge_threshold:
frames = [frames[0].concat(*frames[1:])]
if not frames or data[0] is not None:
data = list(itertools.takewhile(lambda x: x is not None, data))
frames.append(
self._data_to_result_frame(
data,
unknown_as_string=unknown_as_string,
as_type=as_type,
columns=columns,
)
)
return frames[0].concat(*frames[1:])
def to_pandas(self, start=None, count=None, **kw):
import pandas # noqa: F401
return self.to_result_frame(start=start, count=count, **kw).values
class CsvRecordReader(AbstractRecordReader):
NULL_TOKEN = "\\N"
BACK_SLASH_ESCAPE = "\\x%02x" % ord("\\")
def __init__(self, schema, stream, **kwargs):
# shift csv field limit size to match table field size
max_field_size = kwargs.pop("max_field_size", 0) or types.String._max_length
if csv.field_size_limit() < max_field_size:
csv.field_size_limit(max_field_size)
self._schema = schema
self._csv_columns = None
self._fp = stream
if isinstance(self._fp, Response):
self.raw = self._fp.content if six.PY2 else self._fp.text
else:
self.raw = self._fp
if options.tunnel.string_as_binary:
self._csv = csv.reader(six.StringIO(self._escape_csv_bin(self.raw)))
else:
self._csv = csv.reader(six.StringIO(self._escape_csv(self.raw)))
self._filtered_col_names = (
set(x.lower() for x in kwargs["columns"]) if "columns" in kwargs else None
)
self._columns = None
self._filtered_col_idxes = None
@classmethod
def _escape_csv(cls, s):
escaped = utils.to_text(s).encode("unicode_escape")
# Make invisible chars available to `csv` library.
# Note that '\n' and '\r' should be unescaped.
# '\\' should be replaced with '\x5c' before unescaping
# to avoid mis-escaped strings like '\\n'.
return (
utils.to_text(escaped)
.replace("\\\\", cls.BACK_SLASH_ESCAPE)
.replace("\\n", "\n")
.replace("\\r", "\r")
)
@classmethod
def _escape_csv_bin(cls, s):
escaped = utils.to_binary(s).decode("latin1").encode("unicode_escape")
# Make invisible chars available to `csv` library.
# Note that '\n' and '\r' should be unescaped.
# '\\' should be replaced with '\x5c' before unescaping
# to avoid mis-escaped strings like '\\n'.
return (
utils.to_text(escaped)
.replace("\\\\", cls.BACK_SLASH_ESCAPE)
.replace("\\n", "\n")
.replace("\\r", "\r")
)
@staticmethod
def _unescape_csv(s):
return s.encode("utf-8").decode("unicode_escape")
@staticmethod
def _unescape_csv_bin(s):
return s.encode("utf-8").decode("unicode_escape").encode("latin1")
def _readline(self):
try:
values = next(self._csv)
res = []
read_binary = options.tunnel.string_as_binary
if read_binary:
unescape_csv = self._unescape_csv_bin
else:
unescape_csv = self._unescape_csv
for i, value in enumerate(values):
value = unescape_csv(value)
if value == self.NULL_TOKEN:
res.append(None)
elif self._csv_columns and self._csv_columns[i].type == types.boolean:
if value == "true":
res.append(True)
elif value == "false":
res.append(False)
else:
res.append(value)
elif self._csv_columns and isinstance(
self._csv_columns[i].type, types.Map
):
col_type = self._csv_columns[i].type
if not (value.startswith("{") and value.endswith("}")):
raise ValueError("Dict format error!")
items = []
for kv in value[1:-1].split(","):
k, v = kv.split(":", 1)
k = col_type.key_type.cast_value(k.strip(), types.string)
v = col_type.value_type.cast_value(v.strip(), types.string)
items.append((k, v))
res.append(OrderedDict(items))
elif self._csv_columns and isinstance(
self._csv_columns[i].type, types.Array
):
col_type = self._csv_columns[i].type
if not (value.startswith("[") and value.endswith("]")):
raise ValueError("Array format error!")
items = []
for item in value[1:-1].split(","):
item = col_type.value_type.cast_value(
item.strip(), types.string
)
items.append(item)
res.append(items)
else:
res.append(value)
return res
except StopIteration:
return
def __next__(self):
self._load_columns()
values = self._readline()
if not values:
raise StopIteration
if self._filtered_col_idxes:
values = [values[idx] for idx in self._filtered_col_idxes]
return Record(self._columns, values=values)
next = __next__
def read(self, start=None, count=None, step=None):
if count is None:
end = None
else:
start = start or 0
step = step or 1
end = start + count * step
return self._iter(start=start, end=end, step=step)
def _load_columns(self):
if self._csv_columns is not None:
return
values = self._readline()
self._csv_columns = []
for value in values:
if self._schema is None:
self._csv_columns.append(types.Column(name=value, typo="string"))
else:
if self._schema.is_partition(value):
self._csv_columns.append(self._schema.get_partition(value))
else:
self._csv_columns.append(self._schema.get_column(value))
if self._csv_columns is not None and self._filtered_col_names:
self._filtered_col_idxes = []
self._columns = []
for idx, col in enumerate(self._csv_columns):
if col.name.lower() in self._filtered_col_names:
self._filtered_col_idxes.append(idx)
self._columns.append(col)
else:
self._columns = self._csv_columns
def to_pandas(self, start=None, count=None, **kw):
kw.pop("n_process", None)
return super(CsvRecordReader, self).to_pandas(start=start, count=count, **kw)
def close(self):
if hasattr(self._fp, "close"):
self._fp.close()
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
# make class name compatible
RecordReader = CsvRecordReader