def plot_latent_embeddings()

in compert/plotting.py [0:0]


    def plot_latent_embeddings(self,
                                emb,
                                titlename='Example',
                                kind='perturbations',
                                palette=None,
                                labels=None,
                                dimred='KernelPCA',
                                filename=None,
                                show_text=True
                                ):
        """
        Parameters
        ----------
        emb : np.array
            Multi-dimensional embedding of perturbations or covariates.
        titlename : str, optional (default: 'Example')
            Title.
        kind : int, optional, optional (default: 'perturbations')
            Specify if this is embedding of perturbations, covariates or some
            other. If it is perturbations or covariates, it will use default
            saved dictionaries for colors.
        palette : dict, optional (default: None)
            If embedding of kind not perturbations or covariates, the user can
            specify color dictionary for the embedding.
        labels : list, optional (default: None)
            Labels for the embeddings.
        dimred : str, optional (default: 'KernelPCA')
            Dimensionality reduction method for plotting low dimensional
            representations. Options: 'KernelPCA', 'UMAPpre', 'UMAPcos', None.
            If None, uses first 2 dimensions of the embedding.
        filename : str (default: None)
            Name of the file to save the plot. If None, will automatically
            generate name from prefix file.
        """
        if filename is None:
            if self.fileprefix is None:
                filename = None
                file_name_similarity = None
            else:
                filename = f'{self.fileprefix}_emebdding.png'
                file_name_similarity=f'{self.fileprefix}_emebdding_similarity.png'
        else:
            file_name_similarity = filename.split('.')[0] + '_similarity.png'

        if (labels is None):
            if kind == 'perturbations':
                palette = self.perts_palette
                labels = self.unique_perts
            elif kind == 'covars':
                palette = self.сovars_palette
                labels = self.unique_сovars

        if len(emb) < 2:
            print(f'Embedding contains only {len(emb)} vectors. Not enough to plot.')
        else:
            plot_embedding(
                    fast_dimred(emb, method=dimred),
                    labels,
                    show_lines=True,
                    show_text=show_text,
                    col_dict=palette,
                    title=titlename,
                    file_name=filename,
                    fontsize=self.fontsize
                    )

            plot_similarity(
                    emb,
                    labels,
                    col_dict=palette,
                    fontsize=self.fontsize,
                    file_name=file_name_similarity
                    )