easy_rec/python/input/dummy_input.py (33 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
from easy_rec.python.input.input import Input
from easy_rec.python.utils.tf_utils import get_tf_type
if tf.__version__ >= '2.0':
tf = tf.compat.v1
class DummyInput(Input):
"""Dummy memory input.
Dummy Input is used to debug the performance bottleneck of data pipeline.
"""
def __init__(self,
data_config,
feature_config,
input_path,
task_index=0,
task_num=1,
check_mode=False,
pipeline_config=None,
input_vals={}):
super(DummyInput,
self).__init__(data_config, feature_config, input_path, task_index,
task_num, check_mode, pipeline_config)
self._input_vals = input_vals
def _build(self, mode, params):
"""Build fake constant input.
Args:
mode: tf.estimator.ModeKeys.TRAIN / tf.estimator.ModeKeys.EVAL / tf.estimator.ModeKeys.PREDICT
params: parameters passed by estimator, currently not used
Returns:
features tensor dict
label tensor dict
"""
features = {}
for field, field_type, def_val in zip(self._input_fields,
self._input_field_types,
self._input_field_defaults):
tf_type = get_tf_type(field_type)
def_val = self.get_type_defaults(field_type, default_val=def_val)
if field in self._input_vals:
tensor = self._input_vals[field]
else:
tensor = tf.constant([def_val] * self._batch_size, dtype=tf_type)
features[field] = tensor
parse_dict = self._preprocess(features)
return self._get_features(parse_dict), self._get_labels(parse_dict)