in community-content/vertex_model_garden/model_oss/tfvision/train_hpt_oss.py [0:0]
def parse_params() -> Any:
"""Parses parameters."""
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS, lock_return=False)
if _INIT_CHECKPOINT.value:
params.task.init_checkpoint = _INIT_CHECKPOINT.value
if 'yolov7' in FLAGS.experiment:
params.task.init_checkpoint_modules = ['backbone', 'decoder']
else:
params.task.init_checkpoint_modules = 'backbone'
if _MODEL_NAME.value:
if FLAGS.experiment in [
'deit_imagenet_pretrain',
'vit_imagenet_pretrain',
'vit_imagenet_finetune',
]:
params.task.model.backbone.vit.model_name = _MODEL_NAME.value
if _NUM_CLASSES.value:
params.task.model.num_classes = _NUM_CLASSES.value
if _INPUT_SIZE.value:
input_size = [int(elem) for elem in _INPUT_SIZE.value]
if len(input_size) != 2:
raise ValueError('The input size must contain 2 integers.')
if input_size[0] < 0 or input_size[1] < 0:
raise ValueError('The input size must be positive.')
params.task.model.input_size = [input_size[0], input_size[1], 3]
# If users set input train/validation data path, we assume the data are
# converted from data converter as tfrecord. Users can use tfds by writing
# their own config directly, and no need to override this parameter.
if _INPUT_TRAIN_DATA_PATH.value:
params.task.train_data.input_path = _INPUT_TRAIN_DATA_PATH.value
params.task.train_data.file_type = _FILE_TYPE_TFRECORD
params.task.train_data.tfds_name = ''
if _INPUT_VALIDATION_DATA_PATH.value:
params.task.validation_data.input_path = _INPUT_VALIDATION_DATA_PATH.value
params.task.validation_data.file_type = _FILE_TYPE_TFRECORD
params.task.validation_data.tfds_name = ''
if _GLOBAL_BATCH_SIZE.value:
params.task.train_data.global_batch_size = _GLOBAL_BATCH_SIZE.value
params.task.validation_data.global_batch_size = _GLOBAL_BATCH_SIZE.value
if _PREFETCH_BUFFER_SIZE.value:
params.task.train_data.prefetch_buffer_size = _PREFETCH_BUFFER_SIZE.value
params.task.validation_data.prefetch_buffer_size = (
_PREFETCH_BUFFER_SIZE.value
)
# Use `get` method of train_utils.hyperparams.OneOfConfig to get learning
# rate config.
learning_rate = params.trainer.optimizer_config.learning_rate.get()
if _TRAIN_STEPS.value:
params.trainer.train_steps = _TRAIN_STEPS.value
if hasattr(learning_rate, 'decay_steps'):
learning_rate.decay_steps = _TRAIN_STEPS.value
if (
_BACKBONE_TRAINABLE.value is not None
and params.task.model.backbone.type == 'hub_model'
):
params.task.model.backbone.hub_model.trainable = _BACKBONE_TRAINABLE.value
if _LEARNING_RATE.value:
logging.info('Updating learning_rate: %s', _LEARNING_RATE.value)
if hasattr(learning_rate, 'initial_learning_rate'):
learning_rate.initial_learning_rate = _LEARNING_RATE.value
if _WEIGHT_DECAY.value and 'yolo' in FLAGS.experiment:
if 'sgd_torch' == params.trainer.optimizer_config.optimizer.type:
params.trainer.optimizer_config.optimizer.sgd_torch.weight_decay = (
_WEIGHT_DECAY.value
)
elif 'adamw' == params.trainer.optimizer_config.optimizer.type:
params.trainer.optimizer_config.optimizer.adamw.weight_decay_rate = (
_WEIGHT_DECAY.value
)
# Yolo models does not support anchor size.
if _ANCHOR_SIZE.value and 'yolo' not in FLAGS.experiment:
params.task.model.anchor.anchor_size = _ANCHOR_SIZE.value
# Segmentation models will also set output size.
if _OUTPUT_SIZE.value:
output_size = [int(elem) for elem in _OUTPUT_SIZE.value]
if len(output_size) != 2:
raise ValueError('The output size must contain 2 integers.')
if output_size[0] < 0 or output_size[1] < 0:
raise ValueError('The output size must be positive.')
params.task.train_data.output_size = output_size
params.task.validation_data.output_size = output_size
# Set default params for best checkpoints.
params.trainer.best_checkpoint_export_subdir = constants.BEST_CKPT_DIRNAME
params.trainer.best_checkpoint_metric_comp = constants.BEST_CKPT_METRIC_COMP
params.trainer.best_checkpoint_eval_metric = get_best_eval_metric(
_OBJECTIVE.value, params
)
return params