in research/delf/delf/python/training/train.py [0:0]
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
#-------------------------------------------------------------
# Log flags used.
logging.info('Running training script with\n')
logging.info('logdir= %s', FLAGS.logdir)
logging.info('initial_lr= %f', FLAGS.initial_lr)
logging.info('block3_strides= %s', str(FLAGS.block3_strides))
# ------------------------------------------------------------
# Create the strategy.
strategy = tf.distribute.MirroredStrategy()
logging.info('Number of devices: %d', strategy.num_replicas_in_sync)
if FLAGS.debug:
print('Number of devices:', strategy.num_replicas_in_sync)
max_iters = FLAGS.max_iters
global_batch_size = FLAGS.batch_size
image_size = FLAGS.image_size
num_eval_batches = int(50000 / global_batch_size)
report_interval = 100
eval_interval = 1000
save_interval = 1000
initial_lr = FLAGS.initial_lr
clip_val = tf.constant(10.0)
if FLAGS.debug:
tf.config.run_functions_eagerly(True)
global_batch_size = 4
max_iters = 100
num_eval_batches = 1
save_interval = 1
report_interval = 10
# Determine the number of classes based on the version of the dataset.
gld_info = gld.GoogleLandmarksInfo()
num_classes = gld_info.num_classes[FLAGS.dataset_version]
# ------------------------------------------------------------
# Create the distributed train/validation sets.
train_dataset = gld.CreateDataset(
file_pattern=FLAGS.train_file_pattern,
batch_size=global_batch_size,
image_size=image_size,
augmentation=FLAGS.use_augmentation,
seed=FLAGS.seed)
validation_dataset = gld.CreateDataset(
file_pattern=FLAGS.validation_file_pattern,
batch_size=global_batch_size,
image_size=image_size,
augmentation=False,
seed=FLAGS.seed)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
validation_dist_dataset = strategy.experimental_distribute_dataset(
validation_dataset)
train_iter = iter(train_dist_dataset)
validation_iter = iter(validation_dist_dataset)
# Create a checkpoint directory to store the checkpoints.
checkpoint_prefix = os.path.join(FLAGS.logdir, 'delf_tf2-ckpt')
# ------------------------------------------------------------
# Finally, we do everything in distributed scope.
with strategy.scope():
# Compute loss.
# Set reduction to `none` so we can do the reduction afterwards and divide
# by global batch size.
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
def compute_loss(labels, predictions):
per_example_loss = loss_object(labels, predictions)
return tf.nn.compute_average_loss(
per_example_loss, global_batch_size=global_batch_size)
# Set up metrics.
desc_validation_loss = tf.keras.metrics.Mean(name='desc_validation_loss')
attn_validation_loss = tf.keras.metrics.Mean(name='attn_validation_loss')
desc_train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='desc_train_accuracy')
attn_train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='attn_train_accuracy')
desc_validation_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='desc_validation_accuracy')
attn_validation_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='attn_validation_accuracy')
# ------------------------------------------------------------
# Setup DELF model and optimizer.
model = create_model(num_classes)
logging.info('Model, datasets loaded.\nnum_classes= %d', num_classes)
optimizer = tf.keras.optimizers.SGD(learning_rate=initial_lr, momentum=0.9)
# Setup summary writer.
summary_writer = tf.summary.create_file_writer(
os.path.join(FLAGS.logdir, 'train_logs'), flush_millis=10000)
# Setup checkpoint directory.
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(
checkpoint,
checkpoint_prefix,
max_to_keep=10,
keep_checkpoint_every_n_hours=3)
# Restores the checkpoint, if existing.
checkpoint.restore(manager.latest_checkpoint)
# ------------------------------------------------------------
# Train step to run on one GPU.
def train_step(inputs):
"""Train one batch."""
images, labels = inputs
# Temporary workaround to avoid some corrupted labels.
labels = tf.clip_by_value(labels, 0, model.num_classes)
def _backprop_loss(tape, loss, weights):
"""Backpropogate losses using clipped gradients.
Args:
tape: gradient tape.
loss: scalar Tensor, loss value.
weights: keras model weights.
"""
gradients = tape.gradient(loss, weights)
clipped, _ = tf.clip_by_global_norm(gradients, clip_norm=clip_val)
optimizer.apply_gradients(zip(clipped, weights))
# Record gradients and loss through backbone.
with tf.GradientTape() as gradient_tape:
# Make a forward pass to calculate prelogits.
(desc_prelogits, attn_prelogits, attn_scores, backbone_blocks,
dim_expanded_features, _) = model.global_and_local_forward_pass(images)
# Calculate global loss by applying the descriptor classifier.
if FLAGS.delg_global_features:
desc_logits = model.desc_classification(desc_prelogits, labels)
else:
desc_logits = model.desc_classification(desc_prelogits)
desc_loss = compute_loss(labels, desc_logits)
# Calculate attention loss by applying the attention block classifier.
attn_logits = model.attn_classification(attn_prelogits)
attn_loss = compute_loss(labels, attn_logits)
# Calculate reconstruction loss between the attention prelogits and the
# backbone.
if FLAGS.use_autoencoder:
block3 = tf.stop_gradient(backbone_blocks['block3'])
reconstruction_loss = tf.math.reduce_mean(
tf.keras.losses.MSE(block3, dim_expanded_features))
else:
reconstruction_loss = 0
# Cumulate global loss, attention loss and reconstruction loss.
total_loss = (
desc_loss + FLAGS.attention_loss_weight * attn_loss +
FLAGS.reconstruction_loss_weight * reconstruction_loss)
# Perform backpropagation through the descriptor and attention layers
# together. Note that this will increment the number of iterations of
# "optimizer".
_backprop_loss(gradient_tape, total_loss, model.trainable_weights)
# Step number, for summary purposes.
global_step = optimizer.iterations
# Input image-related summaries.
tf.summary.image('batch_images', (images + 1.0) / 2.0, step=global_step)
tf.summary.scalar(
'image_range/max', tf.reduce_max(images), step=global_step)
tf.summary.scalar(
'image_range/min', tf.reduce_min(images), step=global_step)
# Attention and sparsity summaries.
_attention_summaries(attn_scores, global_step)
activations_zero_fractions = {
'sparsity/%s' % k: tf.nn.zero_fraction(v)
for k, v in backbone_blocks.items()
}
for k, v in activations_zero_fractions.items():
tf.summary.scalar(k, v, step=global_step)
# Scaling factor summary for cosine logits for a DELG model.
if FLAGS.delg_global_features:
tf.summary.scalar(
'desc/scale_factor', model.scale_factor, step=global_step)
# Record train accuracies.
_record_accuracy(desc_train_accuracy, desc_logits, labels)
_record_accuracy(attn_train_accuracy, attn_logits, labels)
return desc_loss, attn_loss, reconstruction_loss
# ------------------------------------------------------------
def validation_step(inputs):
"""Validate one batch."""
images, labels = inputs
labels = tf.clip_by_value(labels, 0, model.num_classes)
# Get descriptor predictions.
blocks = {}
prelogits = model.backbone(
images, intermediates_dict=blocks, training=False)
if FLAGS.delg_global_features:
logits = model.desc_classification(prelogits, labels, training=False)
else:
logits = model.desc_classification(prelogits, training=False)
softmax_probabilities = tf.keras.layers.Softmax()(logits)
validation_loss = loss_object(labels, logits)
desc_validation_loss.update_state(validation_loss)
desc_validation_accuracy.update_state(labels, softmax_probabilities)
# Get attention predictions.
block3 = blocks['block3'] # pytype: disable=key-error
prelogits, _, _ = model.attention(block3, training=False)
logits = model.attn_classification(prelogits, training=False)
softmax_probabilities = tf.keras.layers.Softmax()(logits)
validation_loss = loss_object(labels, logits)
attn_validation_loss.update_state(validation_loss)
attn_validation_accuracy.update_state(labels, softmax_probabilities)
return desc_validation_accuracy.result(), attn_validation_accuracy.result(
)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
"""Get the actual losses."""
# Each (desc, attn) is a list of 3 losses - crossentropy, reg, total.
desc_per_replica_loss, attn_per_replica_loss, recon_per_replica_loss = (
strategy.run(train_step, args=(dataset_inputs,)))
# Reduce over the replicas.
desc_global_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, desc_per_replica_loss, axis=None)
attn_global_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, attn_per_replica_loss, axis=None)
recon_global_loss = strategy.reduce(
tf.distribute.ReduceOp.SUM, recon_per_replica_loss, axis=None)
return desc_global_loss, attn_global_loss, recon_global_loss
@tf.function
def distributed_validation_step(dataset_inputs):
return strategy.run(validation_step, args=(dataset_inputs,))
# ------------------------------------------------------------
# *** TRAIN LOOP ***
with summary_writer.as_default():
record_cond = lambda: tf.equal(optimizer.iterations % report_interval, 0)
with tf.summary.record_if(record_cond):
global_step_value = optimizer.iterations.numpy()
# TODO(dananghel): try to load pretrained weights at backbone creation.
# Load pretrained weights for ResNet50 trained on ImageNet.
if (FLAGS.imagenet_checkpoint is not None) and (not global_step_value):
logging.info('Attempting to load ImageNet pretrained weights.')
input_batch = next(train_iter)
_, _, _ = distributed_train_step(input_batch)
model.backbone.restore_weights(FLAGS.imagenet_checkpoint)
logging.info('Done.')
else:
logging.info('Skip loading ImageNet pretrained weights.')
if FLAGS.debug:
model.backbone.log_weights()
last_summary_step_value = None
last_summary_time = None
while global_step_value < max_iters:
# input_batch : images(b, h, w, c), labels(b,).
try:
input_batch = next(train_iter)
except tf.errors.OutOfRangeError:
# Break if we run out of data in the dataset.
logging.info('Stopping training at global step %d, no more data',
global_step_value)
break
# Set learning rate and run the training step over num_gpu gpus.
optimizer.learning_rate = _learning_rate_schedule(
optimizer.iterations.numpy(), max_iters, initial_lr)
desc_dist_loss, attn_dist_loss, recon_dist_loss = (
distributed_train_step(input_batch))
# Step number, to be used for summary/logging.
global_step = optimizer.iterations
global_step_value = global_step.numpy()
# LR, losses and accuracies summaries.
tf.summary.scalar(
'learning_rate', optimizer.learning_rate, step=global_step)
tf.summary.scalar(
'loss/desc/crossentropy', desc_dist_loss, step=global_step)
tf.summary.scalar(
'loss/attn/crossentropy', attn_dist_loss, step=global_step)
if FLAGS.use_autoencoder:
tf.summary.scalar(
'loss/recon/mse', recon_dist_loss, step=global_step)
tf.summary.scalar(
'train_accuracy/desc',
desc_train_accuracy.result(),
step=global_step)
tf.summary.scalar(
'train_accuracy/attn',
attn_train_accuracy.result(),
step=global_step)
# Summary for number of global steps taken per second.
current_time = time.time()
if (last_summary_step_value is not None and
last_summary_time is not None):
tf.summary.scalar(
'global_steps_per_sec',
(global_step_value - last_summary_step_value) /
(current_time - last_summary_time),
step=global_step)
if tf.summary.should_record_summaries().numpy():
last_summary_step_value = global_step_value
last_summary_time = current_time
# Print to console if running locally.
if FLAGS.debug:
if global_step_value % report_interval == 0:
print(global_step.numpy())
print('desc:', desc_dist_loss.numpy())
print('attn:', attn_dist_loss.numpy())
# Validate once in {eval_interval*n, n \in N} steps.
if global_step_value % eval_interval == 0:
for i in range(num_eval_batches):
try:
validation_batch = next(validation_iter)
desc_validation_result, attn_validation_result = (
distributed_validation_step(validation_batch))
except tf.errors.OutOfRangeError:
logging.info('Stopping eval at batch %d, no more data', i)
break
# Log validation results to tensorboard.
tf.summary.scalar(
'validation/desc', desc_validation_result, step=global_step)
tf.summary.scalar(
'validation/attn', attn_validation_result, step=global_step)
logging.info('\nValidation(%f)\n', global_step_value)
logging.info(': desc: %f\n', desc_validation_result.numpy())
logging.info(': attn: %f\n', attn_validation_result.numpy())
# Print to console.
if FLAGS.debug:
print('Validation: desc:', desc_validation_result.numpy())
print(' : attn:', attn_validation_result.numpy())
# Save checkpoint once (each save_interval*n, n \in N) steps, or if
# this is the last iteration.
# TODO(andrearaujo): save only in one of the two ways. They are
# identical, the only difference is that the manager adds some extra
# prefixes and variables (eg, optimizer variables).
if (global_step_value % save_interval
== 0) or (global_step_value >= max_iters):
save_path = manager.save(checkpoint_number=global_step_value)
logging.info('Saved (%d) at %s', global_step_value, save_path)
file_path = '%s/delf_weights' % FLAGS.logdir
model.save_weights(file_path, save_format='tf')
logging.info('Saved weights (%d) at %s', global_step_value,
file_path)
# Reset metrics for next step.
desc_train_accuracy.reset_states()
attn_train_accuracy.reset_states()
desc_validation_loss.reset_states()
attn_validation_loss.reset_states()
desc_validation_accuracy.reset_states()
attn_validation_accuracy.reset_states()
logging.info('Finished training for %d steps.', max_iters)