in community-content/vertex_model_garden/model_oss/tfvision/configs/hub_model.py [0:0]
def hub_model() -> cfg.ExperimentConfig:
"""Gets experimental configs for tf-hub models."""
batch_size = 8
train_steps = 625000
steps_per_loop = 1250
return cfg.ExperimentConfig(
task=image_classification.ImageClassificationTask(
model=image_classification.ImageClassificationModel(
num_classes=1000,
input_size=_INPUT_SIZE,
backbone=backbones.Backbone(
type='hub_model',
hub_model=backbones.HubModel(
handle=_HANDLE, mean_rgb=_MEAN_RGB, stddev_rgb=_STDDEV_RGB
),
),
dropout_rate=0.0,
),
losses=image_classification.Losses(
l2_weight_decay=0.0, label_smoothing=0.1, one_hot=True
),
train_data=image_classification.DataConfig(
input_path=os.path.join(
image_classification.IMAGENET_INPUT_PATH_BASE, 'train*'
),
aug_type=None,
dtype='float32',
global_batch_size=batch_size,
is_training=True,
decode_jpeg_only=False,
),
validation_data=image_classification.DataConfig(
input_path=os.path.join(
image_classification.IMAGENET_INPUT_PATH_BASE, 'valid*'
),
dtype='float32',
global_batch_size=batch_size,
is_training=False,
decode_jpeg_only=False,
drop_remainder=False,
),
),
trainer=cfg.TrainerConfig(
best_checkpoint_eval_metric='accuracy',
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_metric_comp='higher',
optimizer_config=optimization.OptimizationConfig(
learning_rate=optimization.LrConfig(
type='cosine',
cosine=optimization.lr_cfg.CosineLrConfig(
decay_steps=train_steps, initial_learning_rate=0.001
),
),
optimizer=optimization.OptimizerConfig(
type='sgd', sgd=optimization.SGDConfig(momentum=0.9)
),
),
checkpoint_interval=steps_per_loop,
steps_per_loop=steps_per_loop,
summary_interval=steps_per_loop,
validation_interval=steps_per_loop,
train_steps=train_steps,
validation_steps=-1,
),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
],
)