def combine()

in shap_e/models/nerf/ray.py [0:0]


    def combine(self, cur: "RayVolumeIntegralResults") -> "RayVolumeIntegralResults":
        """
        Combines the integration results of `self` over [t0, t1] and
        `cur` over [t1, t2] to produce a new set of results over [t0, t2] by
        using a similar equation to (4) in NeRF++:

            integrate(
                lambda t: density(t) * channels(t) * transmittance(t),
                [t0, t2]
            )

          = integrate(
                lambda t: density(t) * channels(t) * transmittance(t),
                [t0, t1]
            ) + transmittance(t1) * integrate(
                lambda t: density(t) * channels(t) * transmittance(t),
                [t1, t2]
            )
        """
        assert torch.allclose(self.volume_range.next_t0(), cur.volume_range.t0)

        def _combine_fn(
            prev_val: Optional[torch.Tensor],
            cur_val: Optional[torch.Tensor],
            *,
            prev_transmittance: torch.Tensor,
        ):
            assert prev_val is not None
            if cur_val is None:
                # cur_output.aux_losses are empty for the void_model.
                return prev_val
            return prev_val + prev_transmittance * cur_val

        output = self.output.combine(
            cur.output, combine_fn=partial(_combine_fn, prev_transmittance=self.transmittance)
        )

        combined = RayVolumeIntegralResults(
            output=output,
            volume_range=self.volume_range.extend(cur.volume_range),
            transmittance=self.transmittance * cur.transmittance,
        )
        return combined