def forward()

in src/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py [0:0]


    def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs):
        """
        Args:
            x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain.
            return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False.
            denorm (bool, optional): Whether to denormalize the input image. Defaults to False.
            return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False.
        
        Returns:
            dict: Dictionary of outputs with keys:
                - "rel_depth": Relative depth map of shape (B, 1, H, W)
                - "metric_depth": Metric depth map of shape (B, 1, H, W)
                - "domain_logits": Domain logits of shape (B, 2)
                - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True
                - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True
        """
        b, c, h, w = x.shape
        self.orig_input_width = w
        self.orig_input_height = h
        rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True)

        outconv_activation = out[0]
        btlnck = out[1]
        x_blocks = out[2:]

        x_d0 = self.conv2(btlnck)
        x = x_d0

        # Predict which path to take
        embedding = self.patch_transformer(x)[0]  # N, E
        domain_logits = self.mlp_classifier(embedding)  # N, 2
        domain_vote = torch.softmax(domain_logits.sum(
            dim=0, keepdim=True), dim=-1)  # 1, 2

        # Get the path
        bin_conf_name = ["nyu", "kitti"][torch.argmax(
            domain_vote, dim=-1).squeeze().item()]

        try:
            conf = [c for c in self.bin_conf if c.name == bin_conf_name][0]
        except IndexError:
            raise ValueError(
                f"bin_conf_name {bin_conf_name} not found in bin_confs")

        min_depth = conf['min_depth']
        max_depth = conf['max_depth']

        seed_bin_regressor = self.seed_bin_regressors[bin_conf_name]
        _, seed_b_centers = seed_bin_regressor(x)
        if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2':
            b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth)
        else:
            b_prev = seed_b_centers
        prev_b_embedding = self.seed_projector(x)

        attractors = self.attractors[bin_conf_name]
        for projector, attractor, x in zip(self.projectors, attractors, x_blocks):
            b_embedding = projector(x)
            b, b_centers = attractor(
                b_embedding, b_prev, prev_b_embedding, interpolate=True)
            b_prev = b
            prev_b_embedding = b_embedding

        last = outconv_activation

        b_centers = nn.functional.interpolate(
            b_centers, last.shape[-2:], mode='bilinear', align_corners=True)
        b_embedding = nn.functional.interpolate(
            b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)

        clb = self.conditional_log_binomial[bin_conf_name]
        x = clb(last, b_embedding)

        # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
        # print(x.shape, b_centers.shape)
        # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
        out = torch.sum(x * b_centers, dim=1, keepdim=True)

        output = dict(domain_logits=domain_logits, metric_depth=out)
        if return_final_centers or return_probs:
            output['bin_centers'] = b_centers

        if return_probs:
            output['probs'] = x
        return output