def _convert_model_config()

in tzrec/tools/convert_easyrec_config_to_tzrec_config.py [0:0]


    def _convert_model_config(self, easyrec_model_config, tz_model_config):
        """Convert easyrec model config to tzrec model config."""
        model_class = easyrec_model_config.model_class
        model_type = easyrec_model_config.WhichOneof("model")
        easyrec_model_config = getattr(easyrec_model_config, model_type)
        if model_class == "DBMTL":
            tz_model_config_ob = multi_task_rank_pb2.DBMTL()
            bottom_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.bottom_dnn)
            expert_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.expert_dnn)
            tz_model_config_ob.bottom_mlp.CopyFrom(bottom_mlp)
            tz_model_config_ob.expert_mlp.CopyFrom(expert_mlp)
            tz_model_config_ob.num_expert = easyrec_model_config.num_expert
            for task_tower in easyrec_model_config.task_towers:
                tz_task_tower = self._easyrec_bayes_tower_2_tzrec_bayes_tower(
                    task_tower
                )
                tz_model_config_ob.task_towers.append(tz_task_tower)
            tz_model_config.dbmtl.CopyFrom(tz_model_config_ob)
        elif model_class == "SimpleMultiTask":
            tz_model_config_ob = multi_task_rank_pb2.SimpleMultiTask()
            for task_tower in easyrec_model_config.task_towers:
                tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower)
                tz_model_config_ob.task_towers.append(tz_task_tower)
            tz_model_config.simple_multi_task.CopyFrom(tz_model_config_ob)
        elif model_class == "MMoE":
            tz_model_config_ob = multi_task_rank_pb2.MMoE()
            expert_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.expert_dnn)
            tz_model_config_ob.expert_mlp.CopyFrom(expert_mlp)
            tz_model_config_ob.gate_mlp.CopyFrom(expert_mlp)
            tz_model_config_ob.num_expert = easyrec_model_config.num_expert
            for task_tower in easyrec_model_config.task_towers:
                tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower)
                tz_model_config_ob.task_towers.append(tz_task_tower)
            tz_model_config.mmoe.CopyFrom(tz_model_config_ob)
        elif model_class == "PLE":
            tz_model_config_ob = multi_task_rank_pb2.PLE()
            for extraction_network in easyrec_model_config.extraction_networks:
                tz_extraction_network = (
                    self._easyrec_extraction_network_2_tzrec_extraction_network(
                        extraction_network
                    )
                )
                tz_model_config.ple.extraction_networks.append(tz_extraction_network)
            for task_tower in easyrec_model_config.task_towers:
                tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower)
                tz_model_config_ob.task_towers.append(tz_task_tower)
            tz_model_config.ple.CopyFrom(tz_model_config_ob)
        elif model_class == "DeepFM":
            tz_model_config_ob = rank_model_pb2.DeepFM()
            deep = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.dnn)
            final = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.final_dnn)
            tz_model_config_ob.deep.CopyFrom(deep)
            tz_model_config_ob.final.CopyFrom(final)
            if easyrec_model_config.HasField("wide_output_dim"):
                tz_model_config_ob.wide_embedding_dim = (
                    easyrec_model_config.wide_output_dim
                )
            tz_model_config.deepfm.CopyFrom(tz_model_config_ob)
        elif model_class == "MultiTower":
            tz_model_config_ob = rank_model_pb2.MultiTower()
            for tower in easyrec_model_config.towers:
                tz_tower = self._easyrec_tower_2_tzrec_tower(tower)
                tz_model_config_ob.towers.append(tz_tower)
            final = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.final_dnn)
            tz_model_config_ob.final.CopyFrom(final)
            tz_model_config.multi_tower.CopyFrom(tz_model_config_ob)
        elif model_class == "DSSM":
            tz_model_config_ob = match_model_pb2.DSSM()
            user_tower = self._easyrec_dssm_tower_2_tzrec_tower(
                easyrec_model_config.user_tower
            )
            tz_model_config_ob.user_tower.CopyFrom(user_tower)
            item_tower = self._easyrec_dssm_tower_2_tzrec_tower(
                easyrec_model_config.item_tower
            )
            tz_model_config_ob.item_tower.CopyFrom(item_tower)
            tz_model_config_ob.output_dim = 32
            if hasattr(
                easyrec_model_config, "temperature"
            ) and easyrec_model_config.HasField("temperature"):
                tz_model_config_ob.temperature = easyrec_model_config.temperature
            tz_model_config.dssm.CopyFrom(tz_model_config_ob)
        else:
            logger.error(
                f"{model_class} is not convert to tzrec model, please adaptation"
            )
        return tz_model_config