in tzrec/tools/convert_easyrec_config_to_tzrec_config.py [0:0]
def _convert_model_feature_group(self, easyrec_feature_groups):
"""Convert easyrec feature group to tzrec feature group."""
tz_feature_groups = []
for easy_feature_group in easyrec_feature_groups:
tz_feature_group = model_pb2.FeatureGroupConfig()
tz_feature_group.group_name = easy_feature_group.group_name
tz_feature_group.feature_names.extend(easy_feature_group.feature_names)
if (
easy_feature_group.wide_deep
== easyrec_feature_config_pb2.WideOrDeep.WIDE # NOQA
):
tz_feature_group.group_type = model_pb2.FeatureGroupType.WIDE
else:
tz_feature_group.group_type = model_pb2.FeatureGroupType.DEEP
for i, easyrec_sequence_group in enumerate(
easy_feature_group.sequence_features
):
tz_seq_group = model_pb2.SeqGroupConfig()
tz_seq_encoder = seq_encoder_pb2.SeqEncoderConfig()
seq_encoder = seq_encoder_pb2.DINEncoder()
if easyrec_sequence_group.HasField("group_name"):
group_name = easyrec_sequence_group.group_name
else:
group_name = f"seq_{i}"
tz_seq_group.group_name = group_name
seq_encoder.input = group_name
mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_sequence_group.seq_dnn)
seq_encoder.attn_mlp.CopyFrom(mlp)
tz_seq_encoder.din_encoder.CopyFrom(seq_encoder)
for seq_att_map in easyrec_sequence_group.seq_att_map:
tz_seq_group.feature_names.extend(seq_att_map.key)
tz_seq_group.feature_names.extend(seq_att_map.hist_seq)
tz_seq_group.feature_names.extend(seq_att_map.aux_hist_seq)
tz_feature_group.sequence_groups.append(tz_seq_group)
tz_feature_group.sequence_encoders.append(tz_seq_encoder)
tz_feature_groups.append(tz_feature_group)
return tz_feature_groups