def new_forward()

in src/diffusers/hooks/faster_cache.py [0:0]


    def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
        # Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
        # requirements for skipping the unconditional branch are met as described in the paper.
        # We skip the unconditional branch only if the following conditions are met:
        #   1. We have completed at least one iteration of the denoiser
        #   2. The current timestep is within the range specified by the user. This is the optimal timestep range
        #      where approximating the unconditional branch from the computation of the conditional branch is possible
        #      without a significant loss in quality.
        #   3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
        #      we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
        is_within_timestep_range = (
            self.unconditional_batch_timestep_skip_range[0]
            < self.current_timestep_callback()
            < self.unconditional_batch_timestep_skip_range[1]
        )
        should_skip_uncond = (
            self.state.iteration > 0
            and is_within_timestep_range
            and self.state.iteration % self.unconditional_batch_skip_range != 0
            and not self.is_guidance_distilled
        )

        if should_skip_uncond:
            is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
            if is_any_kwarg_uncond:
                logger.debug("FasterCache - Skipping unconditional branch computation")
                args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
                kwargs = {
                    k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
                    for k, v in kwargs.items()
                }

        output = self.fn_ref.original_forward(*args, **kwargs)

        if self.is_guidance_distilled:
            self.state.iteration += 1
            return output

        if torch.is_tensor(output):
            hidden_states = output
        elif isinstance(output, (tuple, Transformer2DModelOutput)):
            hidden_states = output[0]

        batch_size = hidden_states.size(0)

        if should_skip_uncond:
            self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
                module
            )
            self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
                module
            )

            if self.tensor_format == "BCFHW":
                hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
            if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
                hidden_states = hidden_states.flatten(0, 1)

            low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())

            # Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
            low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
            high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
            uncond_freq = low_freq_uncond + high_freq_uncond

            uncond_states = torch.fft.ifftshift(uncond_freq)
            uncond_states = torch.fft.ifft2(uncond_states).real

            if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
                uncond_states = uncond_states.unflatten(0, (batch_size, -1))
                hidden_states = hidden_states.unflatten(0, (batch_size, -1))
            if self.tensor_format == "BCFHW":
                uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
                hidden_states = hidden_states.permute(0, 2, 1, 3, 4)

            # Concatenate the approximated unconditional and predicted conditional branches
            uncond_states = uncond_states.to(hidden_states.dtype)
            hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
        else:
            uncond_states, cond_states = hidden_states.chunk(2, dim=0)
            if self.tensor_format == "BCFHW":
                uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
                cond_states = cond_states.permute(0, 2, 1, 3, 4)
            if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
                uncond_states = uncond_states.flatten(0, 1)
                cond_states = cond_states.flatten(0, 1)

            low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
            low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
            self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
            self.state.high_frequency_delta = high_freq_uncond - high_freq_cond

        self.state.iteration += 1
        if torch.is_tensor(output):
            output = hidden_states
        elif isinstance(output, tuple):
            output = (hidden_states, *output[1:])
        else:
            output.sample = hidden_states

        return output