tzrec/tools/convert_easyrec_config_to_tzrec_config.py (689 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import io import json import os import sys import tarfile import tempfile import zipfile from collections import OrderedDict import requests from google.protobuf import descriptor_pool, symbol_database, text_format from tzrec.constant import EASYREC_VERSION from tzrec.protos import feature_pb2 as tzrec_feature_pb2 from tzrec.protos import ( loss_pb2, metric_pb2, model_pb2, module_pb2, seq_encoder_pb2, tower_pb2, ) from tzrec.protos import pipeline_pb2 as tzrec_pipeline_pb2 from tzrec.protos.data_pb2 import DatasetType from tzrec.protos.models import match_model_pb2, multi_task_rank_pb2, rank_model_pb2 from tzrec.utils.logging_util import logger def _get_easyrec(pkg_path=None): """Get easyrec whl and extract.""" local_cache_dir = tempfile.mkdtemp(prefix="tzrec_tmp") if pkg_path is None: pkg_path = ( f"https://easyrec.oss-cn-beijing.aliyuncs.com/release/whls/" f"easy_rec-{EASYREC_VERSION}-py2.py3-none-any.whl" ) if pkg_path.startswith("http"): logger.info(f"downloading easyrec from {pkg_path}") r = requests.get(pkg_path) content = r.content else: with open(pkg_path, "rb") as f: content = f.read() if ".tar" in pkg_path: try: with tarfile.open(fileobj=io.BytesIO(content)) as tar: tar.extractall(path=local_cache_dir) local_package_dir = local_cache_dir except Exception as e: raise RuntimeError(f"invalid {pkg_path} tar.") from e else: try: with zipfile.ZipFile(io.BytesIO(content)) as f: f.extractall(local_cache_dir) local_package_dir = local_cache_dir except zipfile.BadZipfile as e: raise RuntimeError(f"invalid {pkg_path} whl.") from e with open(os.path.join(local_package_dir, "easy_rec/__init__.py"), "w") as f: f.write("") sys.path.append(local_package_dir) _sym = symbol_database.Default() _sym.pool = descriptor_pool.DescriptorPool() from easy_rec.python.protos import feature_config_pb2 as _feature_config_pb2 from easy_rec.python.protos import loss_pb2 as _loss_pb2 from easy_rec.python.protos import pipeline_pb2 as _pipeline_pb2 globals()["easyrec_pipeline_pb2"] = _pipeline_pb2 globals()["easyrec_feature_config_pb2"] = _feature_config_pb2 globals()["easyrec_loss_pb2"] = _loss_pb2 class ConvertConfig(object): """Convert EasyRec config to tzrec config. Args: easyrec_config_path (str): EasyRec config file path. fg_json_path (str): EasyRec use fg.json file path. output_tzrec_config_path (str): TzRec config file path will create. """ def __init__( self, easyrec_config_path, output_tzrec_config_path, fg_json_path=None, easyrec_package_path=None, ): if "easyrec_pipeline_pb2" not in globals(): _get_easyrec(easyrec_package_path) self.output_tzrec_config_path = output_tzrec_config_path self.easyrec_config = self.load_easyrec_config(easyrec_config_path) self.feature_to_fg = {} self.sub_sequence_to_group = {} self.sequence_feature_to_fg = {} if fg_json_path is not None: fg_json = self.load_easyrec_fg_json(fg_json_path) self.analyse_fg(fg_json) def analyse_fg(self, fg_json): """Analysis fg.json.""" for feat in fg_json["features"]: if "sequence_name" in feat: sequence_name = feat["sequence_name"] for sub_feat in feat["features"]: self.sub_sequence_to_group[ f"{sequence_name}__{sub_feat['feature_name']}" ] = sequence_name self.sequence_feature_to_fg[sequence_name] = feat else: feature_name = feat["feature_name"] self.feature_to_fg[feature_name] = feat def load_easyrec_config(self, path): """Load easyrec config.""" easyrec_config = easyrec_pipeline_pb2.EasyRecConfig() # NOQA with open(path, "r", encoding="utf-8") as f: cfg_str = f.read() text_format.Merge(cfg_str, easyrec_config) return easyrec_config def load_easyrec_fg_json(self, path): """Load easyrec use fg.json.""" with open(path, "r", encoding="utf-8") as f: fg_json = json.load(f) return fg_json def _create_train_config(self, pipeline_config): """Create easy_rec train config.""" if not pipeline_config.HasField("train_config"): train_config_str = """ train_config { sparse_optimizer { adam_optimizer { lr: 0.001 } constant_learning_rate { } } dense_optimizer { adam_optimizer { lr: 0.001 } constant_learning_rate { } } num_epochs: 1 use_tensorboard: false }""" text_format.Merge(train_config_str, pipeline_config) return pipeline_config def _create_eval_config(self, pipeline_config): """Create tzrec train config.""" if not pipeline_config.HasField("eval_config"): eval_config_str = "eval_config {}" text_format.Merge(eval_config_str, pipeline_config) return pipeline_config def _create_data_config(self, pipeline_config): """Create tzrec data config.""" label_fields = list(self.easyrec_config.data_config.label_fields) pipeline_config.data_config.batch_size = ( self.easyrec_config.data_config.batch_size ) pipeline_config.data_config.dataset_type = DatasetType.OdpsDataset pipeline_config.data_config.label_fields.extend(label_fields) pipeline_config.data_config.num_workers = 8 return pipeline_config def _create_feature_config(self, pipeline_config): """Create tzrec feature config.""" easyrec_feature_config = easyrec_feature_config_pb2.FeatureConfig() # NOQA seq_group_cfg = OrderedDict() for cfg in self.easyrec_config.feature_configs: if cfg.feature_name: feature_name = cfg.feature_name else: feature_name = list(cfg.input_names)[0] input_names = cfg.input_names feature_type = cfg.feature_type if feature_name in self.feature_to_fg: fg_json = self.feature_to_fg[feature_name] elif feature_name in self.sub_sequence_to_group: pass elif input_names[0] in self.feature_to_fg: fg_json = self.feature_to_fg[input_names[0]] else: logger.error(f"in easyrec config {feature_name} not in fg.json") feature_config = None if feature_type == easyrec_feature_config.IdFeature: feature_config = tzrec_feature_pb2.FeatureConfig() feature = tzrec_feature_pb2.IdFeature() feature.feature_name = feature_name feature.expression = fg_json["expression"] feature.embedding_dim = cfg.embedding_dim feature.hash_bucket_size = cfg.hash_bucket_size feature_config.ClearField("feature") feature_config.id_feature.CopyFrom(feature) elif feature_type == easyrec_feature_config.TagFeature: feature_config = tzrec_feature_pb2.FeatureConfig() feature = tzrec_feature_pb2.IdFeature() feature.feature_name = feature_name feature.expression = fg_json["expression"] feature.embedding_dim = cfg.embedding_dim feature.hash_bucket_size = cfg.hash_bucket_size if cfg.HasField("kv_separator"): feature.weighted = True feature_config.ClearField("feature") feature_config.id_feature.CopyFrom(feature) elif feature_type == easyrec_feature_config.SequenceFeature: if feature_name in self.sub_sequence_to_group: sequence_name = self.sub_sequence_to_group[feature_name] if sequence_name in seq_group_cfg: seq_group_cfg[sequence_name].append(cfg) else: seq_group_cfg[sequence_name] = [cfg] elif feature_name in self.feature_to_fg: feature_config = tzrec_feature_pb2.FeatureConfig() if cfg.sub_feature_type == easyrec_feature_config.IdFeature: feature = tzrec_feature_pb2.SequenceIdFeature() feature.feature_name = feature_name feature.expression = self.feature_to_fg[feature_name][ "expression" ] feature.embedding_dim = cfg.embedding_dim feature.hash_bucket_size = cfg.hash_bucket_size feature_config.ClearField("feature") feature_config.sequence_id_feature.CopyFrom(feature) else: feature = tzrec_feature_pb2.SequenceRawFeature() feature.feature_name = feature_name feature.expression = self.feature_to_fg[feature_name][ "expression" ] boundaries = list(cfg.boundaries) feature.embedding_dim = cfg.embedding_dim if len(boundaries): feature.boundaries.extend(boundaries) feature_config.ClearField("feature") feature_config.sequence_raw_feature.CopyFrom(feature) else: logger.error(f"sequences feature: {feature_name} can't converted") elif feature_type == easyrec_feature_config.RawFeature: feature_config = tzrec_feature_pb2.FeatureConfig() if fg_json["feature_type"] == "lookup_feature": feature = tzrec_feature_pb2.LookupFeature() feature.feature_name = feature_name map = fg_json["map"] key = fg_json["key"] boundaries = list(cfg.boundaries) feature.feature_name = feature_name feature.map = map feature.key = key feature.embedding_dim = cfg.embedding_dim if len(boundaries): feature.boundaries.extend(boundaries) feature_config.ClearField("feature") feature_config.lookup_feature.CopyFrom(feature) else: feature = tzrec_feature_pb2.RawFeature() feature.feature_name = feature_name feature.expression = fg_json["expression"] boundaries = list(cfg.boundaries) feature.embedding_dim = cfg.embedding_dim if len(boundaries): feature.boundaries.extend(boundaries) feature_config.ClearField("feature") feature_config.raw_feature.CopyFrom(feature) elif feature_type == easyrec_feature_config.ComboFeature: feature_config = tzrec_feature_pb2.FeatureConfig() feature = tzrec_feature_pb2.ComboFeature() feature.feature_name = feature_name for input in list(cfg.input_names): if input in self.feature_to_fg: tmp_fg_json = self.feature_to_fg[input] feature.expression.append(tmp_fg_json["expression"]) else: raise ValueError(f"{cfg} input_names:{input} not in fg json") feature.embedding_dim = cfg.embedding_dim feature.hash_bucket_size = cfg.hash_bucket_size feature_config.ClearField("feature") feature_config.combo_feature.CopyFrom(feature) elif feature_type == easyrec_feature_config.LookupFeature: feature_config = tzrec_feature_pb2.FeatureConfig() feature = tzrec_feature_pb2.LookupFeature() feature.feature_name = feature_name map_f = cfg.input_names[0] key_f = cfg.input_names[1] if map_f in self.feature_to_fg: feature.map = self.feature_to_fg[map_f]["expression"] else: raise ValueError(f"{cfg} input names: {map_f} not in fg.json") if key_f in self.feature_to_fg: feature.key = self.feature_to_fg[key_f]["expression"] else: raise ValueError(f"{cfg} input names: {map_f} not in fg.json") feature.embedding_dim = cfg.embedding_dim if len(list(cfg.boundaries)): feature.boundaries.extend(list(cfg.boundaries)) feature_config.ClearField("feature") feature_config.lookup_feature.CopyFrom(feature) else: logger.error(f"{feature_name} can't converted") if feature_config is not None: pipeline_config.feature_configs.append(feature_config) for seq_name, sub_cfgs in seq_group_cfg.items(): sequence_fg = self.sequence_feature_to_fg[seq_name] feature_config = tzrec_feature_pb2.FeatureConfig() sequence_feature_config = tzrec_feature_pb2.SequenceFeature() sequence_feature_config.sequence_name = sequence_fg["sequence_name"] sequence_feature_config.sequence_length = sequence_fg["sequence_length"] sequence_feature_config.sequence_delim = sequence_fg["sequence_delim"] features = sequence_fg["features"] seq_feature_to_fg = {} for feature in features: seq_feature_to_fg[f"{seq_name}__{feature['feature_name']}"] = feature for cfg in sub_cfgs: sub_feature_cfg = tzrec_feature_pb2.SeqFeatureConfig() feature_name = ( cfg.feature_name if cfg.feature_name else cfg.input_names[0] ) if feature_name in seq_feature_to_fg: seq_feature_fg = seq_feature_to_fg[feature_name] if cfg.sub_feature_type == easyrec_feature_config.IdFeature: feature = tzrec_feature_pb2.IdFeature() feature.feature_name = seq_feature_fg["feature_name"] feature.expression = seq_feature_fg["expression"] feature.embedding_dim = cfg.embedding_dim feature.hash_bucket_size = cfg.hash_bucket_size sub_feature_cfg.ClearField("feature") sub_feature_cfg.id_feature.CopyFrom(feature) else: feature = tzrec_feature_pb2.RawFeature() feature.feature_name = seq_feature_fg["feature_name"] feature.expression = seq_feature_fg["expression"] boundaries = list(cfg.boundaries) feature.embedding_dim = cfg.embedding_dim if len(boundaries): feature.boundaries.extend(boundaries) sub_feature_cfg.ClearField("feature") sub_feature_cfg.raw_feature.CopyFrom(feature) sequence_feature_config.features.append(sub_feature_cfg) else: logger.error( f"sequence feature: {feature_name} not config in fg.json" ) feature_config.sequence_feature.CopyFrom(sequence_feature_config) pipeline_config.feature_configs.append(feature_config) return pipeline_config def _create_feature_config_no_fg(self, pipeline_config): """Create tzrec feature config no fg json.""" easyrec_feature_config = easyrec_feature_config_pb2.FeatureConfig() # NOQA for cfg in self.easyrec_config.feature_configs: if cfg.feature_name: feature_name = cfg.feature_name else: feature_name = list(cfg.input_names)[0] input_names = cfg.input_names feature_type = cfg.feature_type feature_config = None if feature_type == easyrec_feature_config.IdFeature: feature_config = tzrec_feature_pb2.FeatureConfig() feature = tzrec_feature_pb2.IdFeature() feature.feature_name = feature_name feature.expression = f"user:{input_names[0]}" feature.embedding_dim = cfg.embedding_dim feature.hash_bucket_size = cfg.hash_bucket_size feature_config.ClearField("feature") feature_config.id_feature.CopyFrom(feature) elif feature_type == easyrec_feature_config.TagFeature: feature_config = tzrec_feature_pb2.FeatureConfig() feature = tzrec_feature_pb2.IdFeature() feature.feature_name = feature_name feature.expression = f"user:{input_names[0]}" feature.embedding_dim = cfg.embedding_dim feature.hash_bucket_size = cfg.hash_bucket_size if cfg.HasField("kv_separator"): feature.weighted = True feature_config.ClearField("feature") feature_config.id_feature.CopyFrom(feature) elif feature_type == easyrec_feature_config.SequenceFeature: feature_config = tzrec_feature_pb2.FeatureConfig() if cfg.sub_feature_type == easyrec_feature_config.RawFeature: feature = tzrec_feature_pb2.SequenceRawFeature() feature.feature_name = feature_name feature.expression = f"user:{input_names[0]}" feature.sequence_length = cfg.sequence_length feature.sequence_delim = cfg.separator feature.embedding_dim = cfg.embedding_dim boundaries = list(cfg.boundaries) if len(boundaries) > 0: feature.boundaries.extend(boundaries) feature_config.ClearField("feature") feature_config.sequence_raw_feature.CopyFrom(feature) else: feature = tzrec_feature_pb2.SequenceIdFeature() feature.feature_name = feature_name feature.expression = f"user:{input_names[0]}" feature.sequence_length = cfg.sequence_length feature.sequence_delim = cfg.separator feature.embedding_dim = cfg.embedding_dim if cfg.HasField("hash_bucket_size"): feature.hash_bucket_size = cfg.hash_bucket_size if cfg.HasField("num_buckets"): feature.num_buckets = cfg.num_buckets feature_config.ClearField("feature") feature_config.sequence_id_feature.CopyFrom(feature) if cfg.sequence_length <= 1: logger.error(f"{feature_name} sequence_length is invalid !!!") elif feature_type == easyrec_feature_config.RawFeature: feature_config = tzrec_feature_pb2.FeatureConfig() feature = tzrec_feature_pb2.RawFeature() feature.feature_name = feature_name feature.expression = f"user:{input_names[0]}" boundaries = list(cfg.boundaries) if cfg.HasField("embedding_dim"): feature.embedding_dim = cfg.embedding_dim if len(boundaries): feature.boundaries.extend(boundaries) feature_config.ClearField("feature") feature_config.raw_feature.CopyFrom(feature) elif feature_type == easyrec_feature_config.ComboFeature: feature_config = tzrec_feature_pb2.FeatureConfig() feature = tzrec_feature_pb2.ComboFeature() feature.feature_name = feature_name for input in list(cfg.input_names): feature.expression.append(f"user:{input}") feature.embedding_dim = cfg.embedding_dim feature.hash_bucket_size = cfg.hash_bucket_size feature_config.ClearField("feature") feature_config.combo_feature.CopyFrom(feature) elif feature_type == easyrec_feature_config.LookupFeature: feature_config = tzrec_feature_pb2.FeatureConfig() feature = tzrec_feature_pb2.LookupFeature() feature.feature_name = feature_name feature.map = f"user:{input_names[0]}" feature.key = f"user:{input_names[1]}" if cfg.HasField("embedding_dim"): feature.embedding_dim = cfg.embedding_dim if len(list(cfg.boundaries)): feature.boundaries.extend(list(cfg.boundaries)) feature_config.ClearField("feature") feature_config.lookup_feature.CopyFrom(feature) else: logger.error(f"{feature_name} can't converted") if feature_config is not None: logger.info(f"{feature_name} converted succeeded") pipeline_config.feature_configs.append(feature_config) return pipeline_config def _easyrec_dnn_2_tzrec_mlp(self, dnn): """Convert easyrec dnn to tzrec mlp.""" mlp = module_pb2.MLP() mlp.hidden_units.extend(dnn.hidden_units) mlp.dropout_ratio.extend(dnn.dropout_ratio) mlp.use_bn = dnn.use_bn return mlp def _easyrec_loss_2_tzrec_loss(self, easyrec_loss): """Convert easyrec loss to tzrec loss.""" tzrec_loss = loss_pb2.LossConfig() loss_type = easyrec_loss.loss_type if loss_type == easyrec_loss_pb2.LossType.JRC_LOSS: # NOQA tzrec_loss.jrc_loss.CopyFrom(loss_pb2.JRCLoss()) elif loss_type == easyrec_loss_pb2.LossType.L2_LOSS: # NOQA tzrec_loss.l2_loss.CopyFrom(loss_pb2.L2Loss()) elif loss_type == easyrec_loss_pb2.LossType.SOFTMAX_CROSS_ENTROPY: # NOQA tzrec_loss.softmax_cross_entropy.CopyFrom(loss_pb2.SoftmaxCrossEntropy()) elif loss_type == easyrec_loss_pb2.LossType.CLASSIFICATION: # NOQA tzrec_loss.binary_cross_entropy.CopyFrom(loss_pb2.BinaryCrossEntropy()) else: logger.error( f"{easyrec_loss} is not convert to tzrec loss, please adaptation" ) return tzrec_loss def _easyrec_metrics_2_tzrec_metrics(self, easyrec_metric): """Convert easyrec metric to tzrec metric.""" metric = metric_pb2.MetricConfig() metric_type = easyrec_metric.WhichOneof("metric") easyrec_metric_ob = getattr(easyrec_metric, metric_type) if metric_type == "auc": metric.auc.CopyFrom(metric_pb2.AUC()) elif metric_type == "gauc": tzrec_metric_ob = metric_pb2.GroupedAUC( grouping_key=easyrec_metric_ob.uid_field ) metric.grouped_auc.CopyFrom(tzrec_metric_ob) elif metric_type == "recall_at_topk": metric.recall_at_k.CopyFrom(metric_pb2.RecallAtK()) elif metric_type == "mean_absolute_error": metric.mean_absolute_error.CopyFrom(metric_pb2.MeanAbsoluteError()) elif metric_type == "mean_squared_error": metric.mean_squared_error.CopyFrom(metric_pb2.MeanSquaredError()) elif metric_type == "accuracy": metric.accuracy.CopyFrom(metric_pb2.Accuracy()) else: logger.error( f"{easyrec_metric} is not convert to tzrec metric, please adaptation" ) return metric def _easyrec_bayes_tower_2_tzrec_bayes_tower(self, easyrec_bayes_task_tower): """Convert easyrec bayes tower to tzrec bayes tower.""" tzrec_bayes_task_tower = tower_pb2.BayesTaskTower() tzrec_bayes_task_tower.tower_name = easyrec_bayes_task_tower.tower_name tzrec_bayes_task_tower.label_name = easyrec_bayes_task_tower.label_name tzrec_bayes_task_tower.num_class = easyrec_bayes_task_tower.num_class tzrec_bayes_task_tower.relation_tower_names.extend( easyrec_bayes_task_tower.relation_tower_names ) mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_bayes_task_tower.dnn) tzrec_bayes_task_tower.mlp.CopyFrom(mlp) relation_mlp = self._easyrec_dnn_2_tzrec_mlp( easyrec_bayes_task_tower.relation_dnn ) tzrec_bayes_task_tower.relation_mlp.CopyFrom(relation_mlp) for loss in easyrec_bayes_task_tower.losses: tzrec_bayes_task_tower.losses.append(self._easyrec_loss_2_tzrec_loss(loss)) for metric in easyrec_bayes_task_tower.metrics_set: tzrec_bayes_task_tower.metrics.append( self._easyrec_metrics_2_tzrec_metrics(metric) ) return tzrec_bayes_task_tower def _easyrec_task_tower_2_tzrec_task_tower(self, easyrec_task_tower): """Convert easyrec task tower to tzrec task tower.""" tzrec_task_tower = tower_pb2.TaskTower() tzrec_task_tower.tower_name = easyrec_task_tower.tower_name tzrec_task_tower.label_name = easyrec_task_tower.label_name tzrec_task_tower.num_class = easyrec_task_tower.num_class mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_task_tower.dnn) tzrec_task_tower.mlp.CopyFrom(mlp) for loss in easyrec_task_tower.losses: tzrec_task_tower.losses.append(self._easyrec_loss_2_tzrec_loss(loss)) for metric in easyrec_task_tower.metrics_set: tzrec_task_tower.metrics.append( self._easyrec_metrics_2_tzrec_metrics(metric) ) return tzrec_task_tower def _easyrec_tower_2_tzrec_tower(self, easyrec_tower): """Convert easyrec tower to tzrec tower.""" tzrec_tower = tower_pb2.Tower() tzrec_tower.input = easyrec_tower.input mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_tower.dnn) tzrec_tower.mlp.CopyFrom(mlp) return tzrec_tower def _easyrec_dssm_tower_2_tzrec_tower(self, easyrec_dssm_tower): """Convert easyrec dssm tower to tzrec tower.""" tzrec_tower = tower_pb2.Tower() tzrec_tower.input = easyrec_dssm_tower.id mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_dssm_tower.dnn) tzrec_tower.mlp.CopyFrom(mlp) return tzrec_tower def _easyrec_extraction_network_2_tzrec_extraction_network( self, easyrec_extraction_network ): """Convert easyrec extraction net to tzrec extraction net.""" tzrec_extraction_network = module_pb2.ExtractionNetwork() tzrec_extraction_network.network_name = easyrec_extraction_network.network_name tzrec_extraction_network.expert_num_per_task = ( easyrec_extraction_network.expert_num_per_task ) tzrec_extraction_network.share_num = easyrec_extraction_network.share_num task_expert_net = self._easyrec_dnn_2_tzrec_mlp( easyrec_extraction_network.task_expert_net ) tzrec_extraction_network.task_expert_net.CopyFrom(task_expert_net) share_expert_net = self._easyrec_dnn_2_tzrec_mlp( easyrec_extraction_network.share_expert_net ) tzrec_extraction_network.share_expert_net.CopyFrom(share_expert_net) return tzrec_extraction_network def _convert_model_feature_group(self, easyrec_feature_groups): """Convert easyrec feature group to tzrec feature group.""" tz_feature_groups = [] for easy_feature_group in easyrec_feature_groups: tz_feature_group = model_pb2.FeatureGroupConfig() tz_feature_group.group_name = easy_feature_group.group_name tz_feature_group.feature_names.extend(easy_feature_group.feature_names) if ( easy_feature_group.wide_deep == easyrec_feature_config_pb2.WideOrDeep.WIDE # NOQA ): tz_feature_group.group_type = model_pb2.FeatureGroupType.WIDE else: tz_feature_group.group_type = model_pb2.FeatureGroupType.DEEP for i, easyrec_sequence_group in enumerate( easy_feature_group.sequence_features ): tz_seq_group = model_pb2.SeqGroupConfig() tz_seq_encoder = seq_encoder_pb2.SeqEncoderConfig() seq_encoder = seq_encoder_pb2.DINEncoder() if easyrec_sequence_group.HasField("group_name"): group_name = easyrec_sequence_group.group_name else: group_name = f"seq_{i}" tz_seq_group.group_name = group_name seq_encoder.input = group_name mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_sequence_group.seq_dnn) seq_encoder.attn_mlp.CopyFrom(mlp) tz_seq_encoder.din_encoder.CopyFrom(seq_encoder) for seq_att_map in easyrec_sequence_group.seq_att_map: tz_seq_group.feature_names.extend(seq_att_map.key) tz_seq_group.feature_names.extend(seq_att_map.hist_seq) tz_seq_group.feature_names.extend(seq_att_map.aux_hist_seq) tz_feature_group.sequence_groups.append(tz_seq_group) tz_feature_group.sequence_encoders.append(tz_seq_encoder) tz_feature_groups.append(tz_feature_group) return tz_feature_groups def _convert_model_config(self, easyrec_model_config, tz_model_config): """Convert easyrec model config to tzrec model config.""" model_class = easyrec_model_config.model_class model_type = easyrec_model_config.WhichOneof("model") easyrec_model_config = getattr(easyrec_model_config, model_type) if model_class == "DBMTL": tz_model_config_ob = multi_task_rank_pb2.DBMTL() bottom_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.bottom_dnn) expert_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.expert_dnn) tz_model_config_ob.bottom_mlp.CopyFrom(bottom_mlp) tz_model_config_ob.expert_mlp.CopyFrom(expert_mlp) tz_model_config_ob.num_expert = easyrec_model_config.num_expert for task_tower in easyrec_model_config.task_towers: tz_task_tower = self._easyrec_bayes_tower_2_tzrec_bayes_tower( task_tower ) tz_model_config_ob.task_towers.append(tz_task_tower) tz_model_config.dbmtl.CopyFrom(tz_model_config_ob) elif model_class == "SimpleMultiTask": tz_model_config_ob = multi_task_rank_pb2.SimpleMultiTask() for task_tower in easyrec_model_config.task_towers: tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower) tz_model_config_ob.task_towers.append(tz_task_tower) tz_model_config.simple_multi_task.CopyFrom(tz_model_config_ob) elif model_class == "MMoE": tz_model_config_ob = multi_task_rank_pb2.MMoE() expert_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.expert_dnn) tz_model_config_ob.expert_mlp.CopyFrom(expert_mlp) tz_model_config_ob.gate_mlp.CopyFrom(expert_mlp) tz_model_config_ob.num_expert = easyrec_model_config.num_expert for task_tower in easyrec_model_config.task_towers: tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower) tz_model_config_ob.task_towers.append(tz_task_tower) tz_model_config.mmoe.CopyFrom(tz_model_config_ob) elif model_class == "PLE": tz_model_config_ob = multi_task_rank_pb2.PLE() for extraction_network in easyrec_model_config.extraction_networks: tz_extraction_network = ( self._easyrec_extraction_network_2_tzrec_extraction_network( extraction_network ) ) tz_model_config.ple.extraction_networks.append(tz_extraction_network) for task_tower in easyrec_model_config.task_towers: tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower) tz_model_config_ob.task_towers.append(tz_task_tower) tz_model_config.ple.CopyFrom(tz_model_config_ob) elif model_class == "DeepFM": tz_model_config_ob = rank_model_pb2.DeepFM() deep = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.dnn) final = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.final_dnn) tz_model_config_ob.deep.CopyFrom(deep) tz_model_config_ob.final.CopyFrom(final) if easyrec_model_config.HasField("wide_output_dim"): tz_model_config_ob.wide_embedding_dim = ( easyrec_model_config.wide_output_dim ) tz_model_config.deepfm.CopyFrom(tz_model_config_ob) elif model_class == "MultiTower": tz_model_config_ob = rank_model_pb2.MultiTower() for tower in easyrec_model_config.towers: tz_tower = self._easyrec_tower_2_tzrec_tower(tower) tz_model_config_ob.towers.append(tz_tower) final = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.final_dnn) tz_model_config_ob.final.CopyFrom(final) tz_model_config.multi_tower.CopyFrom(tz_model_config_ob) elif model_class == "DSSM": tz_model_config_ob = match_model_pb2.DSSM() user_tower = self._easyrec_dssm_tower_2_tzrec_tower( easyrec_model_config.user_tower ) tz_model_config_ob.user_tower.CopyFrom(user_tower) item_tower = self._easyrec_dssm_tower_2_tzrec_tower( easyrec_model_config.item_tower ) tz_model_config_ob.item_tower.CopyFrom(item_tower) tz_model_config_ob.output_dim = 32 if hasattr( easyrec_model_config, "temperature" ) and easyrec_model_config.HasField("temperature"): tz_model_config_ob.temperature = easyrec_model_config.temperature tz_model_config.dssm.CopyFrom(tz_model_config_ob) else: logger.error( f"{model_class} is not convert to tzrec model, please adaptation" ) return tz_model_config def _create_model_config(self, pipeline_config): """Convert easyrec model config to tzrec model config.""" tz_model_config = model_pb2.ModelConfig() easyrec_model_config = self.easyrec_config.model_config easyrec_feature_groups = easyrec_model_config.feature_groups tz_feature_groups = self._convert_model_feature_group(easyrec_feature_groups) tz_model_config.feature_groups.extend(tz_feature_groups) tz_model_config = self._convert_model_config( easyrec_model_config, tz_model_config ) pipeline_config.model_config.CopyFrom(tz_model_config) return pipeline_config def build(self): """Create tzrec model config order by easyrec config and fg file.""" tzrec_config = tzrec_pipeline_pb2.EasyRecConfig() tzrec_config = self._create_train_config(tzrec_config) tzrec_config = self._create_eval_config(tzrec_config) tzrec_config = self._create_data_config(tzrec_config) if len(self.feature_to_fg): tzrec_config = self._create_feature_config(tzrec_config) else: tzrec_config = self._create_feature_config_no_fg(tzrec_config) tzrec_config = self._create_model_config(tzrec_config) config_text = text_format.MessageToString(tzrec_config, as_utf8=True) with open(self.output_tzrec_config_path, "w") as f: f.write(config_text) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--easyrec_config_path", type=str, default=None, help="easyrec model config path", ) parser.add_argument( "--fg_json_path", type=str, default=None, help="easyrec use fg.json path" ) parser.add_argument( "--output_tzrec_config_path", type=str, default=None, help="output tzrec config path", ) parser.add_argument( "--easyrec_package_path", type=str, default=None, help="easyrec whl or tar package path or url", ) args, extra_args = parser.parse_known_args() fs = ConvertConfig( args.easyrec_config_path, args.output_tzrec_config_path, args.fg_json_path, args.easyrec_package_path, ) fs.build()