easy_rec/python/input/odps_input_v3.py (103 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import sys
import tensorflow as tf
from easy_rec.python.input.input import Input
from easy_rec.python.utils import odps_util
from easy_rec.python.utils.tf_utils import get_tf_type
try:
import common_io
except Exception:
common_io = None
class OdpsInputV3(Input):
"""Common IO based interface, could run at local or on data science."""
def __init__(self,
data_config,
feature_config,
input_path,
task_index=0,
task_num=1,
check_mode=False,
pipeline_config=None):
super(OdpsInputV3,
self).__init__(data_config, feature_config, input_path, task_index,
task_num, check_mode, pipeline_config)
self._num_epoch = 0
if common_io is None:
logging.error("""please install common_io pip install
https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/common_io-0.1.0-cp37-cp37m-linux_x86_64.whl"""
)
sys.exit(1)
def _parse_table(self, *fields):
fields = list(fields)
inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
for x in self._label_fids:
inputs[self._input_fields[x]] = fields[x]
return inputs
def _odps_read(self):
logging.info('start epoch[%d]' % self._num_epoch)
self._num_epoch += 1
if type(self._input_path) != list:
self._input_path = self._input_path.split(',')
assert len(
self._input_path) > 0, 'match no files with %s' % self._input_path
# check data_config are consistent with odps tables
odps_util.check_input_field_and_types(self._data_config)
record_defaults = [
self.get_type_defaults(x, v)
for x, v in zip(self._input_field_types, self._input_field_defaults)
]
selected_cols = ','.join(self._input_fields)
for table_path in self._input_path:
reader = common_io.table.TableReader(
table_path,
selected_cols=selected_cols,
slice_id=self._task_index,
slice_count=self._task_num)
total_records_num = reader.get_row_count()
batch_num = int(total_records_num / self._data_config.batch_size)
res_num = total_records_num - batch_num * self._data_config.batch_size
batch_defaults = [
[x] * self._data_config.batch_size for x in record_defaults
]
for batch_id in range(batch_num):
batch_data_np = [x.copy() for x in batch_defaults]
for row_id, one_data in enumerate(
reader.read(self._data_config.batch_size)):
for col_id in range(len(record_defaults)):
if one_data[col_id] not in ['', 'NULL', None]:
batch_data_np[col_id][row_id] = one_data[col_id]
yield tuple(batch_data_np)
if res_num > 0:
batch_data_np = [x[:res_num] for x in batch_defaults]
for row_id, one_data in enumerate(reader.read(res_num)):
for col_id in range(len(record_defaults)):
if one_data[col_id] not in ['', 'NULL', None]:
batch_data_np[col_id][row_id] = one_data[col_id]
yield tuple(batch_data_np)
reader.close()
logging.info('finish epoch[%d]' % self._num_epoch)
def _build(self, mode, params):
# get input type
list_type = [get_tf_type(x) for x in self._input_field_types]
list_type = tuple(list_type)
list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
list_shapes = tuple(list_shapes)
# read odps tables
dataset = tf.data.Dataset.from_generator(
self._odps_read, output_types=list_type, output_shapes=list_shapes)
if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.shuffle(
self._data_config.shuffle_buffer_size,
seed=2020,
reshuffle_each_iteration=True)
dataset = dataset.repeat(self.num_epochs)
else:
dataset = dataset.repeat(1)
dataset = dataset.map(
self._parse_table,
num_parallel_calls=self._data_config.num_parallel_calls)
# preprocess is necessary to transform data
# so that they could be feed into FeatureColumns
dataset = dataset.map(
map_func=self._preprocess,
num_parallel_calls=self._data_config.num_parallel_calls)
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
if mode != tf.estimator.ModeKeys.PREDICT:
dataset = dataset.map(lambda x:
(self._get_features(x), self._get_labels(x)))
else:
dataset = dataset.map(lambda x: (self._get_features(x)))
return dataset