In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generate training dataset for Cloud Translation API NMT (Neural Machine Translation) model training

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Ftranslation%2Ftranslation_training_data_tsv_generator.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/translation/translation_training_data_tsv_generator.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/translation/translation_training_data_tsv_generator.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>

| | |
|-|-|
|Author | [Abhijat Gupta](https://github.com/abhijat-gupta)

## **Overview**

[Cloud Translation API](https://cloud.google.com/translate/docs) uses Google's neural machine translation technology to let you dynamically translate text through the API using a Google pre-trained, custom model, or a translation specialized large language model (LLMs). 

It comes in [Basic and Advanced](https://cloud.google.com/translate/docs/editions) editions. Both provide fast and dynamic translation, but Advanced offers customization features, such as domain-specific translation, formatted document translation, and batch translation.

[AutoML Translation](https://cloud.google.com/translate/docs/advanced/automl-beginner) lets you build custom models (without writing code) that are tailored for your domain-specific content compared to the default Google Neural Machine Translation (NMT) model

The first 500,000 characters sent to the API to process (Basic and Advanced combined) per month are free (not applicable to LLMs).

## Objective

### Key Features
1. Paragraphs are converted into line-pairs of less than 200 words.
2. Tables in documents are converted into a line-pair with each row as a separate line-pair.
3. Limit of 200 words per line is handled.
4. Empty or blank lines are not added to the TSV.

This notebook enables you to generate a TSV file out of documents (docx) for training NMT (neural machine translation) model. The generated TSV file will contain the source and target line pairs for 2 languages in 2 columns respectively. Limit of 200 words for a line is handled within the code. Example: If a line is exceeding 200 words, it won't be added to the training dataset, but will be captured and returned in a dictionary so that you can decide on how to convert it to line-pair of less than 200 words.
The code also removes any blank or empty lines in a document from both source and reference before making line-pairs. This makes sure that both the documents do not mismatch with line-pairs due to empty lines.


## How to use the notebook

##### input: a dictionary containing source and reference GCS paths.

##### output: a single TSV file, 2 dictionaries

##### Steps to follow:
- Provide as many source and reference files in the input dictionary: `source_ref_dictionary`, *key* being the source file path and reference file path as its *value*
- Trigger all the cells after providing the input.
- The TSV gets created in your local path.



## Costs

Learn about [Translation pricing](https://cloud.google.com/translate/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## **Getting Started**
### Install docx SDK for Python

In [1]:
%pip install --proxy "" docx --quiet
%pip install --proxy "" python-docx --quiet

### Restart kernel

In [None]:
# Restart kernel after installs so that your environment can access the new packages
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

### Authenticate your notebook environment (Colab only)

If you are running this notebook on Google Colab, run the following cell to authenticate your environment. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-notebooks?hl=en).

In [None]:
import sys

# Additional authentication is required for Google Colab
if "google.colab" in sys.modules:
    # Authenticate user to Google Cloud
    from google.colab import auth

    auth.authenticate_user()

### imports

In [62]:
import json
import os

import docx
from docx.document import Document as _Document
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.table import Table, _Cell
from docx.text.paragraph import Paragraph
import google.auth
from google.auth.credentials import Credentials
from google.cloud import storage
import requests

### output TSV file name

In [77]:
# file name for the output tabular TSV.
tsv_file_name = "your_tsv_file_name.tsv"  # @param {type:"string"}
PROJECT_ID = "your project id"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}
DEFAULT_SOURCE_LANG_CODE = "<source_language>"  # @param {type:"string"}
DEFAULT_DATASET_PREFIX = "<your_dataset_prefix>"  # @param {type:"string"}
DEFAULT_DATASET_SUFFIX = "<your_dataset_suffix>"  # @param {type:"string"}

url = (
    f"https://translation.googleapis.com/v3/projects/{PROJECT_ID}/locations/{LOCATION}"
)

### source and reference paths

In [58]:
source_ref_dictionary = {
    "source_path1.docx": "reference_path1.docx",
    "source_path2.docx": "reference_path2.docx",
}

### Generate TSV

In [71]:
def get_document_objects(
    src_path: str, ref_path: str, source_bucket_name: str
) -> tuple[_Document, _Document]:
    """Fetches a source document and its translated/reference version from GCS bucket."""

    client = storage.Client()
    ref_file_name = ref_path.split(source_bucket_name + "/")[1]
    file_name = src_path.split(source_bucket_name + "/")[1]

    try:
        bucket = client.get_bucket(source_bucket_name)
        src_blob = bucket.get_blob(file_name)
        ref_blob = bucket.get_blob(ref_file_name)
    except TypeError as te:
        return te

    src_file_downloaded_name = file_name.split("source/")[1]
    ref_file_downloaded_name = ref_file_name.split("reference/")[1]

    src_filepath = os.path.join(os.getcwd(), src_file_downloaded_name + "_local.docx")
    ref_filepath = os.path.join(os.getcwd(), ref_file_downloaded_name + "_local.docx")

    with open(src_filepath, "wb") as src_f:
        src_blob.download_to_file(src_f)
    src_f.close()

    with open(ref_filepath, "wb") as ref_f:
        ref_blob.download_to_file(ref_f)
    ref_f.close()

    source = docx.Document(src_filepath)
    reference = docx.Document(ref_filepath)

    return source, reference


def iter_block_items(parent: _Document) -> Paragraph or Table:
    """
    Generate a reference to each paragraph and table child within *parent*,
    in document order. Each returned value is an instance of either Table or
    Paragraph. *parent* would most commonly be a reference to a main
    Document object, but also works for a _Cell object, which itself can
    contain paragraphs and tables.
    """
    if isinstance(parent, _Document):
        parent_elm = parent.element.body
    elif isinstance(parent, _Cell):
        parent_elm = parent._tc
    elif isinstance(parent, _Row):
        parent_elm = parent._tr
    else:
        raise ValueError("something's not right")
    for child in parent_elm.iterchildren():
        if isinstance(child, CT_P):
            yield Paragraph(child, parent)
        elif isinstance(child, CT_Tbl):
            yield Table(child, parent)


def make_tsv(source_ref_dictionary: dict, tsv_file_name: str) -> tuple[dict, dict]:
    """
    - This function reads the source and reference/translated documents from local paths iteratively, block-by-block.
    - A page blocks can be: Paragraphs and Tables.
    - In order to generate correct pairs, the type of blocks should be same for both source and reference.
    - If a block don't match, it get captured in mismatched_block dictionary and will not be added to the TSV. The Iteration stops and a TSV is created uptill the matching blocks.
    - ONLY docx format is supported.
    - Creates and saves the TSV in local path(Can be configured to save in GCS bucket).
    - Returns the mismatched blocks from the documents as a dictionary.
    """

    for src_path, ref_path in source_ref_dictionary.items():
        if src_path is None or src_path == "":
            return "source file path is invalid."
        if ref_path is None or ref_path == "":
            return "translated/reference file path is invalid."
        if src_path.split(".", -1)[::-1][0] != ref_path.split(".", -1)[::-1][0]:
            return "source and translated versions are in different format."

    tsv_file = os.path.join(os.getcwd(), tsv_file_name)
    if ".pdf" in src_path.split(src_path.split("gs://")[1].split("/")[0] + "/")[1]:
        return "PDFs are not supported. Process exited."

    try:
        mismatched_block = {}
        more_than_200_words = {}
        for source_path, reference_path in source_ref_dictionary.items():
            source_bucket_name = source_path.split("gs://")[1].split("/")[0]
            source, reference = get_document_objects(
                source_path, reference_path, source_bucket_name
            )

            with open(tsv_file, "a") as tsv_f:
                for para in source.paragraphs:
                    if len(para.text.strip()) == 0:
                        p = para._element
                        p.getparent().remove(p)
                        p._p = p._element = None
                for para in reference.paragraphs:
                    if len(para.text.strip()) == 0:
                        p = para._element
                        p.getparent().remove(p)
                        p._p = p._element = None

                for src_block, ref_block in zip(
                    iter_block_items(source), iter_block_items(reference)
                ):
                    if (
                        isinstance(src_block, Paragraph)
                        and isinstance(ref_block, Paragraph)
                        and src_block.text is not None
                        and ref_block.text is not None
                    ):
                        try:
                            tsv_f.write(src_block.text + "\t" + ref_block.text)
                            tsv_f.write("\n")
                        except Exception as e:
                            print(e)
                    elif isinstance(src_block, Table) and isinstance(ref_block, Table):
                        try:
                            for src_row, ref_row in zip(src_block.rows, ref_block.rows):
                                src_row_data = []
                                ref_row_data = []
                                for cell in src_row.cells:
                                    for paragraph in cell.paragraphs:
                                        src_row_data.append(paragraph.text)
                                for cell in ref_row.cells:
                                    for paragraph in cell.paragraphs:
                                        ref_row_data.append(paragraph.text)
                                if len(src_row_data) >= 200 or len(ref_row_data) >= 200:
                                    print(
                                        "Length of a pair detected to be greater than 200 words."
                                    )
                                    print("this pair will be skipped")
                                    more_than_200_words[" ".join(src_row_data)] = (
                                        " ".join(ref_row_data)
                                    )
                                else:
                                    tsv_f.write(
                                        " ".join(src_row_data)
                                        + "\t"
                                        + " ".join(ref_row_data)
                                    )
                                    tsv_f.write("\n")
                        except Exceptio as e:
                            print(e)
                    else:
                        try:
                            mismatched_block[src_block.text] = ref_block
                        except:
                            mismatched_block[src_block] = ref_block.text
                        break

            tsv_f.close()
        print(f"Generated TSV stored at {tsv_file}")
        return mismatched_block, more_than_200_words
    except Exception as e:
        print(e)

In [72]:
mismatched_block, more_than_200_words = make_tsv(source_ref_dictionary, tsv_file_name)

Generated TSV stored at /home/jupyter/src/your_tsv_file_name.tsv


In [73]:
mismatched_block

{}

In [74]:
more_than_200_words

{}

## Custom model training

In [85]:
def generate_access_token() -> Credentials:
    """Generates access token to call translate APIs."""
    creds, project = google.auth.default()

    auth_req = google.auth.transport.requests.Request()
    creds.refresh(auth_req)
    return creds.token


def create_dataset(
    target_lang_code: str,
    url: str,
    source_lang_code: str | None = DEFAULT_SOURCE_LANG_CODE,
) -> dict or None:
    """Creates a dataset."""
    ACCESS_TOKEN = generate_access_token()
    headers = {
        "Authorization": f"Bearer {ACCESS_TOKEN}",
        "Content-Type": "application/json; charset=UTF-8",
    }

    if DEFAULT_DATASET_SUFFIX != "" and DEFAULT_DATASET_SUFFIX is not None:
        dataset_display_name = f"{DEFAULT_DATASET_PREFIX}_{source_lang_code}_to_{target_lang_code}_{DEFAULT_DATASET_SUFFIX}"
    else:
        dataset_display_name = (
            f"{DEFAULT_DATASET_PREFIX}_{source_lang_code}_to_{target_lang_code}"
        )

    data = {
        "display_name": dataset_display_name,
        "source_language_code": source_lang_code,
        "target_language_code": target_lang_code,
    }
    dataset_url = f"{url}/datasets"
    try:
        response = requests.post(dataset_url, data=json.dumps(data), headers=headers)
        data_create_response = json.loads(response.text)
        return data_create_response
    except Exception as e:
        return e


def fetch_dataset_id(name: str, url: str) -> str or None:
    """Fetches dataset id for the given dataset name."""
    ACCESS_TOKEN = generate_access_token()
    headers = {
        "Authorization": f"Bearer {ACCESS_TOKEN}",
        "Content-Type": "application/json; charset=UTF-8",
    }
    print(f"dataset name provided: {name}")

    fetch_dataset_url = f"{url}/datasets"
    datasets = requests.get(fetch_dataset_url, headers=headers)
    dataset_list = json.loads(datasets.text)
    all_datasets = dataset_list["datasets"]

    for dataset_details in all_datasets:
        if name.lower() == dataset_details["displayName"].lower():
            print(dataset_details["name"].split("/", -1)[::-1][0])
            return dataset_details["name"].split("/", -1)[::-1][0]
    return


def import_data(url: str, dataset_id: str, tsv_uri: str) -> dict or None:
    """Imports TSV into a translation dataset."""
    if dataset_id is None:
        return "valid Dataset not found. Exiting."

    ACCESS_TOKEN = generate_access_token()
    headers = {
        "Authorization": f"Bearer {ACCESS_TOKEN}",
        "Content-Type": "application/json; charset=UTF-8",
    }

    print(f"Dataset used: {dataset_id}")

    data = {
        "input_config": {
            "input_files": [
                {
                    "display_name": "training_data.tsv",
                    "usage": "UNASSIGNED",
                    "gcs_source": {"input_uri": tsv_uri},
                }
            ]
        }
    }

    importDataset_url = f"{url}/datasets/{dataset_id}:importData"
    response = requests.post(importDataset_url, data=json.dumps(data), headers=headers)
    try:
        data_import_response = json.loads(response.text)
        return data_import_response
    except Exception as e:
        print("Service unavailable!", 500)
        return e


def train_model(
    model_name: str, project_id: str, location: str, dataset_id: str, url: str
) -> dict:
    """Creates a custom model on top of NMT model"""
    if dataset_id is None:
        return "valid dataset not found. Exiting."

    ACCESS_TOKEN = generate_access_token()
    headers = {
        "Authorization": f"Bearer {ACCESS_TOKEN}",
        "Content-Type": "application/json; charset=UTF-8",
    }

    data = {
        "display_name": model_name,
        "dataset": f"projects/{project_id}/locations/{location}/datasets/{dataset_id}",
    }
    models_url = f"{url}/models"
    print(
        f"""Model training details:
    
        'model display name': {model_name},
        'dataset': {dataset_id}
    
    """
    )
    response = requests.post(models_url, data=json.dumps(data), headers=headers)
    try:
        model_training_response = json.loads(response.text)
        return model_training_response
    except Exception as e:
        print("Service unavailable!", 500)
        return e

### Create a dataset

Creates a Translation dataset. View in [console](https://console.cloud.google.com/translation/datasets)

In [1]:
create_dataset("de", url, "en")

### Import data
Imports data into a Translation dataset. View in [console](https://console.cloud.google.com/translation/datasets)

In [2]:
import_data(
    url,
    fetch_dataset_id(
        name=(
            f"{DEFAULT_DATASET_PREFIX}_en_to_de{DEFAULT_DATASET_SUFFIX}"
            if DEFAULT_DATASET_SUFFIX is not None
            else f"{DEFAULT_DATASET_PREFIX}_en_to_de"
        ),
        url=url,
    ),
    f"<your cloud storage bucket here>/{tsv_file_name}",
)

### Train a model

Triggers training for the given dataset name. View in [console](https://console.cloud.google.com/translation/locations/us-central1/datasets/1372e4ac8f9fa3a9/train)

In [3]:
train_model(
    "test_model",
    PROJECT_ID,
    LOCATION,
    fetch_dataset_id(
        name=(
            f"{DEFAULT_DATASET_PREFIX}_en_to_de{DEFAULT_DATASET_SUFFIX}"
            if DEFAULT_DATASET_SUFFIX is not None
            else f"{DEFAULT_DATASET_PREFIX}_en_to_de"
        ),
        url=url,
    ),
    url,
)