in tensorflow_examples/lite/model_maker/core/task/model_spec/text_spec.py [0:0]
def create_classifier_model(bert_config,
num_labels,
max_seq_length,
initializer=None,
hub_module_url=None,
hub_module_trainable=True,
is_tf2=True):
"""BERT classifier model in functional API style.
Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`.
Args:
bert_config: BertConfig, the config defines the core Bert model.
num_labels: integer, the number of classes.
max_seq_length: integer, the maximum input sequence length.
initializer: Initializer for the final dense layer in the span labeler.
Defaulted to TruncatedNormal initializer.
hub_module_url: TF-Hub path/url to Bert module.
hub_module_trainable: True to finetune layers in the hub module.
is_tf2: boolean, whether the hub module is in TensorFlow 2.x format.
Returns:
Combined prediction model (words, mask, type) -> (one-hot labels)
BERT sub-model (words, mask, type) -> (bert_outputs)
"""
if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
input_word_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
input_type_ids = tf.keras.layers.Input(
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
if is_tf2:
bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
else:
bert_model = hub_loader.HubKerasLayerV1V2(
hub_module_url,
signature='tokens',
output_key='pooled_output',
trainable=hub_module_trainable)
pooled_output = bert_model({
'input_ids': input_word_ids,
'input_mask': input_mask,
'segment_ids': input_type_ids
})
output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
pooled_output)
output = tf.keras.layers.Dense(
num_labels,
kernel_initializer=initializer,
name='output',
activation='softmax',
dtype=tf.float32)(
output)
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=output), bert_model