def sample_multiple()

in ar-cnn/inference.py [0:0]


    def sample_multiple(self, input_tensor, temperature,
                        max_removal_percentage, max_notes_to_add,
                        number_of_iterations):
        """
        Samples multiple times from an tensor.
        Returns the final output tensor after X number of iterations.

        Parameters
        ----------
        input_tensor : 2d numpy array
            original tensor (i.e. user input melody)
        temperature : float
            temperature to apply before softmax during inference
        max_removal_percentage : float
            maximum percentage of notes that can be removed from the original input
        max_notes_to_add : int
            maximum number of notes that can be added to the original input
        number_of_iterations : int
            number of iterations to sample from the model predictions

        Returns
        -------
        2d numpy array
            output tensor (i.e. new composition)
        """

        max_original_notes_to_remove = int(
            max_removal_percentage * np.count_nonzero(input_tensor) / 100)
        notes_removed_count = 0
        notes_added_count = 0

        original_input_one_indices = self.get_indices(input_tensor, 1)
        original_input_zero_indices = self.get_indices(input_tensor, 0)

        current_input_one_indices = copy.deepcopy(original_input_one_indices)
        current_input_zero_indices = copy.deepcopy(original_input_zero_indices)

        for _ in range(number_of_iterations):
            input_tensor, notes_removed_count, notes_added_count = self.sample_notes_from_model(
                input_tensor, max_original_notes_to_remove, max_notes_to_add,
                temperature, notes_removed_count, notes_added_count,
                original_input_one_indices, original_input_zero_indices,
                current_input_zero_indices, current_input_one_indices)

        return input_tensor.reshape(self.number_of_timesteps,
                                    Constants.number_of_pitches)