def train_step()

in tensorflow_similarity/models/contrastive_model.py [0:0]


    def train_step(self, data):
        view1, view2 = self._parse_views(data)

        # Forward pass through the backbone, projector, and predictor
        with tf.GradientTape() as tape:
            contrastive_loss, pred1, pred2, z1, z2 = self._forward_pass(
                view1, view2, training=True
            )
            regularization_loss = sum(self.backbone.losses)
            regularization_loss += sum(self.projector.losses)
            if self.predictor is not None:
                regularization_loss += sum(self.predictor.losses)
            combined_loss = contrastive_loss + regularization_loss

        # collect train variables from both the backbone, projector, and projector
        tvars = self.backbone.trainable_variables
        tvars = tvars + self.projector.trainable_variables

        if self.predictor is not None:
            tvars = tvars + self.predictor.trainable_variables

        # Compute gradients
        gradients = tape.gradient(combined_loss, tvars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, tvars))

        # Update metrics
        # !This are contrastive metrics with different input
        # TODO: figure out interesting metrics -- z Mae?
        # TODO: check metrics are of the right type in compile?
        # self.compiled_metrics.update_state([z1, z2], [pred2, pred1])

        # report loss manually
        self.loss_trackers["loss"].update_state(combined_loss)
        if self.reg_loss_len:
            self.loss_trackers["contrastive_loss"].update_state(
                contrastive_loss
            )
            self.loss_trackers["regularization_loss"].update_state(
                regularization_loss
            )

        # Collect metrics to return
        return_metrics = {}
        for metric in self.metrics:
            result = metric.result()
            if isinstance(result, dict):
                return_metrics.update(result)
            else:
                return_metrics[metric.name] = result
        return return_metrics