tzrec/tools/convert_easyrec_config_to_tzrec_config.py (689 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
import json
import os
import sys
import tarfile
import tempfile
import zipfile
from collections import OrderedDict
import requests
from google.protobuf import descriptor_pool, symbol_database, text_format
from tzrec.constant import EASYREC_VERSION
from tzrec.protos import feature_pb2 as tzrec_feature_pb2
from tzrec.protos import (
loss_pb2,
metric_pb2,
model_pb2,
module_pb2,
seq_encoder_pb2,
tower_pb2,
)
from tzrec.protos import pipeline_pb2 as tzrec_pipeline_pb2
from tzrec.protos.data_pb2 import DatasetType
from tzrec.protos.models import match_model_pb2, multi_task_rank_pb2, rank_model_pb2
from tzrec.utils.logging_util import logger
def _get_easyrec(pkg_path=None):
"""Get easyrec whl and extract."""
local_cache_dir = tempfile.mkdtemp(prefix="tzrec_tmp")
if pkg_path is None:
pkg_path = (
f"https://easyrec.oss-cn-beijing.aliyuncs.com/release/whls/"
f"easy_rec-{EASYREC_VERSION}-py2.py3-none-any.whl"
)
if pkg_path.startswith("http"):
logger.info(f"downloading easyrec from {pkg_path}")
r = requests.get(pkg_path)
content = r.content
else:
with open(pkg_path, "rb") as f:
content = f.read()
if ".tar" in pkg_path:
try:
with tarfile.open(fileobj=io.BytesIO(content)) as tar:
tar.extractall(path=local_cache_dir)
local_package_dir = local_cache_dir
except Exception as e:
raise RuntimeError(f"invalid {pkg_path} tar.") from e
else:
try:
with zipfile.ZipFile(io.BytesIO(content)) as f:
f.extractall(local_cache_dir)
local_package_dir = local_cache_dir
except zipfile.BadZipfile as e:
raise RuntimeError(f"invalid {pkg_path} whl.") from e
with open(os.path.join(local_package_dir, "easy_rec/__init__.py"), "w") as f:
f.write("")
sys.path.append(local_package_dir)
_sym = symbol_database.Default()
_sym.pool = descriptor_pool.DescriptorPool()
from easy_rec.python.protos import feature_config_pb2 as _feature_config_pb2
from easy_rec.python.protos import loss_pb2 as _loss_pb2
from easy_rec.python.protos import pipeline_pb2 as _pipeline_pb2
globals()["easyrec_pipeline_pb2"] = _pipeline_pb2
globals()["easyrec_feature_config_pb2"] = _feature_config_pb2
globals()["easyrec_loss_pb2"] = _loss_pb2
class ConvertConfig(object):
"""Convert EasyRec config to tzrec config.
Args:
easyrec_config_path (str): EasyRec config file path.
fg_json_path (str): EasyRec use fg.json file path.
output_tzrec_config_path (str): TzRec config file path will create.
"""
def __init__(
self,
easyrec_config_path,
output_tzrec_config_path,
fg_json_path=None,
easyrec_package_path=None,
):
if "easyrec_pipeline_pb2" not in globals():
_get_easyrec(easyrec_package_path)
self.output_tzrec_config_path = output_tzrec_config_path
self.easyrec_config = self.load_easyrec_config(easyrec_config_path)
self.feature_to_fg = {}
self.sub_sequence_to_group = {}
self.sequence_feature_to_fg = {}
if fg_json_path is not None:
fg_json = self.load_easyrec_fg_json(fg_json_path)
self.analyse_fg(fg_json)
def analyse_fg(self, fg_json):
"""Analysis fg.json."""
for feat in fg_json["features"]:
if "sequence_name" in feat:
sequence_name = feat["sequence_name"]
for sub_feat in feat["features"]:
self.sub_sequence_to_group[
f"{sequence_name}__{sub_feat['feature_name']}"
] = sequence_name
self.sequence_feature_to_fg[sequence_name] = feat
else:
feature_name = feat["feature_name"]
self.feature_to_fg[feature_name] = feat
def load_easyrec_config(self, path):
"""Load easyrec config."""
easyrec_config = easyrec_pipeline_pb2.EasyRecConfig() # NOQA
with open(path, "r", encoding="utf-8") as f:
cfg_str = f.read()
text_format.Merge(cfg_str, easyrec_config)
return easyrec_config
def load_easyrec_fg_json(self, path):
"""Load easyrec use fg.json."""
with open(path, "r", encoding="utf-8") as f:
fg_json = json.load(f)
return fg_json
def _create_train_config(self, pipeline_config):
"""Create easy_rec train config."""
if not pipeline_config.HasField("train_config"):
train_config_str = """
train_config {
sparse_optimizer {
adam_optimizer {
lr: 0.001
}
constant_learning_rate {
}
}
dense_optimizer {
adam_optimizer {
lr: 0.001
}
constant_learning_rate {
}
}
num_epochs: 1
use_tensorboard: false
}"""
text_format.Merge(train_config_str, pipeline_config)
return pipeline_config
def _create_eval_config(self, pipeline_config):
"""Create tzrec train config."""
if not pipeline_config.HasField("eval_config"):
eval_config_str = "eval_config {}"
text_format.Merge(eval_config_str, pipeline_config)
return pipeline_config
def _create_data_config(self, pipeline_config):
"""Create tzrec data config."""
label_fields = list(self.easyrec_config.data_config.label_fields)
pipeline_config.data_config.batch_size = (
self.easyrec_config.data_config.batch_size
)
pipeline_config.data_config.dataset_type = DatasetType.OdpsDataset
pipeline_config.data_config.label_fields.extend(label_fields)
pipeline_config.data_config.num_workers = 8
return pipeline_config
def _create_feature_config(self, pipeline_config):
"""Create tzrec feature config."""
easyrec_feature_config = easyrec_feature_config_pb2.FeatureConfig() # NOQA
seq_group_cfg = OrderedDict()
for cfg in self.easyrec_config.feature_configs:
if cfg.feature_name:
feature_name = cfg.feature_name
else:
feature_name = list(cfg.input_names)[0]
input_names = cfg.input_names
feature_type = cfg.feature_type
if feature_name in self.feature_to_fg:
fg_json = self.feature_to_fg[feature_name]
elif feature_name in self.sub_sequence_to_group:
pass
elif input_names[0] in self.feature_to_fg:
fg_json = self.feature_to_fg[input_names[0]]
else:
logger.error(f"in easyrec config {feature_name} not in fg.json")
feature_config = None
if feature_type == easyrec_feature_config.IdFeature:
feature_config = tzrec_feature_pb2.FeatureConfig()
feature = tzrec_feature_pb2.IdFeature()
feature.feature_name = feature_name
feature.expression = fg_json["expression"]
feature.embedding_dim = cfg.embedding_dim
feature.hash_bucket_size = cfg.hash_bucket_size
feature_config.ClearField("feature")
feature_config.id_feature.CopyFrom(feature)
elif feature_type == easyrec_feature_config.TagFeature:
feature_config = tzrec_feature_pb2.FeatureConfig()
feature = tzrec_feature_pb2.IdFeature()
feature.feature_name = feature_name
feature.expression = fg_json["expression"]
feature.embedding_dim = cfg.embedding_dim
feature.hash_bucket_size = cfg.hash_bucket_size
if cfg.HasField("kv_separator"):
feature.weighted = True
feature_config.ClearField("feature")
feature_config.id_feature.CopyFrom(feature)
elif feature_type == easyrec_feature_config.SequenceFeature:
if feature_name in self.sub_sequence_to_group:
sequence_name = self.sub_sequence_to_group[feature_name]
if sequence_name in seq_group_cfg:
seq_group_cfg[sequence_name].append(cfg)
else:
seq_group_cfg[sequence_name] = [cfg]
elif feature_name in self.feature_to_fg:
feature_config = tzrec_feature_pb2.FeatureConfig()
if cfg.sub_feature_type == easyrec_feature_config.IdFeature:
feature = tzrec_feature_pb2.SequenceIdFeature()
feature.feature_name = feature_name
feature.expression = self.feature_to_fg[feature_name][
"expression"
]
feature.embedding_dim = cfg.embedding_dim
feature.hash_bucket_size = cfg.hash_bucket_size
feature_config.ClearField("feature")
feature_config.sequence_id_feature.CopyFrom(feature)
else:
feature = tzrec_feature_pb2.SequenceRawFeature()
feature.feature_name = feature_name
feature.expression = self.feature_to_fg[feature_name][
"expression"
]
boundaries = list(cfg.boundaries)
feature.embedding_dim = cfg.embedding_dim
if len(boundaries):
feature.boundaries.extend(boundaries)
feature_config.ClearField("feature")
feature_config.sequence_raw_feature.CopyFrom(feature)
else:
logger.error(f"sequences feature: {feature_name} can't converted")
elif feature_type == easyrec_feature_config.RawFeature:
feature_config = tzrec_feature_pb2.FeatureConfig()
if fg_json["feature_type"] == "lookup_feature":
feature = tzrec_feature_pb2.LookupFeature()
feature.feature_name = feature_name
map = fg_json["map"]
key = fg_json["key"]
boundaries = list(cfg.boundaries)
feature.feature_name = feature_name
feature.map = map
feature.key = key
feature.embedding_dim = cfg.embedding_dim
if len(boundaries):
feature.boundaries.extend(boundaries)
feature_config.ClearField("feature")
feature_config.lookup_feature.CopyFrom(feature)
else:
feature = tzrec_feature_pb2.RawFeature()
feature.feature_name = feature_name
feature.expression = fg_json["expression"]
boundaries = list(cfg.boundaries)
feature.embedding_dim = cfg.embedding_dim
if len(boundaries):
feature.boundaries.extend(boundaries)
feature_config.ClearField("feature")
feature_config.raw_feature.CopyFrom(feature)
elif feature_type == easyrec_feature_config.ComboFeature:
feature_config = tzrec_feature_pb2.FeatureConfig()
feature = tzrec_feature_pb2.ComboFeature()
feature.feature_name = feature_name
for input in list(cfg.input_names):
if input in self.feature_to_fg:
tmp_fg_json = self.feature_to_fg[input]
feature.expression.append(tmp_fg_json["expression"])
else:
raise ValueError(f"{cfg} input_names:{input} not in fg json")
feature.embedding_dim = cfg.embedding_dim
feature.hash_bucket_size = cfg.hash_bucket_size
feature_config.ClearField("feature")
feature_config.combo_feature.CopyFrom(feature)
elif feature_type == easyrec_feature_config.LookupFeature:
feature_config = tzrec_feature_pb2.FeatureConfig()
feature = tzrec_feature_pb2.LookupFeature()
feature.feature_name = feature_name
map_f = cfg.input_names[0]
key_f = cfg.input_names[1]
if map_f in self.feature_to_fg:
feature.map = self.feature_to_fg[map_f]["expression"]
else:
raise ValueError(f"{cfg} input names: {map_f} not in fg.json")
if key_f in self.feature_to_fg:
feature.key = self.feature_to_fg[key_f]["expression"]
else:
raise ValueError(f"{cfg} input names: {map_f} not in fg.json")
feature.embedding_dim = cfg.embedding_dim
if len(list(cfg.boundaries)):
feature.boundaries.extend(list(cfg.boundaries))
feature_config.ClearField("feature")
feature_config.lookup_feature.CopyFrom(feature)
else:
logger.error(f"{feature_name} can't converted")
if feature_config is not None:
pipeline_config.feature_configs.append(feature_config)
for seq_name, sub_cfgs in seq_group_cfg.items():
sequence_fg = self.sequence_feature_to_fg[seq_name]
feature_config = tzrec_feature_pb2.FeatureConfig()
sequence_feature_config = tzrec_feature_pb2.SequenceFeature()
sequence_feature_config.sequence_name = sequence_fg["sequence_name"]
sequence_feature_config.sequence_length = sequence_fg["sequence_length"]
sequence_feature_config.sequence_delim = sequence_fg["sequence_delim"]
features = sequence_fg["features"]
seq_feature_to_fg = {}
for feature in features:
seq_feature_to_fg[f"{seq_name}__{feature['feature_name']}"] = feature
for cfg in sub_cfgs:
sub_feature_cfg = tzrec_feature_pb2.SeqFeatureConfig()
feature_name = (
cfg.feature_name if cfg.feature_name else cfg.input_names[0]
)
if feature_name in seq_feature_to_fg:
seq_feature_fg = seq_feature_to_fg[feature_name]
if cfg.sub_feature_type == easyrec_feature_config.IdFeature:
feature = tzrec_feature_pb2.IdFeature()
feature.feature_name = seq_feature_fg["feature_name"]
feature.expression = seq_feature_fg["expression"]
feature.embedding_dim = cfg.embedding_dim
feature.hash_bucket_size = cfg.hash_bucket_size
sub_feature_cfg.ClearField("feature")
sub_feature_cfg.id_feature.CopyFrom(feature)
else:
feature = tzrec_feature_pb2.RawFeature()
feature.feature_name = seq_feature_fg["feature_name"]
feature.expression = seq_feature_fg["expression"]
boundaries = list(cfg.boundaries)
feature.embedding_dim = cfg.embedding_dim
if len(boundaries):
feature.boundaries.extend(boundaries)
sub_feature_cfg.ClearField("feature")
sub_feature_cfg.raw_feature.CopyFrom(feature)
sequence_feature_config.features.append(sub_feature_cfg)
else:
logger.error(
f"sequence feature: {feature_name} not config in fg.json"
)
feature_config.sequence_feature.CopyFrom(sequence_feature_config)
pipeline_config.feature_configs.append(feature_config)
return pipeline_config
def _create_feature_config_no_fg(self, pipeline_config):
"""Create tzrec feature config no fg json."""
easyrec_feature_config = easyrec_feature_config_pb2.FeatureConfig() # NOQA
for cfg in self.easyrec_config.feature_configs:
if cfg.feature_name:
feature_name = cfg.feature_name
else:
feature_name = list(cfg.input_names)[0]
input_names = cfg.input_names
feature_type = cfg.feature_type
feature_config = None
if feature_type == easyrec_feature_config.IdFeature:
feature_config = tzrec_feature_pb2.FeatureConfig()
feature = tzrec_feature_pb2.IdFeature()
feature.feature_name = feature_name
feature.expression = f"user:{input_names[0]}"
feature.embedding_dim = cfg.embedding_dim
feature.hash_bucket_size = cfg.hash_bucket_size
feature_config.ClearField("feature")
feature_config.id_feature.CopyFrom(feature)
elif feature_type == easyrec_feature_config.TagFeature:
feature_config = tzrec_feature_pb2.FeatureConfig()
feature = tzrec_feature_pb2.IdFeature()
feature.feature_name = feature_name
feature.expression = f"user:{input_names[0]}"
feature.embedding_dim = cfg.embedding_dim
feature.hash_bucket_size = cfg.hash_bucket_size
if cfg.HasField("kv_separator"):
feature.weighted = True
feature_config.ClearField("feature")
feature_config.id_feature.CopyFrom(feature)
elif feature_type == easyrec_feature_config.SequenceFeature:
feature_config = tzrec_feature_pb2.FeatureConfig()
if cfg.sub_feature_type == easyrec_feature_config.RawFeature:
feature = tzrec_feature_pb2.SequenceRawFeature()
feature.feature_name = feature_name
feature.expression = f"user:{input_names[0]}"
feature.sequence_length = cfg.sequence_length
feature.sequence_delim = cfg.separator
feature.embedding_dim = cfg.embedding_dim
boundaries = list(cfg.boundaries)
if len(boundaries) > 0:
feature.boundaries.extend(boundaries)
feature_config.ClearField("feature")
feature_config.sequence_raw_feature.CopyFrom(feature)
else:
feature = tzrec_feature_pb2.SequenceIdFeature()
feature.feature_name = feature_name
feature.expression = f"user:{input_names[0]}"
feature.sequence_length = cfg.sequence_length
feature.sequence_delim = cfg.separator
feature.embedding_dim = cfg.embedding_dim
if cfg.HasField("hash_bucket_size"):
feature.hash_bucket_size = cfg.hash_bucket_size
if cfg.HasField("num_buckets"):
feature.num_buckets = cfg.num_buckets
feature_config.ClearField("feature")
feature_config.sequence_id_feature.CopyFrom(feature)
if cfg.sequence_length <= 1:
logger.error(f"{feature_name} sequence_length is invalid !!!")
elif feature_type == easyrec_feature_config.RawFeature:
feature_config = tzrec_feature_pb2.FeatureConfig()
feature = tzrec_feature_pb2.RawFeature()
feature.feature_name = feature_name
feature.expression = f"user:{input_names[0]}"
boundaries = list(cfg.boundaries)
if cfg.HasField("embedding_dim"):
feature.embedding_dim = cfg.embedding_dim
if len(boundaries):
feature.boundaries.extend(boundaries)
feature_config.ClearField("feature")
feature_config.raw_feature.CopyFrom(feature)
elif feature_type == easyrec_feature_config.ComboFeature:
feature_config = tzrec_feature_pb2.FeatureConfig()
feature = tzrec_feature_pb2.ComboFeature()
feature.feature_name = feature_name
for input in list(cfg.input_names):
feature.expression.append(f"user:{input}")
feature.embedding_dim = cfg.embedding_dim
feature.hash_bucket_size = cfg.hash_bucket_size
feature_config.ClearField("feature")
feature_config.combo_feature.CopyFrom(feature)
elif feature_type == easyrec_feature_config.LookupFeature:
feature_config = tzrec_feature_pb2.FeatureConfig()
feature = tzrec_feature_pb2.LookupFeature()
feature.feature_name = feature_name
feature.map = f"user:{input_names[0]}"
feature.key = f"user:{input_names[1]}"
if cfg.HasField("embedding_dim"):
feature.embedding_dim = cfg.embedding_dim
if len(list(cfg.boundaries)):
feature.boundaries.extend(list(cfg.boundaries))
feature_config.ClearField("feature")
feature_config.lookup_feature.CopyFrom(feature)
else:
logger.error(f"{feature_name} can't converted")
if feature_config is not None:
logger.info(f"{feature_name} converted succeeded")
pipeline_config.feature_configs.append(feature_config)
return pipeline_config
def _easyrec_dnn_2_tzrec_mlp(self, dnn):
"""Convert easyrec dnn to tzrec mlp."""
mlp = module_pb2.MLP()
mlp.hidden_units.extend(dnn.hidden_units)
mlp.dropout_ratio.extend(dnn.dropout_ratio)
mlp.use_bn = dnn.use_bn
return mlp
def _easyrec_loss_2_tzrec_loss(self, easyrec_loss):
"""Convert easyrec loss to tzrec loss."""
tzrec_loss = loss_pb2.LossConfig()
loss_type = easyrec_loss.loss_type
if loss_type == easyrec_loss_pb2.LossType.JRC_LOSS: # NOQA
tzrec_loss.jrc_loss.CopyFrom(loss_pb2.JRCLoss())
elif loss_type == easyrec_loss_pb2.LossType.L2_LOSS: # NOQA
tzrec_loss.l2_loss.CopyFrom(loss_pb2.L2Loss())
elif loss_type == easyrec_loss_pb2.LossType.SOFTMAX_CROSS_ENTROPY: # NOQA
tzrec_loss.softmax_cross_entropy.CopyFrom(loss_pb2.SoftmaxCrossEntropy())
elif loss_type == easyrec_loss_pb2.LossType.CLASSIFICATION: # NOQA
tzrec_loss.binary_cross_entropy.CopyFrom(loss_pb2.BinaryCrossEntropy())
else:
logger.error(
f"{easyrec_loss} is not convert to tzrec loss, please adaptation"
)
return tzrec_loss
def _easyrec_metrics_2_tzrec_metrics(self, easyrec_metric):
"""Convert easyrec metric to tzrec metric."""
metric = metric_pb2.MetricConfig()
metric_type = easyrec_metric.WhichOneof("metric")
easyrec_metric_ob = getattr(easyrec_metric, metric_type)
if metric_type == "auc":
metric.auc.CopyFrom(metric_pb2.AUC())
elif metric_type == "gauc":
tzrec_metric_ob = metric_pb2.GroupedAUC(
grouping_key=easyrec_metric_ob.uid_field
)
metric.grouped_auc.CopyFrom(tzrec_metric_ob)
elif metric_type == "recall_at_topk":
metric.recall_at_k.CopyFrom(metric_pb2.RecallAtK())
elif metric_type == "mean_absolute_error":
metric.mean_absolute_error.CopyFrom(metric_pb2.MeanAbsoluteError())
elif metric_type == "mean_squared_error":
metric.mean_squared_error.CopyFrom(metric_pb2.MeanSquaredError())
elif metric_type == "accuracy":
metric.accuracy.CopyFrom(metric_pb2.Accuracy())
else:
logger.error(
f"{easyrec_metric} is not convert to tzrec metric, please adaptation"
)
return metric
def _easyrec_bayes_tower_2_tzrec_bayes_tower(self, easyrec_bayes_task_tower):
"""Convert easyrec bayes tower to tzrec bayes tower."""
tzrec_bayes_task_tower = tower_pb2.BayesTaskTower()
tzrec_bayes_task_tower.tower_name = easyrec_bayes_task_tower.tower_name
tzrec_bayes_task_tower.label_name = easyrec_bayes_task_tower.label_name
tzrec_bayes_task_tower.num_class = easyrec_bayes_task_tower.num_class
tzrec_bayes_task_tower.relation_tower_names.extend(
easyrec_bayes_task_tower.relation_tower_names
)
mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_bayes_task_tower.dnn)
tzrec_bayes_task_tower.mlp.CopyFrom(mlp)
relation_mlp = self._easyrec_dnn_2_tzrec_mlp(
easyrec_bayes_task_tower.relation_dnn
)
tzrec_bayes_task_tower.relation_mlp.CopyFrom(relation_mlp)
for loss in easyrec_bayes_task_tower.losses:
tzrec_bayes_task_tower.losses.append(self._easyrec_loss_2_tzrec_loss(loss))
for metric in easyrec_bayes_task_tower.metrics_set:
tzrec_bayes_task_tower.metrics.append(
self._easyrec_metrics_2_tzrec_metrics(metric)
)
return tzrec_bayes_task_tower
def _easyrec_task_tower_2_tzrec_task_tower(self, easyrec_task_tower):
"""Convert easyrec task tower to tzrec task tower."""
tzrec_task_tower = tower_pb2.TaskTower()
tzrec_task_tower.tower_name = easyrec_task_tower.tower_name
tzrec_task_tower.label_name = easyrec_task_tower.label_name
tzrec_task_tower.num_class = easyrec_task_tower.num_class
mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_task_tower.dnn)
tzrec_task_tower.mlp.CopyFrom(mlp)
for loss in easyrec_task_tower.losses:
tzrec_task_tower.losses.append(self._easyrec_loss_2_tzrec_loss(loss))
for metric in easyrec_task_tower.metrics_set:
tzrec_task_tower.metrics.append(
self._easyrec_metrics_2_tzrec_metrics(metric)
)
return tzrec_task_tower
def _easyrec_tower_2_tzrec_tower(self, easyrec_tower):
"""Convert easyrec tower to tzrec tower."""
tzrec_tower = tower_pb2.Tower()
tzrec_tower.input = easyrec_tower.input
mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_tower.dnn)
tzrec_tower.mlp.CopyFrom(mlp)
return tzrec_tower
def _easyrec_dssm_tower_2_tzrec_tower(self, easyrec_dssm_tower):
"""Convert easyrec dssm tower to tzrec tower."""
tzrec_tower = tower_pb2.Tower()
tzrec_tower.input = easyrec_dssm_tower.id
mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_dssm_tower.dnn)
tzrec_tower.mlp.CopyFrom(mlp)
return tzrec_tower
def _easyrec_extraction_network_2_tzrec_extraction_network(
self, easyrec_extraction_network
):
"""Convert easyrec extraction net to tzrec extraction net."""
tzrec_extraction_network = module_pb2.ExtractionNetwork()
tzrec_extraction_network.network_name = easyrec_extraction_network.network_name
tzrec_extraction_network.expert_num_per_task = (
easyrec_extraction_network.expert_num_per_task
)
tzrec_extraction_network.share_num = easyrec_extraction_network.share_num
task_expert_net = self._easyrec_dnn_2_tzrec_mlp(
easyrec_extraction_network.task_expert_net
)
tzrec_extraction_network.task_expert_net.CopyFrom(task_expert_net)
share_expert_net = self._easyrec_dnn_2_tzrec_mlp(
easyrec_extraction_network.share_expert_net
)
tzrec_extraction_network.share_expert_net.CopyFrom(share_expert_net)
return tzrec_extraction_network
def _convert_model_feature_group(self, easyrec_feature_groups):
"""Convert easyrec feature group to tzrec feature group."""
tz_feature_groups = []
for easy_feature_group in easyrec_feature_groups:
tz_feature_group = model_pb2.FeatureGroupConfig()
tz_feature_group.group_name = easy_feature_group.group_name
tz_feature_group.feature_names.extend(easy_feature_group.feature_names)
if (
easy_feature_group.wide_deep
== easyrec_feature_config_pb2.WideOrDeep.WIDE # NOQA
):
tz_feature_group.group_type = model_pb2.FeatureGroupType.WIDE
else:
tz_feature_group.group_type = model_pb2.FeatureGroupType.DEEP
for i, easyrec_sequence_group in enumerate(
easy_feature_group.sequence_features
):
tz_seq_group = model_pb2.SeqGroupConfig()
tz_seq_encoder = seq_encoder_pb2.SeqEncoderConfig()
seq_encoder = seq_encoder_pb2.DINEncoder()
if easyrec_sequence_group.HasField("group_name"):
group_name = easyrec_sequence_group.group_name
else:
group_name = f"seq_{i}"
tz_seq_group.group_name = group_name
seq_encoder.input = group_name
mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_sequence_group.seq_dnn)
seq_encoder.attn_mlp.CopyFrom(mlp)
tz_seq_encoder.din_encoder.CopyFrom(seq_encoder)
for seq_att_map in easyrec_sequence_group.seq_att_map:
tz_seq_group.feature_names.extend(seq_att_map.key)
tz_seq_group.feature_names.extend(seq_att_map.hist_seq)
tz_seq_group.feature_names.extend(seq_att_map.aux_hist_seq)
tz_feature_group.sequence_groups.append(tz_seq_group)
tz_feature_group.sequence_encoders.append(tz_seq_encoder)
tz_feature_groups.append(tz_feature_group)
return tz_feature_groups
def _convert_model_config(self, easyrec_model_config, tz_model_config):
"""Convert easyrec model config to tzrec model config."""
model_class = easyrec_model_config.model_class
model_type = easyrec_model_config.WhichOneof("model")
easyrec_model_config = getattr(easyrec_model_config, model_type)
if model_class == "DBMTL":
tz_model_config_ob = multi_task_rank_pb2.DBMTL()
bottom_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.bottom_dnn)
expert_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.expert_dnn)
tz_model_config_ob.bottom_mlp.CopyFrom(bottom_mlp)
tz_model_config_ob.expert_mlp.CopyFrom(expert_mlp)
tz_model_config_ob.num_expert = easyrec_model_config.num_expert
for task_tower in easyrec_model_config.task_towers:
tz_task_tower = self._easyrec_bayes_tower_2_tzrec_bayes_tower(
task_tower
)
tz_model_config_ob.task_towers.append(tz_task_tower)
tz_model_config.dbmtl.CopyFrom(tz_model_config_ob)
elif model_class == "SimpleMultiTask":
tz_model_config_ob = multi_task_rank_pb2.SimpleMultiTask()
for task_tower in easyrec_model_config.task_towers:
tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower)
tz_model_config_ob.task_towers.append(tz_task_tower)
tz_model_config.simple_multi_task.CopyFrom(tz_model_config_ob)
elif model_class == "MMoE":
tz_model_config_ob = multi_task_rank_pb2.MMoE()
expert_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.expert_dnn)
tz_model_config_ob.expert_mlp.CopyFrom(expert_mlp)
tz_model_config_ob.gate_mlp.CopyFrom(expert_mlp)
tz_model_config_ob.num_expert = easyrec_model_config.num_expert
for task_tower in easyrec_model_config.task_towers:
tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower)
tz_model_config_ob.task_towers.append(tz_task_tower)
tz_model_config.mmoe.CopyFrom(tz_model_config_ob)
elif model_class == "PLE":
tz_model_config_ob = multi_task_rank_pb2.PLE()
for extraction_network in easyrec_model_config.extraction_networks:
tz_extraction_network = (
self._easyrec_extraction_network_2_tzrec_extraction_network(
extraction_network
)
)
tz_model_config.ple.extraction_networks.append(tz_extraction_network)
for task_tower in easyrec_model_config.task_towers:
tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower)
tz_model_config_ob.task_towers.append(tz_task_tower)
tz_model_config.ple.CopyFrom(tz_model_config_ob)
elif model_class == "DeepFM":
tz_model_config_ob = rank_model_pb2.DeepFM()
deep = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.dnn)
final = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.final_dnn)
tz_model_config_ob.deep.CopyFrom(deep)
tz_model_config_ob.final.CopyFrom(final)
if easyrec_model_config.HasField("wide_output_dim"):
tz_model_config_ob.wide_embedding_dim = (
easyrec_model_config.wide_output_dim
)
tz_model_config.deepfm.CopyFrom(tz_model_config_ob)
elif model_class == "MultiTower":
tz_model_config_ob = rank_model_pb2.MultiTower()
for tower in easyrec_model_config.towers:
tz_tower = self._easyrec_tower_2_tzrec_tower(tower)
tz_model_config_ob.towers.append(tz_tower)
final = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.final_dnn)
tz_model_config_ob.final.CopyFrom(final)
tz_model_config.multi_tower.CopyFrom(tz_model_config_ob)
elif model_class == "DSSM":
tz_model_config_ob = match_model_pb2.DSSM()
user_tower = self._easyrec_dssm_tower_2_tzrec_tower(
easyrec_model_config.user_tower
)
tz_model_config_ob.user_tower.CopyFrom(user_tower)
item_tower = self._easyrec_dssm_tower_2_tzrec_tower(
easyrec_model_config.item_tower
)
tz_model_config_ob.item_tower.CopyFrom(item_tower)
tz_model_config_ob.output_dim = 32
if hasattr(
easyrec_model_config, "temperature"
) and easyrec_model_config.HasField("temperature"):
tz_model_config_ob.temperature = easyrec_model_config.temperature
tz_model_config.dssm.CopyFrom(tz_model_config_ob)
else:
logger.error(
f"{model_class} is not convert to tzrec model, please adaptation"
)
return tz_model_config
def _create_model_config(self, pipeline_config):
"""Convert easyrec model config to tzrec model config."""
tz_model_config = model_pb2.ModelConfig()
easyrec_model_config = self.easyrec_config.model_config
easyrec_feature_groups = easyrec_model_config.feature_groups
tz_feature_groups = self._convert_model_feature_group(easyrec_feature_groups)
tz_model_config.feature_groups.extend(tz_feature_groups)
tz_model_config = self._convert_model_config(
easyrec_model_config, tz_model_config
)
pipeline_config.model_config.CopyFrom(tz_model_config)
return pipeline_config
def build(self):
"""Create tzrec model config order by easyrec config and fg file."""
tzrec_config = tzrec_pipeline_pb2.EasyRecConfig()
tzrec_config = self._create_train_config(tzrec_config)
tzrec_config = self._create_eval_config(tzrec_config)
tzrec_config = self._create_data_config(tzrec_config)
if len(self.feature_to_fg):
tzrec_config = self._create_feature_config(tzrec_config)
else:
tzrec_config = self._create_feature_config_no_fg(tzrec_config)
tzrec_config = self._create_model_config(tzrec_config)
config_text = text_format.MessageToString(tzrec_config, as_utf8=True)
with open(self.output_tzrec_config_path, "w") as f:
f.write(config_text)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--easyrec_config_path",
type=str,
default=None,
help="easyrec model config path",
)
parser.add_argument(
"--fg_json_path", type=str, default=None, help="easyrec use fg.json path"
)
parser.add_argument(
"--output_tzrec_config_path",
type=str,
default=None,
help="output tzrec config path",
)
parser.add_argument(
"--easyrec_package_path",
type=str,
default=None,
help="easyrec whl or tar package path or url",
)
args, extra_args = parser.parse_known_args()
fs = ConvertConfig(
args.easyrec_config_path,
args.output_tzrec_config_path,
args.fg_json_path,
args.easyrec_package_path,
)
fs.build()