text/data/smoltalk/rewrite/pipeline/dataset.py (90 lines of code) (raw):

from typing import TYPE_CHECKING, Any from datasets import concatenate_datasets, load_dataset if TYPE_CHECKING: from datasets import Dataset REWRITE_CONCISE_SYSTEM_PROMPT = "You're an AI assistant for text re-writing. Rewrite the input text to make it more concise while preserving its core meaning." REWRITE_FRIENDLY_SYSTEM_PROMPT = "You're an AI assistant for text re-writing. Rewrite the input text to make it more friendly and approachable while maintaining its main points." REWRITE_PROFESSIONAL_SYSTEM_PROMPT = "You're an AI assistant for text re-writing. Rewrite the input text to make it more professional and formal while retaining its essential content." CONFIG_NAME_TO_SYSTEM_PROMPT = { "default": REWRITE_CONCISE_SYSTEM_PROMPT, "unfriendly_email_conversations": REWRITE_FRIENDLY_SYSTEM_PROMPT, "unprofessional_email_conversations": REWRITE_PROFESSIONAL_SYSTEM_PROMPT, } def get_dataset() -> "Dataset": concise = get_finepersonas_emails_dataset("default") unfriendly = get_finepersonas_emails_dataset("unfriendly_email_conversations") unprofessional = get_finepersonas_emails_dataset( "unprofessional_email_conversations" ) linkedin_posts_concise = get_finepersonas_linkedin_posts( system_prompt=REWRITE_CONCISE_SYSTEM_PROMPT ) linkedin_posts_professional = get_finepersonas_linkedin_posts( system_prompt=REWRITE_PROFESSIONAL_SYSTEM_PROMPT ) tweets_professional = get_finepersonas_tweets( system_prompt=REWRITE_PROFESSIONAL_SYSTEM_PROMPT ) return concatenate_datasets( [ concise, unfriendly, unprofessional, linkedin_posts_concise, linkedin_posts_professional, tweets_professional, ] ) def get_finepersonas_emails_dataset( config_name: str, n: int | None = None ) -> "Dataset": def get_first_email(row: dict[str, Any]) -> dict[str, Any]: formatted_emails = row["formatted_emails"] if not formatted_emails: return {"system_prompt": "", "instruction": "", "kind_of_content": ""} email = formatted_emails[0] return { "system_prompt": CONFIG_NAME_TO_SYSTEM_PROMPT[config_name], "instruction": email["body"], "kind_of_content": "email", } dataset: "Dataset" = load_dataset( # type: ignore "argilla/FinePersonas-Synthetic-Email-Conversations", name=config_name, split="train", ) dataset = dataset.map(get_first_email, remove_columns=dataset.column_names) dataset = dataset.filter(lambda x: x["system_prompt"] != "") if n is not None: dataset = select_n(dataset, n) return dataset def get_finepersonas_linkedin_posts( n: int | None = None, system_prompt: str = REWRITE_CONCISE_SYSTEM_PROMPT ) -> "Dataset": dataset: "Dataset" = load_dataset( # type: ignore "argilla-warehouse/FinePersonas-Extra-Stuff", name="linkedin_posts", split="train", ) dataset = dataset.select_columns(["formatted_linkedin_post"]) dataset = dataset.add_column("system_prompt", [system_prompt] * len(dataset)) dataset = dataset.rename_column("formatted_linkedin_post", "instruction") dataset = dataset.add_column("kind_of_content", ["linkedin_post"] * len(dataset)) if n is not None: dataset = select_n(dataset, n) return dataset def get_finepersonas_tweets( n: int | None = None, system_prompt: str = REWRITE_CONCISE_SYSTEM_PROMPT ) -> "Dataset": dataset: "Dataset" = load_dataset( # type: ignore "argilla-warehouse/FinePersonas-Extra-Stuff", name="tweets", split="train" ) dataset = dataset.select_columns(["formatted_tweet"]) dataset = dataset.add_column("system_prompt", [system_prompt] * len(dataset)) dataset = dataset.rename_column("formatted_tweet", "instruction") dataset = dataset.add_column("kind_of_content", ["tweet"] * len(dataset)) if n is not None: dataset = select_n(dataset, n) return dataset def select_n(dataset: "Dataset", n: int) -> "Dataset": return dataset.select(range(n))