def _convert_model_feature_group()

in tzrec/tools/convert_easyrec_config_to_tzrec_config.py [0:0]


    def _convert_model_feature_group(self, easyrec_feature_groups):
        """Convert easyrec feature group to tzrec feature group."""
        tz_feature_groups = []
        for easy_feature_group in easyrec_feature_groups:
            tz_feature_group = model_pb2.FeatureGroupConfig()
            tz_feature_group.group_name = easy_feature_group.group_name
            tz_feature_group.feature_names.extend(easy_feature_group.feature_names)
            if (
                easy_feature_group.wide_deep
                == easyrec_feature_config_pb2.WideOrDeep.WIDE  # NOQA
            ):
                tz_feature_group.group_type = model_pb2.FeatureGroupType.WIDE
            else:
                tz_feature_group.group_type = model_pb2.FeatureGroupType.DEEP
            for i, easyrec_sequence_group in enumerate(
                easy_feature_group.sequence_features
            ):
                tz_seq_group = model_pb2.SeqGroupConfig()
                tz_seq_encoder = seq_encoder_pb2.SeqEncoderConfig()
                seq_encoder = seq_encoder_pb2.DINEncoder()
                if easyrec_sequence_group.HasField("group_name"):
                    group_name = easyrec_sequence_group.group_name
                else:
                    group_name = f"seq_{i}"
                tz_seq_group.group_name = group_name
                seq_encoder.input = group_name
                mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_sequence_group.seq_dnn)
                seq_encoder.attn_mlp.CopyFrom(mlp)
                tz_seq_encoder.din_encoder.CopyFrom(seq_encoder)
                for seq_att_map in easyrec_sequence_group.seq_att_map:
                    tz_seq_group.feature_names.extend(seq_att_map.key)
                    tz_seq_group.feature_names.extend(seq_att_map.hist_seq)
                    tz_seq_group.feature_names.extend(seq_att_map.aux_hist_seq)
                tz_feature_group.sequence_groups.append(tz_seq_group)
                tz_feature_group.sequence_encoders.append(tz_seq_encoder)
            tz_feature_groups.append(tz_feature_group)
        return tz_feature_groups