in community-content/vertex_model_garden/model_oss/movinet/train.py [0:0]
def parse_params() -> Any:
"""Parses parameters."""
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS, lock_return=False)
if _INIT_CHECKPOINT.value:
params.task.init_checkpoint = _INIT_CHECKPOINT.value
params.task.init_checkpoint_modules = 'backbone'
if _NUM_CLASSES.value:
params.task.model.num_classes = _NUM_CLASSES.value
params.task.train_data.num_classes = _NUM_CLASSES.value
params.task.validation_data.num_classes = _NUM_CLASSES.value
# If users set input train/validation data path, we assume the data are
# converted from data converter as tfrecord. Users can use tfds by writing
# their own config directly, and no need to override this parameter.
if _INPUT_TRAIN_DATA_PATH.value:
params.task.train_data.input_path = _INPUT_TRAIN_DATA_PATH.value
params.task.train_data.file_type = _FILE_TYPE_TFRECORD
params.task.train_data.tfds_name = ''
if _INPUT_VALIDATION_DATA_PATH.value:
params.task.validation_data.input_path = _INPUT_VALIDATION_DATA_PATH.value
params.task.validation_data.file_type = _FILE_TYPE_TFRECORD
params.task.validation_data.tfds_name = ''
if _GLOBAL_BATCH_SIZE.value:
params.task.train_data.global_batch_size = _GLOBAL_BATCH_SIZE.value
params.task.validation_data.global_batch_size = _GLOBAL_BATCH_SIZE.value
if _PREFETCH_BUFFER_SIZE.value:
params.task.train_data.prefetch_buffer_size = _PREFETCH_BUFFER_SIZE.value
params.task.validation_data.prefetch_buffer_size = (
_PREFETCH_BUFFER_SIZE.value
)
if _SHUFFLE_BUFFER_SIZE.value:
params.task.train_data.shuffle_buffer_size = _SHUFFLE_BUFFER_SIZE.value
if _TRAIN_STEPS.value:
params.trainer.train_steps = _TRAIN_STEPS.value
if _LEARNING_RATE.value:
logging.info('Updating learning_rate: %s', _LEARNING_RATE.value)
# Use `get` method of train_utils.hyperparams.OneOfConfig to get learning
# rate config.
learning_rate = params.trainer.optimizer_config.learning_rate.get()
if hasattr(learning_rate, 'initial_learning_rate'):
learning_rate.initial_learning_rate = _LEARNING_RATE.value
else:
logging.warning('Cannot set learning rate for %s', learning_rate)
# Set default params for best checkpoints.
params.trainer.best_checkpoint_export_subdir = constants.BEST_CKPT_DIRNAME
params.trainer.best_checkpoint_metric_comp = constants.BEST_CKPT_METRIC_COMP
params.trainer.best_checkpoint_eval_metric = (
constants.VIDEO_CLASSIFICATION_BEST_EVAL_METRIC
)
return params