in tensorflow_similarity/models/contrastive_model.py [0:0]
def train_step(self, data):
view1, view2 = self._parse_views(data)
# Forward pass through the backbone, projector, and predictor
with tf.GradientTape() as tape:
contrastive_loss, pred1, pred2, z1, z2 = self._forward_pass(
view1, view2, training=True
)
regularization_loss = sum(self.backbone.losses)
regularization_loss += sum(self.projector.losses)
if self.predictor is not None:
regularization_loss += sum(self.predictor.losses)
combined_loss = contrastive_loss + regularization_loss
# collect train variables from both the backbone, projector, and projector
tvars = self.backbone.trainable_variables
tvars = tvars + self.projector.trainable_variables
if self.predictor is not None:
tvars = tvars + self.predictor.trainable_variables
# Compute gradients
gradients = tape.gradient(combined_loss, tvars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, tvars))
# Update metrics
# !This are contrastive metrics with different input
# TODO: figure out interesting metrics -- z Mae?
# TODO: check metrics are of the right type in compile?
# self.compiled_metrics.update_state([z1, z2], [pred2, pred1])
# report loss manually
self.loss_trackers["loss"].update_state(combined_loss)
if self.reg_loss_len:
self.loss_trackers["contrastive_loss"].update_state(
contrastive_loss
)
self.loss_trackers["regularization_loss"].update_state(
regularization_loss
)
# Collect metrics to return
return_metrics = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
return_metrics.update(result)
else:
return_metrics[metric.name] = result
return return_metrics