easy_rec/python/input/tfrecord_input.py (78 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging import tensorflow as tf from easy_rec.python.input.input import Input from easy_rec.python.utils.tf_utils import get_tf_type if tf.__version__ >= '2.0': tf = tf.compat.v1 class TFRecordInput(Input): def __init__(self, data_config, feature_config, input_path, task_index=0, task_num=1, check_mode=False, pipeline_config=None): super(TFRecordInput, self).__init__(data_config, feature_config, input_path, task_index, task_num, check_mode, pipeline_config) self.feature_desc = {} for x, t, d in zip(self._input_fields, self._input_field_types, self._input_field_defaults): d = self.get_type_defaults(t, d) t = get_tf_type(t) self.feature_desc[x] = tf.FixedLenFeature( dtype=t, shape=1, default_value=d) def _parse_tfrecord(self, example): try: inputs = tf.parse_single_example(example, features=self.feature_desc) except AttributeError: inputs = tf.io.parse_single_example(example, features=self.feature_desc) return inputs def _build(self, mode, params): if type(self._input_path) != list: self._input_path = self._input_path.split(',') file_paths = [] for x in self._input_path: file_paths.extend(tf.gfile.Glob(x)) assert len(file_paths) > 0, 'match no files with %s' % self._input_path num_parallel_calls = self._data_config.num_parallel_calls data_compression_type = self._data_config.data_compression_type if mode == tf.estimator.ModeKeys.TRAIN: logging.info('train files[%d]: %s' % (len(file_paths), ','.join(file_paths))) dataset = tf.data.Dataset.from_tensor_slices(file_paths) if self._data_config.shuffle: # shuffle input files dataset = dataset.shuffle(len(file_paths)) # too many readers read the same file will cause performance issues # as the same data will be read multiple times parallel_num = min(num_parallel_calls, len(file_paths)) dataset = dataset.interleave( lambda x: tf.data.TFRecordDataset( x, compression_type=data_compression_type), cycle_length=parallel_num, num_parallel_calls=parallel_num) dataset = dataset.shard(self._task_num, self._task_index) if self._data_config.shuffle: dataset = dataset.shuffle( self._data_config.shuffle_buffer_size, seed=2020, reshuffle_each_iteration=True) dataset = dataset.repeat(self.num_epochs) else: logging.info('eval files[%d]: %s' % (len(file_paths), ','.join(file_paths))) dataset = tf.data.TFRecordDataset( file_paths, compression_type=data_compression_type) dataset = dataset.repeat(1) dataset = dataset.map( self._parse_tfrecord, num_parallel_calls=num_parallel_calls) dataset = dataset.batch(self._data_config.batch_size) dataset = dataset.prefetch(buffer_size=self._prefetch_size) dataset = dataset.map( map_func=self._preprocess, num_parallel_calls=num_parallel_calls) dataset = dataset.prefetch(buffer_size=self._prefetch_size) if mode != tf.estimator.ModeKeys.PREDICT: dataset = dataset.map(lambda x: (self._get_features(x), self._get_labels(x))) else: dataset = dataset.map(lambda x: (self._get_features(x))) return dataset