people-and-planet-ai/image-classification/train_model.py (235 lines of code) (raw):
#!/usr/bin/env python
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Callable, Iterable
from datetime import datetime
import io
import logging
import random
import time
from typing import TypeVar
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from google.cloud.aiplatform.gapic import DatasetServiceClient
from google.cloud.aiplatform.gapic import PipelineServiceClient
from google.cloud.aiplatform.gapic.schema import trainingjob
from PIL import Image, ImageFile
import requests
a = TypeVar("a")
def run(
project: str,
region: str,
cloud_storage_path: str,
bigquery_dataset: str,
bigquery_table: str,
ai_platform_name_prefix: str,
min_images_per_class: int,
max_images_per_class: int,
budget_milli_node_hours: int,
pipeline_options: PipelineOptions | None = None,
) -> None:
"""Creates a balanced dataset and signals AI Platform to train a model.
Args:
project: Google Cloud Project ID.
region: Location for AI Platform resources.
bigquery_dataset: Dataset ID for the images database, the dataset must exist.
bigquery_table: Table ID for the images database, the table must exist.
ai_platform_name_prefix: Name prefix for AI Platform resources.
min_images_per_class: Minimum number of images required per class for training.
max_images_per_class: Maximum number of images allowed per class for training.
budget_milli_node_hours: Training budget.
pipeline_options: PipelineOptions for Apache Beam.
"""
with beam.Pipeline(options=pipeline_options) as pipeline:
images = (
pipeline
| "Read images info"
>> beam.io.ReadFromBigQuery(dataset=bigquery_dataset, table=bigquery_table)
| "Key by category" >> beam.WithKeys(lambda x: x["category"])
| "Random samples"
>> beam.combiners.Sample.FixedSizePerKey(max_images_per_class)
| "Remove key" >> beam.Values()
| "Discard small samples"
>> beam.Filter(lambda sample: len(sample) >= min_images_per_class)
| "Flatten elements" >> beam.FlatMap(lambda sample: sample)
| "Get image" >> beam.FlatMap(get_image, cloud_storage_path)
)
dataset_csv_filename = f"{cloud_storage_path}/dataset.csv"
dataset_csv_file = (
pipeline
| "Dataset filename" >> beam.Create([dataset_csv_filename])
| "Write dataset file"
>> beam.Map(write_dataset_csv_file, images=beam.pvalue.AsIter(images))
)
if ai_platform_name_prefix:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
(
dataset_csv_file
| "Create dataset"
>> beam.Map(
create_dataset,
project=project,
region=region,
dataset_name=f"{ai_platform_name_prefix}_{timestamp}",
)
| "Import images" >> beam.MapTuple(import_images_to_dataset)
| "Train model"
>> beam.Map(
train_model,
project=project,
region=region,
model_name=f"{ai_platform_name_prefix}_{timestamp}",
budget_milli_node_hours=budget_milli_node_hours,
)
)
def get_image(
image_info: dict[str, str], cloud_storage_path: str
) -> Iterable[tuple[str, str]]:
"""Makes sure an image exists in Cloud Storage.
Checks if the image file_name exists in Cloud Storage.
If it doesn't exist, it downloads it from the LILA WCS dataset.
If the image can't be downloaded, it is skipped.
Args:
image_info: dict of {'category', 'file_name'}.
cloud_storage_path: Cloud Storage path to look for and download images.
Returns:
A (category, image_gcs_path) tuple.
"""
import apache_beam as beam
base_url = "https://lilablobssc.blob.core.windows.net/wcs-unzipped"
category = image_info["category"]
file_name = image_info["file_name"]
# If the image file does not exist, try downloading it.
image_gcs_path = f"{cloud_storage_path}/{file_name}"
logging.info(f"loading image: {image_gcs_path}")
if not beam.io.gcp.gcsio.GcsIO().exists(image_gcs_path):
image_url = f"{base_url}/{file_name}"
logging.info(f"image not found, downloading: {image_gcs_path} [{image_url}]")
try:
ImageFile.LOAD_TRUNCATED_IMAGES = True
image = Image.open(io.BytesIO(url_get(image_url)))
with beam.io.gcp.gcsio.GcsIO().open(image_gcs_path, "w") as f:
image.save(f, format="JPEG")
except Exception as e:
logging.warning(f"Failed to load image [{image_url}]: {e}")
return
yield category, image_gcs_path
def write_dataset_csv_file(
dataset_csv_filename: str, images: Iterable[tuple[str, str]]
) -> str:
"""Writes the dataset image file names and categories in a CSV file.
Each line in the output dataset CSV file is in the format:
image_gcs_path,category
For more information on the CSV format AI Platform expects:
https://cloud.google.com/ai-platform-unified/docs/datasets/prepare-image#csv
Args:
dataset_csv_filename: Cloud Storage path for the output dataset CSV file.
images: List of (category, image_gcs_path) tuples.
Returns:
The unchanged dataset_csv_filename.
"""
import apache_beam as beam
logging.info(f"Writing dataset CSV file: {dataset_csv_filename}")
with beam.io.gcp.gcsio.GcsIO().open(dataset_csv_filename, "w") as f:
for category, image_gcs_path in images:
f.write(f"{image_gcs_path},{category}\n".encode())
return dataset_csv_filename
def create_dataset(
dataset_csv_filename: str, project: str, region: str, dataset_name: str
) -> tuple[str, str]:
"""Creates an dataset for AI Platform.
For more information:
https://cloud.google.com/ai-platform-unified/docs/datasets/create-dataset-api#create-dataset
Args:
dataset_csv_filename: Cloud Storage path for the dataset CSV file.
project: Google Cloud Project ID.
region: Location for AI Platform resources.
dataset_name: Dataset name.
Returns:
A (dataset_full_path, dataset_csv_filename) tuple.
"""
client = DatasetServiceClient(
client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
)
response = client.create_dataset(
parent=f"projects/{project}/locations/{region}",
dataset={
"display_name": dataset_name,
"metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml",
},
)
logging.info(f"Creating dataset, operation: {response.operation.name}")
dataset = response.result() # wait until the operation finishes
logging.info(f"Dataset created:\n{dataset}")
return dataset.name, dataset_csv_filename
def import_images_to_dataset(dataset_full_path: str, dataset_csv_filename: str) -> str:
"""Imports the images from the dataset CSV file into the AI Platform dataset.
For more information:
https://cloud.google.com/ai-platform-unified/docs/datasets/create-dataset-api#import-data
Args:
dataset_full_path: The AI Platform dataset full path.
dataset_csv_filename: Cloud Storage path for the dataset CSV file.
Returns:
The dataset_full_path.
"""
client = DatasetServiceClient(
client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
)
response = client.import_data(
name=dataset_full_path,
import_configs=[
{
"gcs_source": {"uris": [dataset_csv_filename]},
"import_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_classification_single_label_io_format_1.0.0.yaml",
}
],
)
logging.info(f"Importing data into dataset, operation: {response.operation.name}")
_ = response.result() # wait until the operation finishes
logging.info(f"Data imported: {dataset_full_path}")
return dataset_full_path
def train_model(
dataset_full_path: str,
project: str,
region: str,
model_name: str,
budget_milli_node_hours: int,
) -> str:
"""Starts a model training job.
For more information:
https://cloud.google.com/ai-platform-unified/docs/training/automl-api#training_an_automl_model_using_the_api
Args:
dataset_full_path: The AI Platform dataset full path.
project: Google Cloud Project ID.
region: Location for AI Platform resources.
model_name: Model name.
budget_milli_node_hours: Training budget.
Returns:
The training pipeline full path.
"""
client = PipelineServiceClient(
client_options={
"api_endpoint": "us-central1-aiplatform.googleapis.com",
}
)
training_pipeline = client.create_training_pipeline(
parent=f"projects/{project}/locations/{region}",
training_pipeline={
"display_name": model_name,
"input_data_config": {"dataset_id": dataset_full_path.split("/")[-1]},
"model_to_upload": {"display_name": model_name},
"training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml",
"training_task_inputs": trainingjob.definition.AutoMlImageClassificationInputs(
model_type="CLOUD",
budget_milli_node_hours=budget_milli_node_hours,
).to_value(),
},
)
logging.info(f"Training model, training pipeline:\n{training_pipeline}")
return training_pipeline.name
def url_get(url: str) -> bytes:
"""Sends an HTTP GET request with retries.
Args:
url: URL for the request.
Returns:
The response content bytes.
"""
logging.info(f"url_get: {url}")
return with_retries(lambda: requests.get(url).content)
def with_retries(f: Callable[[], a], max_attempts: int = 3) -> a:
"""Runs a function with retries, using exponential backoff.
For more information:
https://developers.google.com/drive/api/v3/handle-errors?hl=pt-pt#exponential-backoff
Args:
f: A function that doesn't receive any input.
max_attempts: The maximum number of attempts to run the function.
Returns:
The return value of `f`, or an Exception if max_attempts was reached.
"""
for n in range(max_attempts + 1):
try:
return f()
except Exception as e:
if n < max_attempts:
logging.warning(f"Got an error, {n+1} of {max_attempts} attempts: {e}")
time.sleep(2**n + random.random()) # 2^n seconds + random jitter
else:
raise e
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--cloud-storage-path",
required=True,
help="Cloud Storage path to store the AI Platform dataset files.",
)
parser.add_argument(
"--bigquery-dataset",
required=True,
help="BigQuery dataset ID for the images database.",
)
parser.add_argument(
"--bigquery-table",
default="wildlife_images_metadata",
help="BigQuery table ID for the images database.",
)
parser.add_argument(
"--ai-platform-name-prefix",
default="wildlife_classifier",
help="Name prefix for AI Platform resources.",
)
parser.add_argument(
"--min-images-per-class",
type=int,
default=50,
help="Minimum number of images required per class for training.",
)
parser.add_argument(
"--max-images-per-class",
type=int,
default=80,
help="Maximum number of images allowed per class for training.",
)
parser.add_argument(
"--budget-milli-node-hours",
type=int,
default=8000,
help="Training budget, see: https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#imageclassificationmodelmetadata",
)
args, pipeline_args = parser.parse_known_args()
pipeline_options = PipelineOptions(pipeline_args, save_main_session=True)
project = pipeline_options.get_all_options().get("project")
if not project:
parser.error("please provide a Google Cloud project ID with --project")
region = pipeline_options.get_all_options().get("region")
if not region:
parser.error("please provide a Google Cloud compute region with --region")
run(
project=project,
region=region,
cloud_storage_path=args.cloud_storage_path,
bigquery_dataset=args.bigquery_dataset,
bigquery_table=args.bigquery_table,
ai_platform_name_prefix=args.ai_platform_name_prefix,
min_images_per_class=args.min_images_per_class,
max_images_per_class=args.max_images_per_class,
budget_milli_node_hours=args.budget_milli_node_hours,
pipeline_options=pipeline_options,
)