model_generator/text_embedding_generator.py (85 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This file supports preproccesing procedure for text embedding models.""" import enum import tensorflow as tf import tensorflow_hub as hub @enum.unique class TextEmbeddingModelType(str, enum.Enum): """The different text embedding model types that require signature addition.""" NNLM = "nnlm" SWIVEL = "swivel" BERT = "bert" @enum.unique class TextEmbeddingModelLinks(str, enum.Enum): """The different text embedding model tensorflow hub links.""" NNLM = "https://tfhub.dev/google/nnlm-en-dim50-with-normalization/2" SWIVEL = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1" BERT_PREPROCESS = "https://tfhub.dev/tensorflow/bert_en_cased_preprocess/3" BERT_ENCODER = "https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/4" class TextEmbeddingModelGenerator: """Class to chain embedding models by integrating signature addition. It performs preprocessing for the NNLM, SWIVEL, and BERT model types. """ def __init__( self, nnlm_link=TextEmbeddingModelLinks.NNLM, swivel_link=TextEmbeddingModelLinks.SWIVEL, bert_preprocess_link=TextEmbeddingModelLinks.BERT_PREPROCESS, bert_encoder_link=TextEmbeddingModelLinks.BERT_ENCODER, ): """Initializes a TextEmbedModelGenerator to generate text embed models. Args: nnlm_link: Path (local or hub) that links to the NNLM model. swivel_link: Path (local or hub) that links to the SWIVEL model. bert_preprocess_link: Path (local or hub) that links to BERT preprocess model. bert_encoder_link: Path (local or hub) that links to BERT encoder model. Returns: A 'Predictor' instance. """ self._nnlm_link = nnlm_link self._swivel_link = swivel_link self._bert_preprocess_link = bert_preprocess_link self._bert_encoder_link = bert_encoder_link def generate_text_embedding_model(self, model_type, folder_path): """Generate the NNLM model from Tensorflow hub. Args: model_type: Text embedding model type (NNLM or BERT). folder_path: Folder path to save model in. Returns: Generated model with default signature. """ if model_type.lower() == TextEmbeddingModelType.NNLM: model = self._generate_nnlm() elif model_type.lower() == TextEmbeddingModelType.SWIVEL: model = self._generate_swivel() elif model_type.lower() == TextEmbeddingModelType.BERT: model = self._generate_bert() else: raise ValueError( f'"{model_type}" is not a valid model type. Please choose one from' ' (NNLM, SWIVEL, BERT)' ) tf.saved_model.save( model, folder_path, signatures=self._construct_model_signature(model) ) def _generate_nnlm(self) -> tf.keras.Model: """Generate the NNLM model from Tensorflow hub. Returns: Generated NNLM model. """ text_input = tf.keras.layers.Input( shape=(), dtype=tf.string, name="content" ) preprocessor = hub.KerasLayer(self._nnlm_link) outputs = preprocessor(text_input) model = tf.keras.Model(text_input, outputs) return model def _generate_swivel(self) -> tf.keras.Model: """Generate the SWIVEL model from Tensorflow hub. Returns: Generated SWIVEL model. """ text_input = tf.keras.layers.Input( shape=(), dtype=tf.string, name="content" ) preprocessor = hub.KerasLayer(self._swivel_link) outputs = preprocessor(text_input) model = tf.keras.Model(text_input, outputs) return model def _generate_bert(self) -> tf.keras.Model: """Generate the BERT model from Tensorflow hub. Returns: Generated BERT model. """ text_input = tf.keras.layers.Input( shape=(), dtype=tf.string, name="content" ) preprocessor = hub.KerasLayer(self._bert_preprocess_link) encoder_inputs = preprocessor(text_input) encoder = hub.KerasLayer(self._bert_encoder_link) outputs = encoder(encoder_inputs) pooled_output = outputs["pooled_output"] model = tf.keras.Model(text_input, pooled_output) return model def _construct_model_signature(self, model) -> []: """Constructs model signature in order to override output tensor name. Args: model: NNLM, SWIVEL, or BERT tf.keras.Model Returns: Constructed signature. """ @tf.function def export_model_wapper(embedding_model, **feature_specs): return {"text_embedding": embedding_model(feature_specs)} tensor_spec = { "content": tf.TensorSpec(shape=(None,), dtype=tf.string, name="content") } signature = { "serving_default": export_model_wapper.get_concrete_function( model, **tensor_spec ) } return signature def _save_model_with_path(self, model, folder_path): # Check if directory exists. if not tf.io.gfile.isdir(folder_path): tf.io.gfile.MkDir(folder_path) tf.saved_model.save(model, folder_path)