def generate()

in models/wavenet.py [0:0]


    def generate(self, spectrograms: Tensor, training: bool = False) -> Tensor:
        """
        Generate a sample from this model.

        Returns:
          A 1D float tensor containing the output waveform.
        """
        self.model.eval()
        self.model.module.clear_buffer()

        with torch.no_grad():
            spectrograms = self.model.module.upsample_net(spectrograms)
            seq_len = (
                22050 if training else spectrograms.size(-1)
            )  # synthesize the first second only during training
            batch_size = spectrograms.shape[0]
            spectrograms = spectrograms.transpose(1, 2).contiguous()

            if self.scalar_input:
                x = spectrograms.new_zeros(batch_size, 1, 1)
            else:
                x = spectrograms.new_zeros(
                    batch_size, 1, self.config.model.out_channels
                )

            output = []
            for t in tqdm(range(seq_len)):
                # Conditioning features for single time step
                ct = spectrograms[:, t, :].unsqueeze(1)
                x = self.model.module.first_conv.incremental_forward(x)
                skips = 0
                for f in self.model.module.conv_layers:
                    x, h = f.incremental_forward(x, ct, None)
                    skips += h
                skips *= math.sqrt(1.0 / len(self.model.module.conv_layers))
                x = skips
                for f in self.model.module.last_conv_layers:
                    try:
                        x = f.incremental_forward(x)
                    except AttributeError:
                        x = f(x)
                x, output = self.get_x_from_dist(
                    self.config.model.output_distribution,
                    x,
                    history=output,
                    B=batch_size,
                )
        output = torch.stack(output).transpose(0, 1)

        if self.config.model.input_type in ["mulaw", "mulaw-quantize"]:
            if self.config.model.input_type == "mulaw-quantize":
                output = torch.argmax(output, dim=2)
            else:
                output = self.float_2_label(output, self.config.model.quantize_channels)
            output = self.expand(output.long())
        elif self.config.model.input_type != "raw":
            raise RuntimeError(
                "Not supported input type: {}".format(self.config.model.input_type)
            )
        self.model.module.clear_buffer()
        self.model.train()

        return output.flatten()