in easy_rec/python/feature_column/feature_column.py [0:0]
def __init__(self,
feature_configs,
wide_deep_dict={},
wide_output_dim=-1,
ev_params=None):
"""Initializes a `FeatureColumnParser`.
Args:
feature_configs: collections of
easy_rec.python.protos.feature_config_pb2.FeatureConfig
or easy_rec.python.protos.feature_config_pb2.FeatureConfigV2.features
wide_deep_dict: dict of {feature_name:WideOrDeep}, passed by
easy_rec.python.layers.input_layer.InputLayer, it is defined in
easy_rec.python.protos.easy_rec_model_pb2.EasyRecModel.feature_groups
wide_output_dim: output dimension for wide columns
ev_params: params used by EmbeddingVariable, which is provided by pai-tf
"""
self._feature_configs = feature_configs
self._wide_output_dim = wide_output_dim
self._wide_deep_dict = wide_deep_dict
self._deep_columns = {}
self._wide_columns = {}
self._sequence_columns = {}
self._share_embed_names = {}
self._share_embed_infos = {}
self._vocab_size = {}
self._global_ev_params = None
if ev_params is not None:
self._global_ev_params = self._build_ev_params(ev_params)
def _cmp_embed_config(a, b):
return a.embedding_dim == b.embedding_dim and a.combiner == b.combiner and\
a.initializer == b.initializer and a.max_partitions == b.max_partitions and\
a.embedding_name == b.embedding_name
for config in self._feature_configs:
if not config.HasField('embedding_name'):
continue
embed_name = config.embedding_name
if embed_name in self._share_embed_names:
assert _cmp_embed_config(config, self._share_embed_infos[embed_name]),\
'shared embed info of [%s] is not matched [%s] vs [%s]' % (
embed_name, config, self._share_embed_infos[embed_name])
self._share_embed_names[embed_name] += 1
if config.feature_type == FeatureConfig.FeatureType.SequenceFeature:
self._share_embed_infos[embed_name] = copy_obj(config)
else:
self._share_embed_names[embed_name] = 1
self._share_embed_infos[embed_name] = copy_obj(config)
# remove not shared embedding names
not_shared = [
x for x in self._share_embed_names if self._share_embed_names[x] == 1
]
for embed_name in not_shared:
del self._share_embed_names[embed_name]
del self._share_embed_infos[embed_name]
logging.info('shared embeddings[num=%d]' % len(self._share_embed_names))
for embed_name in self._share_embed_names:
logging.info('\t%s: share_num[%d], share_info[%s]' %
(embed_name, self._share_embed_names[embed_name],
self._share_embed_infos[embed_name]))
self._deep_share_embed_columns = {
embed_name: [] for embed_name in self._share_embed_names
}
self._wide_share_embed_columns = {
embed_name: [] for embed_name in self._share_embed_names
}
self._feature_vocab_size = {}
for config in self._feature_configs:
assert isinstance(config, FeatureConfig)
try:
if config.feature_type == config.IdFeature:
self.parse_id_feature(config)
elif config.feature_type == config.TagFeature:
self.parse_tag_feature(config)
elif config.feature_type == config.RawFeature:
self.parse_raw_feature(config)
elif config.feature_type == config.ComboFeature:
self.parse_combo_feature(config)
elif config.feature_type == config.LookupFeature:
self.parse_lookup_feature(config)
elif config.feature_type == config.SequenceFeature:
self.parse_sequence_feature(config)
elif config.feature_type == config.ExprFeature:
self.parse_expr_feature(config)
elif config.feature_type != config.PassThroughFeature:
assert False, 'invalid feature type: %s' % config.feature_type
except FeatureKeyError:
pass
for embed_name in self._share_embed_names:
initializer = None
if self._share_embed_infos[embed_name].HasField('initializer'):
initializer = hyperparams_builder.build_initializer(
self._share_embed_infos[embed_name].initializer)
partitioner = self._build_partitioner(self._share_embed_infos[embed_name])
if self._share_embed_infos[embed_name].HasField('ev_params'):
ev_params = self._build_ev_params(
self._share_embed_infos[embed_name].ev_params)
else:
ev_params = self._global_ev_params
# for handling share embedding columns
if len(self._deep_share_embed_columns[embed_name]) > 0:
share_embed_fcs = feature_column.shared_embedding_columns(
self._deep_share_embed_columns[embed_name],
self._share_embed_infos[embed_name].embedding_dim,
initializer=initializer,
shared_embedding_collection_name=embed_name,
combiner=self._share_embed_infos[embed_name].combiner,
partitioner=partitioner,
ev_params=ev_params)
config = self._share_embed_infos[embed_name]
max_seq_len = config.max_seq_len if config.HasField(
'max_seq_len') else -1
for fc in share_embed_fcs:
fc.max_seq_length = max_seq_len
self._deep_share_embed_columns[embed_name] = share_embed_fcs
# for handling wide share embedding columns
if len(self._wide_share_embed_columns[embed_name]) > 0:
share_embed_fcs = feature_column.shared_embedding_columns(
self._wide_share_embed_columns[embed_name],
self._wide_output_dim,
initializer=initializer,
shared_embedding_collection_name=embed_name + '_wide',
combiner='sum',
partitioner=partitioner,
ev_params=ev_params)
config = self._share_embed_infos[embed_name]
max_seq_len = config.max_seq_len if config.HasField(
'max_seq_len') else -1
for fc in share_embed_fcs:
fc.max_seq_length = max_seq_len
self._wide_share_embed_columns[embed_name] = share_embed_fcs
for fc_name in self._deep_columns:
fc = self._deep_columns[fc_name]
if isinstance(fc, SharedEmbedding):
self._deep_columns[fc_name] = self._get_shared_embedding_column(fc)
for fc_name in self._wide_columns:
fc = self._wide_columns[fc_name]
if isinstance(fc, SharedEmbedding):
self._wide_columns[fc_name] = self._get_shared_embedding_column(
fc, deep=False)
for fc_name in self._sequence_columns:
fc = self._sequence_columns[fc_name]
if isinstance(fc, SharedEmbedding):
self._sequence_columns[fc_name] = self._get_shared_embedding_column(fc)