def parse_params()

in community-content/vertex_model_garden/model_oss/tfvision/train_hpt_oss.py [0:0]


def parse_params() -> Any:
  """Parses parameters."""
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
  params = train_utils.parse_configuration(FLAGS, lock_return=False)
  if _INIT_CHECKPOINT.value:
    params.task.init_checkpoint = _INIT_CHECKPOINT.value
    if 'yolov7' in FLAGS.experiment:
      params.task.init_checkpoint_modules = ['backbone', 'decoder']
    else:
      params.task.init_checkpoint_modules = 'backbone'
  if _MODEL_NAME.value:
    if FLAGS.experiment in [
        'deit_imagenet_pretrain',
        'vit_imagenet_pretrain',
        'vit_imagenet_finetune',
    ]:
      params.task.model.backbone.vit.model_name = _MODEL_NAME.value
  if _NUM_CLASSES.value:
    params.task.model.num_classes = _NUM_CLASSES.value
  if _INPUT_SIZE.value:
    input_size = [int(elem) for elem in _INPUT_SIZE.value]
    if len(input_size) != 2:
      raise ValueError('The input size must contain 2 integers.')
    if input_size[0] < 0 or input_size[1] < 0:
      raise ValueError('The input size must be positive.')
    params.task.model.input_size = [input_size[0], input_size[1], 3]
  # If users set input train/validation data path, we assume the data are
  # converted from data converter as tfrecord. Users can use tfds by writing
  # their own config directly, and no need to override this parameter.
  if _INPUT_TRAIN_DATA_PATH.value:
    params.task.train_data.input_path = _INPUT_TRAIN_DATA_PATH.value
    params.task.train_data.file_type = _FILE_TYPE_TFRECORD
    params.task.train_data.tfds_name = ''
  if _INPUT_VALIDATION_DATA_PATH.value:
    params.task.validation_data.input_path = _INPUT_VALIDATION_DATA_PATH.value
    params.task.validation_data.file_type = _FILE_TYPE_TFRECORD
    params.task.validation_data.tfds_name = ''
  if _GLOBAL_BATCH_SIZE.value:
    params.task.train_data.global_batch_size = _GLOBAL_BATCH_SIZE.value
    params.task.validation_data.global_batch_size = _GLOBAL_BATCH_SIZE.value
  if _PREFETCH_BUFFER_SIZE.value:
    params.task.train_data.prefetch_buffer_size = _PREFETCH_BUFFER_SIZE.value
    params.task.validation_data.prefetch_buffer_size = (
        _PREFETCH_BUFFER_SIZE.value
    )

  # Use `get` method of train_utils.hyperparams.OneOfConfig to get learning
  # rate config.
  learning_rate = params.trainer.optimizer_config.learning_rate.get()

  if _TRAIN_STEPS.value:
    params.trainer.train_steps = _TRAIN_STEPS.value
    if hasattr(learning_rate, 'decay_steps'):
      learning_rate.decay_steps = _TRAIN_STEPS.value
  if (
      _BACKBONE_TRAINABLE.value is not None
      and params.task.model.backbone.type == 'hub_model'
  ):
    params.task.model.backbone.hub_model.trainable = _BACKBONE_TRAINABLE.value
  if _LEARNING_RATE.value:
    logging.info('Updating learning_rate: %s', _LEARNING_RATE.value)
    if hasattr(learning_rate, 'initial_learning_rate'):
      learning_rate.initial_learning_rate = _LEARNING_RATE.value

  if _WEIGHT_DECAY.value and 'yolo' in FLAGS.experiment:
    if 'sgd_torch' == params.trainer.optimizer_config.optimizer.type:
      params.trainer.optimizer_config.optimizer.sgd_torch.weight_decay = (
          _WEIGHT_DECAY.value
      )
    elif 'adamw' == params.trainer.optimizer_config.optimizer.type:
      params.trainer.optimizer_config.optimizer.adamw.weight_decay_rate = (
          _WEIGHT_DECAY.value
      )

  # Yolo models does not support anchor size.
  if _ANCHOR_SIZE.value and 'yolo' not in FLAGS.experiment:
    params.task.model.anchor.anchor_size = _ANCHOR_SIZE.value

  # Segmentation models will also set output size.
  if _OUTPUT_SIZE.value:
    output_size = [int(elem) for elem in _OUTPUT_SIZE.value]
    if len(output_size) != 2:
      raise ValueError('The output size must contain 2 integers.')
    if output_size[0] < 0 or output_size[1] < 0:
      raise ValueError('The output size must be positive.')
    params.task.train_data.output_size = output_size
    params.task.validation_data.output_size = output_size

  # Set default params for best checkpoints.
  params.trainer.best_checkpoint_export_subdir = constants.BEST_CKPT_DIRNAME
  params.trainer.best_checkpoint_metric_comp = constants.BEST_CKPT_METRIC_COMP
  params.trainer.best_checkpoint_eval_metric = get_best_eval_metric(
      _OBJECTIVE.value, params
  )
  return params