in easy_rec/python/input/input.py [0:0]
def __init__(self,
data_config,
feature_configs,
input_path,
task_index=0,
task_num=1,
check_mode=False,
pipeline_config=None,
**kwargs):
self._pipeline_config = pipeline_config
self._data_config = data_config
self._check_mode = check_mode
logging.info('check_mode: %s ' % self._check_mode)
# tf.estimator.ModeKeys.*, only available before
# calling self._build
self._mode = None
if pipeline_config is not None and pipeline_config.model_config.HasField(
'ev_params'):
self._has_ev = True
else:
self._has_ev = False
if self._data_config.auto_expand_input_fields:
input_fields = [x for x in self._data_config.input_fields]
while len(self._data_config.input_fields) > 0:
self._data_config.input_fields.pop()
for field in input_fields:
tmp_names = config_util.auto_expand_names(field.input_name)
for tmp_name in tmp_names:
one_field = DatasetConfig.Field()
one_field.CopyFrom(field)
one_field.input_name = tmp_name
self._data_config.input_fields.append(one_field)
self._input_fields = [x.input_name for x in data_config.input_fields]
self._input_dims = [x.input_dim for x in data_config.input_fields]
self._input_field_types = [x.input_type for x in data_config.input_fields]
self._input_field_defaults = [
x.default_val for x in data_config.input_fields
]
self._label_fields = list(data_config.label_fields)
self._feature_fields = list(data_config.feature_fields)
self._label_sep = list(data_config.label_sep)
self._label_dim = list(data_config.label_dim)
if len(self._label_dim) < len(self._label_fields):
for x in range(len(self._label_fields) - len(self._label_dim)):
self._label_dim.append(1)
self._label_udf_map = {}
for config in self._data_config.input_fields:
if config.HasField('user_define_fn'):
self._label_udf_map[config.input_name] = self._load_label_fn(config)
self._batch_size = data_config.batch_size
self._prefetch_size = data_config.prefetch_size
self._feature_configs = list(feature_configs)
self._task_index = task_index
self._task_num = task_num
self._input_path = input_path
# findout effective fields
self._effective_fields = []
# for multi value inputs, the types maybe different
# from the types defined in input_fields
# it is used in create_multi_placeholders
self._multi_value_types = {}
self._multi_value_fields = set()
self._normalizer_fn = {}
for fc in self._feature_configs:
for input_name in fc.input_names:
assert input_name in self._input_fields, 'invalid input_name in %s' % str(
fc)
if input_name not in self._effective_fields:
self._effective_fields.append(input_name)
if fc.feature_type in [fc.TagFeature, fc.SequenceFeature]:
if fc.hash_bucket_size > 0 or len(
fc.vocab_list) > 0 or fc.HasField('vocab_file'):
self._multi_value_types[fc.input_names[0]] = tf.string
self._multi_value_fields.add(fc.input_names[0])
else:
self._multi_value_types[fc.input_names[0]] = tf.int64
self._multi_value_fields.add(fc.input_names[0])
if len(fc.input_names) > 1:
self._multi_value_types[fc.input_names[1]] = tf.float32
self._multi_value_fields.add(fc.input_names[1])
if fc.feature_type == fc.RawFeature and fc.raw_input_dim > 1:
self._multi_value_types[fc.input_names[0]] = tf.float32
self._multi_value_fields.add(fc.input_names[0])
if fc.HasField('normalizer_fn'):
feature_name = fc.feature_name if fc.HasField(
'feature_name') else fc.input_names[0]
self._normalizer_fn[feature_name] = load_by_path(fc.normalizer_fn)
# add sample weight to effective fields
if self._data_config.HasField('sample_weight'):
self._effective_fields.append(self._data_config.sample_weight)
# add uid_field of GAUC and session_fields of SessionAUC
if self._pipeline_config is not None:
metrics = self._pipeline_config.eval_config.metrics_set
for metric in metrics:
metric_name = metric.WhichOneof('metric')
if metric_name == 'gauc':
uid = metric.gauc.uid_field
if uid not in self._effective_fields:
self._effective_fields.append(uid)
elif metric_name == 'session_auc':
sid = metric.session_auc.session_id_field
if sid not in self._effective_fields:
self._effective_fields.append(sid)
# check multi task model's metrics
model_config = self._pipeline_config.model_config
model_name = model_config.WhichOneof('model')
if model_name in {'mmoe', 'esmm', 'dbmtl', 'simple_multi_task', 'ple'}:
model = getattr(model_config, model_name)
towers = [model.ctr_tower, model.cvr_tower
] if model_name == 'esmm' else model.task_towers
for tower in towers:
metrics = tower.metrics_set
for metric in metrics:
metric_name = metric.WhichOneof('metric')
if metric_name == 'gauc':
uid = metric.gauc.uid_field
if uid not in self._effective_fields:
self._effective_fields.append(uid)
elif metric_name == 'session_auc':
sid = metric.session_auc.session_id_field
if sid not in self._effective_fields:
self._effective_fields.append(sid)
self._effective_fids = [
self._input_fields.index(x) for x in self._effective_fields
]
# sort fids from small to large
self._effective_fids = list(set(self._effective_fids))
self._effective_fields = [
self._input_fields[x] for x in self._effective_fids
]
self._label_fids = [self._input_fields.index(x) for x in self._label_fields]
# virtual fields generated by self._preprocess
# which will be inputs to feature columns
self._appended_fields = []
# sampler
self._sampler = None
if input_path is not None:
# build sampler only when train and eval
self._sampler = sampler_lib.build(data_config)
self.get_type_defaults = get_type_defaults