in research/delf/delf/python/training/global_features/train.py [0:0]
def main(argv):
if len(argv) > 1:
raise RuntimeError('Too many command-line arguments.')
# Manually check if there are unknown test datasets and if the dataset
# ground truth files are downloaded.
for dataset in FLAGS.test_datasets:
if dataset not in _TEST_DATASET_NAMES:
raise ValueError('Unsupported or unknown test dataset: {}.'.format(
dataset))
test_data_config = os.path.join(FLAGS.data_root,
'gnd_{}.pkl'.format(dataset))
if not tf.io.gfile.exists(test_data_config):
raise ValueError(
'{} ground truth file at {} not found. Please download it '
'according to '
'the DELG instructions.'.format(dataset, FLAGS.data_root))
# Check if train dataset is downloaded and download it if not found.
dataset_download.download_train(FLAGS.data_root)
# Creating model export directory if it does not exist.
model_directory = global_features_utils.create_model_directory(
FLAGS.training_dataset, FLAGS.arch, FLAGS.pool, FLAGS.whitening,
FLAGS.pretrained, FLAGS.loss, FLAGS.loss_margin, FLAGS.optimizer,
FLAGS.lr, FLAGS.weight_decay, FLAGS.neg_num, FLAGS.query_size,
FLAGS.pool_size, FLAGS.batch_size, FLAGS.update_every,
FLAGS.image_size, FLAGS.directory)
# Setting up logging directory, same as where the model is stored.
logging.get_absl_handler().use_absl_log_file('absl_logging', model_directory)
# Set cuda visible device.
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_id
global_features_utils.debug_and_log('>> Num GPUs Available: {}'.format(
len(tf.config.experimental.list_physical_devices('GPU'))),
FLAGS.debug)
# Set random seeds.
tf.random.set_seed(0)
np.random.seed(0)
# Initialize the model.
if FLAGS.pretrained:
global_features_utils.debug_and_log(
'>> Using pre-trained model \'{}\''.format(FLAGS.arch))
else:
global_features_utils.debug_and_log(
'>> Using model from scratch (random weights) \'{}\'.'.format(
FLAGS.arch))
model_params = {'architecture': FLAGS.arch, 'pooling': FLAGS.pool,
'whitening': FLAGS.whitening, 'pretrained': FLAGS.pretrained,
'data_root': FLAGS.data_root}
model = global_model.GlobalFeatureNet(**model_params)
# Freeze running mean and std in batch normalization layers.
# We do training one image at a time to improve memory requirements of
# the network; therefore, the computed statistics would not be per a
# batch. Instead, we choose freezing - setting the parameters of all
# batch norm layers in the network to non-trainable (i.e., using original
# imagenet statistics).
for layer in model.feature_extractor.layers:
if isinstance(layer, tf.keras.layers.BatchNormalization):
layer.trainable = False
global_features_utils.debug_and_log('>> Network initialized.')
global_features_utils.debug_and_log('>> Loss: {}.'.format(FLAGS.loss))
# Define the loss function.
if FLAGS.loss == 'contrastive':
criterion = ranking_losses.ContrastiveLoss(margin=FLAGS.loss_margin)
elif FLAGS.loss == 'triplet':
criterion = ranking_losses.TripletLoss(margin=FLAGS.loss_margin)
else:
raise ValueError('Loss {} not available.'.format(FLAGS.loss))
# Defining parameters for the training.
# When pre-computing whitening, we run evaluation before the network training
# and the `start_epoch` is set to 0. In other cases, we start from epoch 1.
start_epoch = 1
exp_decay = math.exp(-0.01)
decay_steps = FLAGS.query_size / FLAGS.batch_size
# Define learning rate decay schedule.
lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=FLAGS.lr,
decay_steps=decay_steps,
decay_rate=exp_decay)
# Define the optimizer.
if FLAGS.optimizer == 'sgd':
opt = tfa.optimizers.extend_with_decoupled_weight_decay(
tf.keras.optimizers.SGD)
optimizer = opt(weight_decay=FLAGS.weight_decay,
learning_rate=lr_scheduler, momentum=FLAGS.momentum)
elif FLAGS.optimizer == 'adam':
opt = tfa.optimizers.extend_with_decoupled_weight_decay(
tf.keras.optimizers.Adam)
optimizer = opt(weight_decay=FLAGS.weight_decay, learning_rate=lr_scheduler)
else:
raise ValueError('Optimizer {} not available.'.format(FLAGS.optimizer))
# Initializing logging.
writer = tf.summary.create_file_writer(model_directory)
tf.summary.experimental.set_step(1)
# Setting up the checkpoint manager.
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(
checkpoint,
model_directory,
max_to_keep=10,
keep_checkpoint_every_n_hours=3)
if FLAGS.resume:
# Restores the checkpoint, if existing.
global_features_utils.debug_and_log('>> Continuing from a checkpoint.')
checkpoint.restore(manager.latest_checkpoint)
# Launching tensorboard if required.
if FLAGS.launch_tensorboard:
tensorboard = tf.keras.callbacks.TensorBoard(model_directory)
tensorboard.set_model(model=model)
tensorboard_utils.launch_tensorboard(log_dir=model_directory)
# Log flags used.
global_features_utils.debug_and_log('>> Running training script with:')
global_features_utils.debug_and_log('>> logdir = {}'.format(model_directory))
if FLAGS.training_dataset.startswith('retrieval-SfM-120k'):
train_dataset = sfm120k.CreateDataset(
data_root=FLAGS.data_root,
mode='train',
imsize=FLAGS.image_size,
num_negatives=FLAGS.neg_num,
num_queries=FLAGS.query_size,
pool_size=FLAGS.pool_size
)
if FLAGS.validation_type is not None:
val_dataset = sfm120k.CreateDataset(
data_root=FLAGS.data_root,
mode='val',
imsize=FLAGS.image_size,
num_negatives=FLAGS.neg_num,
num_queries=float('Inf'),
pool_size=float('Inf'),
eccv2020=True if FLAGS.validation_type == 'eccv2020' else False
)
train_dataset_output_types = [tf.float32 for i in range(2 + FLAGS.neg_num)]
train_dataset_output_types.append(tf.int32)
global_features_utils.debug_and_log(
'>> Training the {} network'.format(model_directory))
global_features_utils.debug_and_log('>> GPU ids: {}'.format(FLAGS.gpu_id))
with writer.as_default():
# Precompute whitening if needed.
if FLAGS.precompute_whitening is not None:
epoch = 0
train_utils.test_retrieval(
FLAGS.test_datasets, model, writer=writer,
epoch=epoch, model_directory=model_directory,
precompute_whitening=FLAGS.precompute_whitening,
data_root=FLAGS.data_root,
multiscale=FLAGS.multiscale)
for epoch in range(start_epoch, FLAGS.epochs + 1):
# Set manual seeds per epoch.
np.random.seed(epoch)
tf.random.set_seed(epoch)
# Find hard-negatives.
# While hard-positive examples are fixed during the whole training
# process and are randomly chosen from every epoch; hard-negatives
# depend on the current CNN parameters and are re-mined once per epoch.
avg_neg_distance = train_dataset.create_epoch_tuples(model)
def _train_gen():
return (inst for inst in train_dataset)
train_loader = tf.data.Dataset.from_generator(
_train_gen,
output_types=tuple(train_dataset_output_types))
loss = train_utils.train_val_one_epoch(
loader=iter(train_loader), model=model,
criterion=criterion, optimizer=optimizer, epoch=epoch,
batch_size=FLAGS.batch_size, query_size=FLAGS.query_size,
neg_num=FLAGS.neg_num, update_every=FLAGS.update_every,
debug=FLAGS.debug)
# Write a scalar summary.
tf.summary.scalar('train_epoch_loss', loss, step=epoch)
# Forces summary writer to send any buffered data to storage.
writer.flush()
# Evaluate on validation set.
if FLAGS.validation_type is not None and (epoch % FLAGS.test_freq == 0 or
epoch == 1):
avg_neg_distance = val_dataset.create_epoch_tuples(model,
model_directory)
def _val_gen():
return (inst for inst in val_dataset)
val_loader = tf.data.Dataset.from_generator(
_val_gen, output_types=tuple(train_dataset_output_types))
loss = train_utils.train_val_one_epoch(
loader=iter(val_loader), model=model,
criterion=criterion, optimizer=None,
epoch=epoch, train=False, batch_size=FLAGS.batch_size,
query_size=FLAGS.query_size, neg_num=FLAGS.neg_num,
update_every=FLAGS.update_every, debug=FLAGS.debug)
tf.summary.scalar('val_epoch_loss', loss, step=epoch)
writer.flush()
# Evaluate on test datasets every test_freq epochs.
if epoch == 1 or epoch % FLAGS.test_freq == 0:
train_utils.test_retrieval(
FLAGS.test_datasets, model, writer=writer, epoch=epoch,
model_directory=model_directory,
precompute_whitening=FLAGS.precompute_whitening,
data_root=FLAGS.data_root, multiscale=FLAGS.multiscale)
# Saving checkpoints and model weights.
try:
save_path = manager.save(checkpoint_number=epoch)
global_features_utils.debug_and_log(
'Saved ({}) at {}'.format(epoch, save_path))
filename = os.path.join(model_directory,
'checkpoint_epoch_{}.h5'.format(epoch))
model.save_weights(filename, save_format='h5')
global_features_utils.debug_and_log(
'Saved weights ({}) at {}'.format(epoch, filename))
except Exception as ex:
global_features_utils.debug_and_log(
'Could not save checkpoint: {}'.format(ex))