translation/translation_training_data_tsv_generator.ipynb (790 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2eec5cc39a59"
},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "594c6f39b5d1"
},
"source": [
"# Generate training dataset for Cloud Translation API NMT (Neural Machine Translation) model training\n",
"\n",
"<table align=\"left\">\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb\">\n",
" <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Ftranslation%2Ftranslation_training_data_tsv_generator.ipynb\">\n",
" <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/translation/translation_training_data_tsv_generator.ipynb\">\n",
" <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
" </a>\n",
" </td>\n",
" <td style=\"text-align: center\">\n",
" <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb\">\n",
" <img width=\"32px\" src=\"https://www.svgrepo.com/download/217753/github.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
" </a>\n",
" </td>\n",
"</table>\n",
"\n",
"<div style=\"clear: both;\"></div>\n",
"\n",
"<b>Share to:</b>\n",
"\n",
"<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
"</a>\n",
"\n",
"<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb\" target=\"_blank\">\n",
" <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
"</a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "86c1d4a789c2"
},
"source": [
"| | |\n",
"|-|-|\n",
"|Author | [Abhijat Gupta](https://github.com/abhijat-gupta)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "25e10371ed0e"
},
"source": [
"## **Overview**\n",
"\n",
"[Cloud Translation API](https://cloud.google.com/translate/docs) uses Google's neural machine translation technology to let you dynamically translate text through the API using a Google pre-trained, custom model, or a translation specialized large language model (LLMs). \n",
"\n",
"It comes in [Basic and Advanced](https://cloud.google.com/translate/docs/editions) editions. Both provide fast and dynamic translation, but Advanced offers customization features, such as domain-specific translation, formatted document translation, and batch translation.\n",
"\n",
"[AutoML Translation](https://cloud.google.com/translate/docs/advanced/automl-beginner) lets you build custom models (without writing code) that are tailored for your domain-specific content compared to the default Google Neural Machine Translation (NMT) model\n",
"\n",
"The first 500,000 characters sent to the API to process (Basic and Advanced combined) per month are free (not applicable to LLMs).\n",
"\n",
"## Objective\n",
"\n",
"### Key Features\n",
"1. Paragraphs are converted into line-pairs of less than 200 words.\n",
"2. Tables in documents are converted into a line-pair with each row as a separate line-pair.\n",
"3. Limit of 200 words per line is handled.\n",
"4. Empty or blank lines are not added to the TSV.\n",
"\n",
"This notebook enables you to generate a TSV file out of documents (docx) for training NMT (neural machine translation) model. The generated TSV file will contain the source and target line pairs for 2 languages in 2 columns respectively. Limit of 200 words for a line is handled within the code. Example: If a line is exceeding 200 words, it won't be added to the training dataset, but will be captured and returned in a dictionary so that you can decide on how to convert it to line-pair of less than 200 words.\n",
"The code also removes any blank or empty lines in a document from both source and reference before making line-pairs. This makes sure that both the documents do not mismatch with line-pairs due to empty lines.\n",
"\n",
"\n",
"## How to use the notebook\n",
"\n",
"##### input: a dictionary containing source and reference GCS paths.\n",
"\n",
"##### output: a single TSV file, 2 dictionaries\n",
"\n",
"##### Steps to follow:\n",
"- Provide as many source and reference files in the input dictionary: `source_ref_dictionary`, *key* being the source file path and reference file path as its *value*\n",
"- Trigger all the cells after providing the input.\n",
"- The TSV gets created in your local path.\n",
"\n",
"\n",
"\n",
"## Costs\n",
"\n",
"Learn about [Translation pricing](https://cloud.google.com/translate/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "628e815b6b1f"
},
"source": [
"## **Getting Started**\n",
"### Install docx SDK for Python"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "a6a9e6e0448d"
},
"outputs": [],
"source": [
"%pip install --proxy \"\" docx --quiet\n",
"%pip install --proxy \"\" python-docx --quiet"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2d4000d88ad8"
},
"source": [
"### Restart kernel"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0c5492fd0156"
},
"outputs": [],
"source": [
"# Restart kernel after installs so that your environment can access the new packages\n",
"import IPython\n",
"\n",
"app = IPython.Application.instance()\n",
"app.kernel.do_shutdown(True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6582b5d47c28"
},
"source": [
"### Authenticate your notebook environment (Colab only)\n",
"\n",
"If you are running this notebook on Google Colab, run the following cell to authenticate your environment. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-notebooks?hl=en)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4788c6f28f01"
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"# Additional authentication is required for Google Colab\n",
"if \"google.colab\" in sys.modules:\n",
" # Authenticate user to Google Cloud\n",
" from google.colab import auth\n",
"\n",
" auth.authenticate_user()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9ccc6635848a"
},
"source": [
"### imports"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"id": "9eb336b1b801"
},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"\n",
"import docx\n",
"from docx.document import Document as _Document\n",
"from docx.oxml.table import CT_Tbl\n",
"from docx.oxml.text.paragraph import CT_P\n",
"from docx.table import Table, _Cell\n",
"from docx.text.paragraph import Paragraph\n",
"import google.auth\n",
"from google.auth.credentials import Credentials\n",
"from google.cloud import storage\n",
"import requests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bf8a23062635"
},
"source": [
"### output TSV file name"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"id": "df55cc8143be"
},
"outputs": [],
"source": [
"# file name for the output tabular TSV.\n",
"tsv_file_name = \"your_tsv_file_name.tsv\" # @param {type:\"string\"}\n",
"PROJECT_ID = \"your project id\" # @param {type:\"string\"}\n",
"LOCATION = \"us-central1\" # @param {type:\"string\"}\n",
"DEFAULT_SOURCE_LANG_CODE = \"<source_language>\" # @param {type:\"string\"}\n",
"DEFAULT_DATASET_PREFIX = \"<your_dataset_prefix>\" # @param {type:\"string\"}\n",
"DEFAULT_DATASET_SUFFIX = \"<your_dataset_suffix>\" # @param {type:\"string\"}\n",
"\n",
"url = (\n",
" f\"https://translation.googleapis.com/v3/projects/{PROJECT_ID}/locations/{LOCATION}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a072ff983a3c"
},
"source": [
"### source and reference paths"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"id": "51d246d095b8"
},
"outputs": [],
"source": [
"source_ref_dictionary = {\n",
" \"source_path1.docx\": \"reference_path1.docx\",\n",
" \"source_path2.docx\": \"reference_path2.docx\",\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7a8d61c373c5"
},
"source": [
"### Generate TSV"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"id": "a56ca18f530a"
},
"outputs": [],
"source": [
"def get_document_objects(\n",
" src_path: str, ref_path: str, source_bucket_name: str\n",
") -> tuple[_Document, _Document]:\n",
" \"\"\"Fetches a source document and its translated/reference version from GCS bucket.\"\"\"\n",
"\n",
" client = storage.Client()\n",
" ref_file_name = ref_path.split(source_bucket_name + \"/\")[1]\n",
" file_name = src_path.split(source_bucket_name + \"/\")[1]\n",
"\n",
" try:\n",
" bucket = client.get_bucket(source_bucket_name)\n",
" src_blob = bucket.get_blob(file_name)\n",
" ref_blob = bucket.get_blob(ref_file_name)\n",
" except TypeError as te:\n",
" return te\n",
"\n",
" src_file_downloaded_name = file_name.split(\"source/\")[1]\n",
" ref_file_downloaded_name = ref_file_name.split(\"reference/\")[1]\n",
"\n",
" src_filepath = os.path.join(os.getcwd(), src_file_downloaded_name + \"_local.docx\")\n",
" ref_filepath = os.path.join(os.getcwd(), ref_file_downloaded_name + \"_local.docx\")\n",
"\n",
" with open(src_filepath, \"wb\") as src_f:\n",
" src_blob.download_to_file(src_f)\n",
" src_f.close()\n",
"\n",
" with open(ref_filepath, \"wb\") as ref_f:\n",
" ref_blob.download_to_file(ref_f)\n",
" ref_f.close()\n",
"\n",
" source = docx.Document(src_filepath)\n",
" reference = docx.Document(ref_filepath)\n",
"\n",
" return source, reference\n",
"\n",
"\n",
"def iter_block_items(parent: _Document) -> Paragraph or Table:\n",
" \"\"\"\n",
" Generate a reference to each paragraph and table child within *parent*,\n",
" in document order. Each returned value is an instance of either Table or\n",
" Paragraph. *parent* would most commonly be a reference to a main\n",
" Document object, but also works for a _Cell object, which itself can\n",
" contain paragraphs and tables.\n",
" \"\"\"\n",
" if isinstance(parent, _Document):\n",
" parent_elm = parent.element.body\n",
" elif isinstance(parent, _Cell):\n",
" parent_elm = parent._tc\n",
" elif isinstance(parent, _Row):\n",
" parent_elm = parent._tr\n",
" else:\n",
" raise ValueError(\"something's not right\")\n",
" for child in parent_elm.iterchildren():\n",
" if isinstance(child, CT_P):\n",
" yield Paragraph(child, parent)\n",
" elif isinstance(child, CT_Tbl):\n",
" yield Table(child, parent)\n",
"\n",
"\n",
"def make_tsv(source_ref_dictionary: dict, tsv_file_name: str) -> tuple[dict, dict]:\n",
" \"\"\"\n",
" - This function reads the source and reference/translated documents from local paths iteratively, block-by-block.\n",
" - A page blocks can be: Paragraphs and Tables.\n",
" - In order to generate correct pairs, the type of blocks should be same for both source and reference.\n",
" - If a block don't match, it get captured in mismatched_block dictionary and will not be added to the TSV. The Iteration stops and a TSV is created uptill the matching blocks.\n",
" - ONLY docx format is supported.\n",
" - Creates and saves the TSV in local path(Can be configured to save in GCS bucket).\n",
" - Returns the mismatched blocks from the documents as a dictionary.\n",
" \"\"\"\n",
"\n",
" for src_path, ref_path in source_ref_dictionary.items():\n",
" if src_path is None or src_path == \"\":\n",
" return \"source file path is invalid.\"\n",
" if ref_path is None or ref_path == \"\":\n",
" return \"translated/reference file path is invalid.\"\n",
" if src_path.split(\".\", -1)[::-1][0] != ref_path.split(\".\", -1)[::-1][0]:\n",
" return \"source and translated versions are in different format.\"\n",
"\n",
" tsv_file = os.path.join(os.getcwd(), tsv_file_name)\n",
" if \".pdf\" in src_path.split(src_path.split(\"gs://\")[1].split(\"/\")[0] + \"/\")[1]:\n",
" return \"PDFs are not supported. Process exited.\"\n",
"\n",
" try:\n",
" mismatched_block = {}\n",
" more_than_200_words = {}\n",
" for source_path, reference_path in source_ref_dictionary.items():\n",
" source_bucket_name = source_path.split(\"gs://\")[1].split(\"/\")[0]\n",
" source, reference = get_document_objects(\n",
" source_path, reference_path, source_bucket_name\n",
" )\n",
"\n",
" with open(tsv_file, \"a\") as tsv_f:\n",
" for para in source.paragraphs:\n",
" if len(para.text.strip()) == 0:\n",
" p = para._element\n",
" p.getparent().remove(p)\n",
" p._p = p._element = None\n",
" for para in reference.paragraphs:\n",
" if len(para.text.strip()) == 0:\n",
" p = para._element\n",
" p.getparent().remove(p)\n",
" p._p = p._element = None\n",
"\n",
" for src_block, ref_block in zip(\n",
" iter_block_items(source), iter_block_items(reference)\n",
" ):\n",
" if (\n",
" isinstance(src_block, Paragraph)\n",
" and isinstance(ref_block, Paragraph)\n",
" and src_block.text is not None\n",
" and ref_block.text is not None\n",
" ):\n",
" try:\n",
" tsv_f.write(src_block.text + \"\\t\" + ref_block.text)\n",
" tsv_f.write(\"\\n\")\n",
" except Exception as e:\n",
" print(e)\n",
" elif isinstance(src_block, Table) and isinstance(ref_block, Table):\n",
" try:\n",
" for src_row, ref_row in zip(src_block.rows, ref_block.rows):\n",
" src_row_data = []\n",
" ref_row_data = []\n",
" for cell in src_row.cells:\n",
" for paragraph in cell.paragraphs:\n",
" src_row_data.append(paragraph.text)\n",
" for cell in ref_row.cells:\n",
" for paragraph in cell.paragraphs:\n",
" ref_row_data.append(paragraph.text)\n",
" if len(src_row_data) >= 200 or len(ref_row_data) >= 200:\n",
" print(\n",
" \"Length of a pair detected to be greater than 200 words.\"\n",
" )\n",
" print(\"this pair will be skipped\")\n",
" more_than_200_words[\" \".join(src_row_data)] = (\n",
" \" \".join(ref_row_data)\n",
" )\n",
" else:\n",
" tsv_f.write(\n",
" \" \".join(src_row_data)\n",
" + \"\\t\"\n",
" + \" \".join(ref_row_data)\n",
" )\n",
" tsv_f.write(\"\\n\")\n",
" except Exceptio as e:\n",
" print(e)\n",
" else:\n",
" try:\n",
" mismatched_block[src_block.text] = ref_block\n",
" except:\n",
" mismatched_block[src_block] = ref_block.text\n",
" break\n",
"\n",
" tsv_f.close()\n",
" print(f\"Generated TSV stored at {tsv_file}\")\n",
" return mismatched_block, more_than_200_words\n",
" except Exception as e:\n",
" print(e)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"id": "c242ace80a3b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generated TSV stored at /home/jupyter/src/your_tsv_file_name.tsv\n"
]
}
],
"source": [
"mismatched_block, more_than_200_words = make_tsv(source_ref_dictionary, tsv_file_name)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"id": "79de10f3c921"
},
"outputs": [
{
"data": {
"text/plain": [
"{}"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mismatched_block"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"id": "75fecf1a251a"
},
"outputs": [
{
"data": {
"text/plain": [
"{}"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"more_than_200_words"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8ecb64ddb0cd"
},
"source": [
"## Custom model training"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {
"id": "8297f0a3814f"
},
"outputs": [],
"source": [
"def generate_access_token() -> Credentials:\n",
" \"\"\"Generates access token to call translate APIs.\"\"\"\n",
" creds, project = google.auth.default()\n",
"\n",
" auth_req = google.auth.transport.requests.Request()\n",
" creds.refresh(auth_req)\n",
" return creds.token\n",
"\n",
"\n",
"def create_dataset(\n",
" target_lang_code: str,\n",
" url: str,\n",
" source_lang_code: str | None = DEFAULT_SOURCE_LANG_CODE,\n",
") -> dict or None:\n",
" \"\"\"Creates a dataset.\"\"\"\n",
" ACCESS_TOKEN = generate_access_token()\n",
" headers = {\n",
" \"Authorization\": f\"Bearer {ACCESS_TOKEN}\",\n",
" \"Content-Type\": \"application/json; charset=UTF-8\",\n",
" }\n",
"\n",
" if DEFAULT_DATASET_SUFFIX != \"\" and DEFAULT_DATASET_SUFFIX is not None:\n",
" dataset_display_name = f\"{DEFAULT_DATASET_PREFIX}_{source_lang_code}_to_{target_lang_code}_{DEFAULT_DATASET_SUFFIX}\"\n",
" else:\n",
" dataset_display_name = (\n",
" f\"{DEFAULT_DATASET_PREFIX}_{source_lang_code}_to_{target_lang_code}\"\n",
" )\n",
"\n",
" data = {\n",
" \"display_name\": dataset_display_name,\n",
" \"source_language_code\": source_lang_code,\n",
" \"target_language_code\": target_lang_code,\n",
" }\n",
" dataset_url = f\"{url}/datasets\"\n",
" try:\n",
" response = requests.post(dataset_url, data=json.dumps(data), headers=headers)\n",
" data_create_response = json.loads(response.text)\n",
" return data_create_response\n",
" except Exception as e:\n",
" return e\n",
"\n",
"\n",
"def fetch_dataset_id(name: str, url: str) -> str or None:\n",
" \"\"\"Fetches dataset id for the given dataset name.\"\"\"\n",
" ACCESS_TOKEN = generate_access_token()\n",
" headers = {\n",
" \"Authorization\": f\"Bearer {ACCESS_TOKEN}\",\n",
" \"Content-Type\": \"application/json; charset=UTF-8\",\n",
" }\n",
" print(f\"dataset name provided: {name}\")\n",
"\n",
" fetch_dataset_url = f\"{url}/datasets\"\n",
" datasets = requests.get(fetch_dataset_url, headers=headers)\n",
" dataset_list = json.loads(datasets.text)\n",
" all_datasets = dataset_list[\"datasets\"]\n",
"\n",
" for dataset_details in all_datasets:\n",
" if name.lower() == dataset_details[\"displayName\"].lower():\n",
" print(dataset_details[\"name\"].split(\"/\", -1)[::-1][0])\n",
" return dataset_details[\"name\"].split(\"/\", -1)[::-1][0]\n",
" return\n",
"\n",
"\n",
"def import_data(url: str, dataset_id: str, tsv_uri: str) -> dict or None:\n",
" \"\"\"Imports TSV into a translation dataset.\"\"\"\n",
" if dataset_id is None:\n",
" return \"valid Dataset not found. Exiting.\"\n",
"\n",
" ACCESS_TOKEN = generate_access_token()\n",
" headers = {\n",
" \"Authorization\": f\"Bearer {ACCESS_TOKEN}\",\n",
" \"Content-Type\": \"application/json; charset=UTF-8\",\n",
" }\n",
"\n",
" print(f\"Dataset used: {dataset_id}\")\n",
"\n",
" data = {\n",
" \"input_config\": {\n",
" \"input_files\": [\n",
" {\n",
" \"display_name\": \"training_data.tsv\",\n",
" \"usage\": \"UNASSIGNED\",\n",
" \"gcs_source\": {\"input_uri\": tsv_uri},\n",
" }\n",
" ]\n",
" }\n",
" }\n",
"\n",
" importDataset_url = f\"{url}/datasets/{dataset_id}:importData\"\n",
" response = requests.post(importDataset_url, data=json.dumps(data), headers=headers)\n",
" try:\n",
" data_import_response = json.loads(response.text)\n",
" return data_import_response\n",
" except Exception as e:\n",
" print(\"Service unavailable!\", 500)\n",
" return e\n",
"\n",
"\n",
"def train_model(\n",
" model_name: str, project_id: str, location: str, dataset_id: str, url: str\n",
") -> dict:\n",
" \"\"\"Creates a custom model on top of NMT model\"\"\"\n",
" if dataset_id is None:\n",
" return \"valid dataset not found. Exiting.\"\n",
"\n",
" ACCESS_TOKEN = generate_access_token()\n",
" headers = {\n",
" \"Authorization\": f\"Bearer {ACCESS_TOKEN}\",\n",
" \"Content-Type\": \"application/json; charset=UTF-8\",\n",
" }\n",
"\n",
" data = {\n",
" \"display_name\": model_name,\n",
" \"dataset\": f\"projects/{project_id}/locations/{location}/datasets/{dataset_id}\",\n",
" }\n",
" models_url = f\"{url}/models\"\n",
" print(\n",
" f\"\"\"Model training details:\n",
" \n",
" 'model display name': {model_name},\n",
" 'dataset': {dataset_id}\n",
" \n",
" \"\"\"\n",
" )\n",
" response = requests.post(models_url, data=json.dumps(data), headers=headers)\n",
" try:\n",
" model_training_response = json.loads(response.text)\n",
" return model_training_response\n",
" except Exception as e:\n",
" print(\"Service unavailable!\", 500)\n",
" return e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b2bd00d9c381"
},
"source": [
"### Create a dataset\n",
"\n",
"Creates a Translation dataset. View in [console](https://console.cloud.google.com/translation/datasets)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "f23b1f449fe6"
},
"outputs": [],
"source": [
"create_dataset(\"de\", url, \"en\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "967544145ee2"
},
"source": [
"### Import data\n",
"Imports data into a Translation dataset. View in [console](https://console.cloud.google.com/translation/datasets)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "759dc4a0b2bd"
},
"outputs": [],
"source": [
"import_data(\n",
" url,\n",
" fetch_dataset_id(\n",
" name=(\n",
" f\"{DEFAULT_DATASET_PREFIX}_en_to_de{DEFAULT_DATASET_SUFFIX}\"\n",
" if DEFAULT_DATASET_SUFFIX is not None\n",
" else f\"{DEFAULT_DATASET_PREFIX}_en_to_de\"\n",
" ),\n",
" url=url,\n",
" ),\n",
" f\"<your cloud storage bucket here>/{tsv_file_name}\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f82ee49970f2"
},
"source": [
"### Train a model\n",
"\n",
"Triggers training for the given dataset name. View in [console](https://console.cloud.google.com/translation/locations/us-central1/datasets/1372e4ac8f9fa3a9/train)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "e9597d13b3be"
},
"outputs": [],
"source": [
"train_model(\n",
" \"test_model\",\n",
" PROJECT_ID,\n",
" LOCATION,\n",
" fetch_dataset_id(\n",
" name=(\n",
" f\"{DEFAULT_DATASET_PREFIX}_en_to_de{DEFAULT_DATASET_SUFFIX}\"\n",
" if DEFAULT_DATASET_SUFFIX is not None\n",
" else f\"{DEFAULT_DATASET_PREFIX}_en_to_de\"\n",
" ),\n",
" url=url,\n",
" ),\n",
" url,\n",
")"
]
}
],
"metadata": {
"colab": {
"name": "translation_training_data_tsv_generator.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}