def parse_params()

in community-content/vertex_model_garden/model_oss/movinet/train.py [0:0]


def parse_params() -> Any:
  """Parses parameters."""
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
  params = train_utils.parse_configuration(FLAGS, lock_return=False)
  if _INIT_CHECKPOINT.value:
    params.task.init_checkpoint = _INIT_CHECKPOINT.value
    params.task.init_checkpoint_modules = 'backbone'
  if _NUM_CLASSES.value:
    params.task.model.num_classes = _NUM_CLASSES.value
    params.task.train_data.num_classes = _NUM_CLASSES.value
    params.task.validation_data.num_classes = _NUM_CLASSES.value
  # If users set input train/validation data path, we assume the data are
  # converted from data converter as tfrecord. Users can use tfds by writing
  # their own config directly, and no need to override this parameter.
  if _INPUT_TRAIN_DATA_PATH.value:
    params.task.train_data.input_path = _INPUT_TRAIN_DATA_PATH.value
    params.task.train_data.file_type = _FILE_TYPE_TFRECORD
    params.task.train_data.tfds_name = ''
  if _INPUT_VALIDATION_DATA_PATH.value:
    params.task.validation_data.input_path = _INPUT_VALIDATION_DATA_PATH.value
    params.task.validation_data.file_type = _FILE_TYPE_TFRECORD
    params.task.validation_data.tfds_name = ''
  if _GLOBAL_BATCH_SIZE.value:
    params.task.train_data.global_batch_size = _GLOBAL_BATCH_SIZE.value
    params.task.validation_data.global_batch_size = _GLOBAL_BATCH_SIZE.value
  if _PREFETCH_BUFFER_SIZE.value:
    params.task.train_data.prefetch_buffer_size = _PREFETCH_BUFFER_SIZE.value
    params.task.validation_data.prefetch_buffer_size = (
        _PREFETCH_BUFFER_SIZE.value
    )
  if _SHUFFLE_BUFFER_SIZE.value:
    params.task.train_data.shuffle_buffer_size = _SHUFFLE_BUFFER_SIZE.value
  if _TRAIN_STEPS.value:
    params.trainer.train_steps = _TRAIN_STEPS.value
  if _LEARNING_RATE.value:
    logging.info('Updating learning_rate: %s', _LEARNING_RATE.value)
    # Use `get` method of train_utils.hyperparams.OneOfConfig to get learning
    # rate config.
    learning_rate = params.trainer.optimizer_config.learning_rate.get()
    if hasattr(learning_rate, 'initial_learning_rate'):
      learning_rate.initial_learning_rate = _LEARNING_RATE.value
    else:
      logging.warning('Cannot set learning rate for %s', learning_rate)
  # Set default params for best checkpoints.
  params.trainer.best_checkpoint_export_subdir = constants.BEST_CKPT_DIRNAME
  params.trainer.best_checkpoint_metric_comp = constants.BEST_CKPT_METRIC_COMP
  params.trainer.best_checkpoint_eval_metric = (
      constants.VIDEO_CLASSIFICATION_BEST_EVAL_METRIC
  )
  return params