in easy_rec/python/tools/feature_selection.py [0:0]
def _process_config(self, feature_importance_map):
"""Process model config and fg config with feature selection."""
excluded_features = set()
for group_name, feature_importance in feature_importance_map.items():
for i, (feature_name, _) in enumerate(feature_importance.items()):
if i >= self._topk:
excluded_features.add(feature_name)
config = config_util.get_configs_from_pipeline_file(self._config_path)
# keep sequence features and side-infos
sequence_features = set()
for feature_group in config.model_config.feature_groups:
for sequence_feature in feature_group.sequence_features:
for seq_att_map in sequence_feature.seq_att_map:
for key in seq_att_map.key:
sequence_features.add(key)
for hist_seq in seq_att_map.hist_seq:
sequence_features.add(hist_seq)
# compat with din
for sequence_feature in config.model_config.seq_att_groups:
for seq_att_map in sequence_feature.seq_att_map:
for key in seq_att_map.key:
sequence_features.add(key)
for hist_seq in seq_att_map.hist_seq:
sequence_features.add(hist_seq)
excluded_features = excluded_features - sequence_features
feature_configs = []
for feature_config in config_util.get_compatible_feature_configs(config):
feature_name = feature_config.feature_name if feature_config.HasField('feature_name') \
else feature_config.input_names[0]
if feature_name not in excluded_features:
feature_configs.append(feature_config)
if config.feature_configs:
config.ClearField('feature_configs')
config.feature_configs.extend(feature_configs)
else:
config.feature_config.ClearField('features')
config.feature_config.features.extend(feature_configs)
for feature_group in config.model_config.feature_groups:
feature_names = []
for feature_name in feature_group.feature_names:
if feature_name not in excluded_features:
feature_names.append(feature_name)
feature_group.ClearField('feature_names')
feature_group.feature_names.extend(feature_names)
config_util.save_message(
config,
os.path.join(self._output_dir, os.path.basename(self._config_path)))
if self._fg_path is not None and len(self._fg_path) > 0:
with tf.gfile.Open(self._fg_path) as f:
fg_json = json.load(f, object_pairs_hook=OrderedDict)
features = []
for feature in fg_json['features']:
if 'feature_name' in feature:
if feature['feature_name'] not in excluded_features:
features.append(feature)
else:
features.append(feature)
fg_json['features'] = features
with tf.gfile.Open(
os.path.join(self._output_dir, os.path.basename(self._fg_path)),
'w') as f:
json.dump(fg_json, f, indent=4)