in tensorflow_examples/lite/model_maker/core/task/model_spec/text_spec.py [0:0]
def __init__(
self,
uri='https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1',
model_dir=None,
seq_len=128,
dropout_rate=0.1,
initializer_range=0.02,
learning_rate=3e-5,
distribution_strategy='mirrored',
num_gpus=-1,
tpu='',
trainable=True,
do_lower_case=True,
is_tf2=True,
name='Bert',
tflite_input_name=None,
default_batch_size=32):
"""Initialze an instance with model paramaters.
Args:
uri: TF-Hub path/url to Bert module.
model_dir: The location of the model checkpoint files.
seq_len: Length of the sequence to feed into the model.
dropout_rate: The rate for dropout.
initializer_range: The stdev of the truncated_normal_initializer for
initializing all weight matrices.
learning_rate: The initial learning rate for Adam.
distribution_strategy: A string specifying which distribution strategy to
use. Accepted values are 'off', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
insensitive. 'off' means not to use Distribution Strategy; 'tpu' means
to use TPUStrategy using `tpu_address`.
num_gpus: How many GPUs to use at each worker with the
DistributionStrategies API. The default is -1, which means utilize all
available GPUs.
tpu: TPU address to connect to.
trainable: boolean, whether pretrain layer is trainable.
do_lower_case: boolean, whether to lower case the input text. Should be
True for uncased models and False for cased models.
is_tf2: boolean, whether the hub module is in TensorFlow 2.x format.
name: The name of the object.
tflite_input_name: Dict, input names for the TFLite model.
default_batch_size: Default batch size for training.
"""
if compat.get_tf_behavior() not in self.compat_tf_versions:
raise ValueError('Incompatible versions. Expect {}, but got {}.'.format(
self.compat_tf_versions, compat.get_tf_behavior()))
self.seq_len = seq_len
self.dropout_rate = dropout_rate
self.initializer_range = initializer_range
self.learning_rate = learning_rate
self.trainable = trainable
self.model_dir = model_dir
if self.model_dir is None:
self.model_dir = tempfile.mkdtemp()
num_gpus = util.get_num_gpus(num_gpus)
self.strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=distribution_strategy,
num_gpus=num_gpus,
tpu_address=tpu)
self.tpu = tpu
self.uri = uri
self.do_lower_case = do_lower_case
self.is_tf2 = is_tf2
self.bert_config = bert_configs.BertConfig(
0,
initializer_range=self.initializer_range,
hidden_dropout_prob=self.dropout_rate)
self.is_built = False
self.name = name
if tflite_input_name is None:
tflite_input_name = {
'ids': 'serving_default_input_word_ids:0',
'mask': 'serving_default_input_mask:0',
'segment_ids': 'serving_default_input_type_ids:0'
}
self.tflite_input_name = tflite_input_name
self.default_batch_size = default_batch_size