def create()

in tensorflow_examples/lite/model_maker/core/task/image_classifier.py [0:0]


  def create(cls,
             train_data,
             model_spec='efficientnet_lite0',
             validation_data=None,
             batch_size=None,
             epochs=None,
             steps_per_epoch=None,
             train_whole_model=None,
             dropout_rate=None,
             learning_rate=None,
             momentum=None,
             shuffle=False,
             use_augmentation=False,
             use_hub_library=True,
             warmup_steps=None,
             model_dir=None,
             do_train=True):
    """Loads data and retrains the model based on data for image classification.

    Args:
      train_data: Training data.
      model_spec: Specification for the model.
      validation_data: Validation data. If None, skips validation process.
      batch_size: Number of samples per training step. If `use_hub_library` is
        False, it represents the base learning rate when train batch size is 256
        and it's linear to the batch size.
      epochs: Number of epochs for training.
      steps_per_epoch: Integer or None. Total number of steps (batches of
        samples) before declaring one epoch finished and starting the next
        epoch. If `steps_per_epoch` is None, the epoch will run until the input
        dataset is exhausted.
      train_whole_model: If true, the Hub module is trained together with the
        classification layer on top. Otherwise, only train the top
        classification layer.
      dropout_rate: The rate for dropout.
      learning_rate: Base learning rate when train batch size is 256. Linear to
        the batch size.
      momentum: a Python float forwarded to the optimizer. Only used when
        `use_hub_library` is True.
      shuffle: Whether the data should be shuffled.
      use_augmentation: Use data augmentation for preprocessing.
      use_hub_library: Use `make_image_classifier_lib` from tensorflow hub to
        retrain the model.
      warmup_steps: Number of warmup steps for warmup schedule on learning rate.
        If None, the default warmup_steps is used which is the total training
        steps in two epochs. Only used when `use_hub_library` is False.
      model_dir: The location of the model checkpoint files. Only used when
        `use_hub_library` is False.
      do_train: Whether to run training.

    Returns:
      An instance based on ImageClassifier.
    """
    model_spec = ms.get(model_spec)
    if compat.get_tf_behavior() not in model_spec.compat_tf_versions:
      raise ValueError('Incompatible versions. Expect {}, but got {}.'.format(
          model_spec.compat_tf_versions, compat.get_tf_behavior()))

    if use_hub_library:
      hparams = get_hub_lib_hparams(
          batch_size=batch_size,
          train_epochs=epochs,
          do_fine_tuning=train_whole_model,
          dropout_rate=dropout_rate,
          learning_rate=learning_rate,
          momentum=momentum)
    else:
      hparams = train_image_classifier_lib.HParams.get_hparams(
          batch_size=batch_size,
          train_epochs=epochs,
          do_fine_tuning=train_whole_model,
          dropout_rate=dropout_rate,
          learning_rate=learning_rate,
          warmup_steps=warmup_steps,
          model_dir=model_dir)

    image_classifier = cls(
        model_spec,
        train_data.index_to_label,
        shuffle=shuffle,
        hparams=hparams,
        use_augmentation=use_augmentation,
        representative_data=train_data)

    if do_train:
      tf.compat.v1.logging.info('Retraining the models...')
      image_classifier.train(train_data, validation_data, steps_per_epoch)
    else:
      # Used in evaluation.
      image_classifier.create_model(with_loss_and_metrics=True)

    return image_classifier