in tensorflow/sagemakercv/training/trainers.py [0:0]
def __call__(self, data_batch, training=True, broadcast=False):
if not training:
# TODO:
# this only works with the dict output from val change data pipeline
# to make training and val match
model_outputs = self.model(data_batch['features'], data_batch.get('labels'), training=training)
model_outputs.update({
'source_ids': data_batch['features']['source_ids'],
'image_info': data_batch['features']['image_info'],
})
return model_outputs
else:
with tf.GradientTape() as tape:
model_outputs = self.model(*data_batch, training=True, weight_decay=self.weight_decay)
if self.fp16:
scaled_loss = self.optimizer.get_scaled_loss(model_outputs['total_loss'])
if self.dist!=None:
tape = self.dist.DistributedGradientTape(tape)
if self.fp16:
scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
else:
gradients = tape.gradient(model_outputs['total_loss'], self.model.trainable_variables)
if self.global_gradient_clip_ratio > 0.0:
all_are_finite = tf.reduce_all([tf.reduce_all(tf.math.is_finite(g)) for g in gradients])
(clipped_grads, _) = tf.clip_by_global_norm(gradients,
clip_norm=self.global_gradient_clip_ratio,
use_norm=tf.cond(all_are_finite,
lambda: tf.linalg.global_norm(gradients),
lambda: tf.constant(1.0)))
gradients = clipped_grads
grads_and_vars = []
for grad, var in zip(gradients, self.model.trainable_variables):
if grad is not None and any([pattern in var.name for pattern in ["bias", "beta"]]):
grad = 2.0 * grad
grads_and_vars.append((grad, var))
self.optimizer.apply_gradients(grads_and_vars)
if self.dist!=None and broadcast:
if MPI_rank()==0:
logging.info("Broadcasting model")
self.dist.broadcast_variables(self.model.variables, 0)
self.dist.broadcast_variables(self.optimizer.variables(), 0)
losses = {i:j for i,j in model_outputs.items() if "loss" in i}
model_outputs.update({
'source_ids': data_batch[0]['source_ids'],
'image_info': data_batch[0]['image_info'],
})
return losses, model_outputs