in tzrec/tools/convert_easyrec_config_to_tzrec_config.py [0:0]
def _convert_model_config(self, easyrec_model_config, tz_model_config):
"""Convert easyrec model config to tzrec model config."""
model_class = easyrec_model_config.model_class
model_type = easyrec_model_config.WhichOneof("model")
easyrec_model_config = getattr(easyrec_model_config, model_type)
if model_class == "DBMTL":
tz_model_config_ob = multi_task_rank_pb2.DBMTL()
bottom_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.bottom_dnn)
expert_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.expert_dnn)
tz_model_config_ob.bottom_mlp.CopyFrom(bottom_mlp)
tz_model_config_ob.expert_mlp.CopyFrom(expert_mlp)
tz_model_config_ob.num_expert = easyrec_model_config.num_expert
for task_tower in easyrec_model_config.task_towers:
tz_task_tower = self._easyrec_bayes_tower_2_tzrec_bayes_tower(
task_tower
)
tz_model_config_ob.task_towers.append(tz_task_tower)
tz_model_config.dbmtl.CopyFrom(tz_model_config_ob)
elif model_class == "SimpleMultiTask":
tz_model_config_ob = multi_task_rank_pb2.SimpleMultiTask()
for task_tower in easyrec_model_config.task_towers:
tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower)
tz_model_config_ob.task_towers.append(tz_task_tower)
tz_model_config.simple_multi_task.CopyFrom(tz_model_config_ob)
elif model_class == "MMoE":
tz_model_config_ob = multi_task_rank_pb2.MMoE()
expert_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.expert_dnn)
tz_model_config_ob.expert_mlp.CopyFrom(expert_mlp)
tz_model_config_ob.gate_mlp.CopyFrom(expert_mlp)
tz_model_config_ob.num_expert = easyrec_model_config.num_expert
for task_tower in easyrec_model_config.task_towers:
tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower)
tz_model_config_ob.task_towers.append(tz_task_tower)
tz_model_config.mmoe.CopyFrom(tz_model_config_ob)
elif model_class == "PLE":
tz_model_config_ob = multi_task_rank_pb2.PLE()
for extraction_network in easyrec_model_config.extraction_networks:
tz_extraction_network = (
self._easyrec_extraction_network_2_tzrec_extraction_network(
extraction_network
)
)
tz_model_config.ple.extraction_networks.append(tz_extraction_network)
for task_tower in easyrec_model_config.task_towers:
tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower)
tz_model_config_ob.task_towers.append(tz_task_tower)
tz_model_config.ple.CopyFrom(tz_model_config_ob)
elif model_class == "DeepFM":
tz_model_config_ob = rank_model_pb2.DeepFM()
deep = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.dnn)
final = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.final_dnn)
tz_model_config_ob.deep.CopyFrom(deep)
tz_model_config_ob.final.CopyFrom(final)
if easyrec_model_config.HasField("wide_output_dim"):
tz_model_config_ob.wide_embedding_dim = (
easyrec_model_config.wide_output_dim
)
tz_model_config.deepfm.CopyFrom(tz_model_config_ob)
elif model_class == "MultiTower":
tz_model_config_ob = rank_model_pb2.MultiTower()
for tower in easyrec_model_config.towers:
tz_tower = self._easyrec_tower_2_tzrec_tower(tower)
tz_model_config_ob.towers.append(tz_tower)
final = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.final_dnn)
tz_model_config_ob.final.CopyFrom(final)
tz_model_config.multi_tower.CopyFrom(tz_model_config_ob)
elif model_class == "DSSM":
tz_model_config_ob = match_model_pb2.DSSM()
user_tower = self._easyrec_dssm_tower_2_tzrec_tower(
easyrec_model_config.user_tower
)
tz_model_config_ob.user_tower.CopyFrom(user_tower)
item_tower = self._easyrec_dssm_tower_2_tzrec_tower(
easyrec_model_config.item_tower
)
tz_model_config_ob.item_tower.CopyFrom(item_tower)
tz_model_config_ob.output_dim = 32
if hasattr(
easyrec_model_config, "temperature"
) and easyrec_model_config.HasField("temperature"):
tz_model_config_ob.temperature = easyrec_model_config.temperature
tz_model_config.dssm.CopyFrom(tz_model_config_ob)
else:
logger.error(
f"{model_class} is not convert to tzrec model, please adaptation"
)
return tz_model_config