community-efforts/prompt_translation/Translation_with_distilabel_gpt_4_turbo.ipynb (220 lines of code) (raw):
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Setup"
],
"metadata": {
"id": "mTYjyCl_1dAO"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MZhTFpbXzPYM"
},
"outputs": [],
"source": [
"HF_ORG_NAME = None # update with the ID of the org you just created\n",
"LANGUAGE = None # update this with the language you will work on"
]
},
{
"cell_type": "code",
"source": [
"assert HF_ORG_NAME is not None, \"Please set HF_ORG_NAME to the ID of the Hugging Face org you just created\"\n",
"assert LANGUAGE is not None, \"Please set LANGUAGE to the language your effort focuses on\""
],
"metadata": {
"id": "TVZF5-b3zRBJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import argilla as rg\n",
"\n",
"OWNER_API_KEY = \"owner.apikey\" # if you haven't setup the secret this is the default owner api key\n",
"assert OWNER_API_KEY is not None, \"Please set OWNER_API_KEY to the API token you just set in the Space settings\"\n",
"\n",
"rg.init(api_url=homepage_url, api_key=OWNER_API_KEY)"
],
"metadata": {
"id": "NdTtXc_v1YBD"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from openai import OpenAI\n",
"from google.colab import userdata\n",
"\n",
"from distilabel.llm.openai import OpenAILLM\n",
"from distilabel.tasks import TextGenerationTask\n",
"from distilabel.pipeline import Pipeline"
],
"metadata": {
"id": "cQG-OX9DzWmA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Get original dataset and translate it\n",
"\n",
"This assumes you have already pushed the untranslated dataset"
],
"metadata": {
"id": "nB9Mquww1gcD"
}
},
{
"cell_type": "code",
"source": [
"# let's load the dataset and prepare the source col for distilabel\n",
"argilla_ds = rg.FeedbackDataset.from_argilla(f\"DIBT Translation for {LANGUAGE}\", workspace=\"admin\")\n",
"hf_ds = argilla_ds.format_as(\"datasets\").rename_columns({'source': \"input\"})"
],
"metadata": {
"id": "WBwjwNdq0LN-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"api_key=userdata.get(\"OPENAI_API_KEY\")\n",
"\n",
"target_lang = \"Spanish\" # change this with your target language name\n",
"\n",
"llm = OpenAILLM(\n",
" model=\"gpt-4-0613\", # gpt4-turbo\n",
" api_key=api_key,\n",
" task=TextGenerationTask(system_prompt=f\"You will be provided with a text in English, and your task is to translate it into {target_lang}. If it's code please don't translate the actual code, only the comments and the explanation.\"),\n",
" num_threads=8,\n",
" max_new_tokens=8192,\n",
")\n",
"\n",
"pipe = Pipeline(\n",
" generator=llm\n",
")"
],
"metadata": {
"id": "BygNfRFyzYWv"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# test everything is working so far\n",
"ds = pipe.generate(\n",
" dataset=hf_ds.select(range(10)),\n",
" batch_size=4,\n",
" display_progress_bar=True\n",
")\n",
"# check the translations before running the full pipeline\n",
"ds.to_pandas().head(5)"
],
"metadata": {
"id": "ZdeX71YdzbX_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# if everything is working as expected, run with the full dataset\n",
"ds = pipe.generate(\n",
" dataset=hf_ds,\n",
" batch_size=4,\n",
" display_progress_bar=True\n",
")"
],
"metadata": {
"id": "SGdugR9kzf79"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Update the translations in the Argilla Space\n"
],
"metadata": {
"id": "18GUbdg01lD4"
}
},
{
"cell_type": "code",
"source": [
"translations = [gen[0] for gen in ds['generations']]\n",
"len(translations)"
],
"metadata": {
"id": "yukaSFwFzk27"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"altered_records = []\n",
"\n",
"for rec, translation in zip(argilla_ds.records, translations):\n",
" rec.suggestions = [\n",
" {\n",
" \"question_name\": \"target\",\n",
" \"value\": translation\n",
" }\n",
" ]\n",
" altered_records.append(rec)\n",
"\n",
"altered_records[0]"
],
"metadata": {
"id": "IJWw41v4zndL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"argilla_ds.update_records(altered_records)"
],
"metadata": {
"id": "IgkY5M4oztQz"
},
"execution_count": null,
"outputs": []
}
]
}