def __init__()

in tensorflow_examples/lite/model_maker/core/task/model_spec/text_spec.py [0:0]


  def __init__(
      self,
      uri='https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1',
      model_dir=None,
      seq_len=128,
      dropout_rate=0.1,
      initializer_range=0.02,
      learning_rate=3e-5,
      distribution_strategy='mirrored',
      num_gpus=-1,
      tpu='',
      trainable=True,
      do_lower_case=True,
      is_tf2=True,
      name='Bert',
      tflite_input_name=None,
      default_batch_size=32):
    """Initialze an instance with model paramaters.

    Args:
      uri: TF-Hub path/url to Bert module.
      model_dir: The location of the model checkpoint files.
      seq_len: Length of the sequence to feed into the model.
      dropout_rate: The rate for dropout.
      initializer_range: The stdev of the truncated_normal_initializer for
        initializing all weight matrices.
      learning_rate: The initial learning rate for Adam.
      distribution_strategy:  A string specifying which distribution strategy to
        use. Accepted values are 'off', 'one_device', 'mirrored',
        'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
        insensitive. 'off' means not to use Distribution Strategy; 'tpu' means
        to use TPUStrategy using `tpu_address`.
      num_gpus: How many GPUs to use at each worker with the
        DistributionStrategies API. The default is -1, which means utilize all
        available GPUs.
      tpu: TPU address to connect to.
      trainable: boolean, whether pretrain layer is trainable.
      do_lower_case: boolean, whether to lower case the input text. Should be
        True for uncased models and False for cased models.
      is_tf2: boolean, whether the hub module is in TensorFlow 2.x format.
      name: The name of the object.
      tflite_input_name: Dict, input names for the TFLite model.
      default_batch_size: Default batch size for training.
    """
    if compat.get_tf_behavior() not in self.compat_tf_versions:
      raise ValueError('Incompatible versions. Expect {}, but got {}.'.format(
          self.compat_tf_versions, compat.get_tf_behavior()))
    self.seq_len = seq_len
    self.dropout_rate = dropout_rate
    self.initializer_range = initializer_range
    self.learning_rate = learning_rate
    self.trainable = trainable

    self.model_dir = model_dir
    if self.model_dir is None:
      self.model_dir = tempfile.mkdtemp()

    num_gpus = util.get_num_gpus(num_gpus)
    self.strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=distribution_strategy,
        num_gpus=num_gpus,
        tpu_address=tpu)
    self.tpu = tpu
    self.uri = uri
    self.do_lower_case = do_lower_case
    self.is_tf2 = is_tf2

    self.bert_config = bert_configs.BertConfig(
        0,
        initializer_range=self.initializer_range,
        hidden_dropout_prob=self.dropout_rate)

    self.is_built = False
    self.name = name

    if tflite_input_name is None:
      tflite_input_name = {
          'ids': 'serving_default_input_word_ids:0',
          'mask': 'serving_default_input_mask:0',
          'segment_ids': 'serving_default_input_type_ids:0'
      }
    self.tflite_input_name = tflite_input_name
    self.default_batch_size = default_batch_size