def sample_notes_from_model()

in ar-cnn/inference.py [0:0]


    def sample_notes_from_model(self,
                                input_tensor,
                                max_original_notes_to_remove,
                                max_notes_to_add,
                                temperature,
                                notes_removed_count,
                                notes_added_count,
                                original_input_one_indices,
                                original_input_zero_indices,
                                current_input_zero_indices,
                                current_input_one_indices,
                                num_notes=1):
        """
        Generates a sample from the tensor and return a new tensor
        Modifies current_input_zero_indices, current_input_one_indices, and input_tensor

        Parameters
        ----------
        input_tensor : 2d numpy array
            input tensor to feed into the model
        max_original_notes_to_remove : int
            maximum number of notes to remove from the original input
        max_notes_to_add : int
            maximum number of notes that can be added to the original input
        temperature : float
            temperature to apply before softmax during inference
        notes_removed_count : int
            number of original notes that have been removed from input
        notes_added_count : int
            number of new notes that have been added to the input
        original_input_one_indices : set of tuples
            indices which have value 1 in original input
        original_input_zero_indices : set of tuples
            indices which have value 0 in original input
        current_input_zero_indices : set of tuples
            indices which have value 0 and were not part of the original input
        current_input_one_indices : set of tuples
            indices which have value 1 and were part of the original input

        Returns
        -------
        input_tensor : 2d numpy array
            output after samping from the model prediction
        notes_removed_count : int
            updated number of original notes removed
        notes_added_count : int
            updated number of new notes added
        """

        output_tensor = self.model.predict([input_tensor])

        # Apply temperature and softmax
        output_tensor = self.get_softmax(output_tensor, temperature)

        if notes_removed_count >= max_original_notes_to_remove:
            # Mask all pixels that both have a note and were once part of the original input
            output_tensor = self.mask_not_allowed_notes(current_input_one_indices, output_tensor)

        if notes_added_count > max_notes_to_add:
            # Mask all pixels that both do not have a note and were not once part of the original input
            output_tensor = self.mask_not_allowed_notes(current_input_zero_indices, output_tensor)

        if np.count_nonzero(output_tensor) == 0:
            return input_tensor, notes_removed_count, notes_added_count

        sampled_index = self.get_sampled_index(output_tensor)
        sampled_index_transpose = tuple(np.array(sampled_index).T[0])

        if input_tensor[sampled_index]:
            # Check if the note being removed is from the original input
            if notes_removed_count < max_original_notes_to_remove and (
                sampled_index_transpose in original_input_one_indices):
                notes_removed_count += 1
                current_input_one_indices.remove(sampled_index_transpose)
            elif tuple(sampled_index_transpose) not in original_input_one_indices:
                notes_added_count -= 1
                current_input_zero_indices.add(sampled_index_transpose)
            input_tensor[sampled_index] = 0
        else:
            # Check if the note being added is not in original input
            if sampled_index_transpose not in original_input_one_indices:
                notes_added_count += 1
                current_input_zero_indices.remove(sampled_index_transpose)
            else:
                notes_removed_count -= 1
                current_input_one_indices.add(sampled_index_transpose)
            input_tensor[sampled_index] = 1
        input_tensor = input_tensor.astype(np.bool_)
        return input_tensor, notes_removed_count, notes_added_count