#!/usr/bin/env python

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Callable, Iterable
from datetime import datetime
import io
import logging
import random
import time
from typing import TypeVar

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from google.cloud.aiplatform.gapic import DatasetServiceClient
from google.cloud.aiplatform.gapic import PipelineServiceClient
from google.cloud.aiplatform.gapic.schema import trainingjob
from PIL import Image, ImageFile
import requests

a = TypeVar("a")


def run(
    project: str,
    region: str,
    cloud_storage_path: str,
    bigquery_dataset: str,
    bigquery_table: str,
    ai_platform_name_prefix: str,
    min_images_per_class: int,
    max_images_per_class: int,
    budget_milli_node_hours: int,
    pipeline_options: PipelineOptions | None = None,
) -> None:
    """Creates a balanced dataset and signals AI Platform to train a model.

    Args:
        project: Google Cloud Project ID.
        region: Location for AI Platform resources.
        bigquery_dataset: Dataset ID for the images database, the dataset must exist.
        bigquery_table: Table ID for the images database, the table must exist.
        ai_platform_name_prefix: Name prefix for AI Platform resources.
        min_images_per_class: Minimum number of images required per class for training.
        max_images_per_class: Maximum number of images allowed per class for training.
        budget_milli_node_hours: Training budget.
        pipeline_options: PipelineOptions for Apache Beam.

    """
    with beam.Pipeline(options=pipeline_options) as pipeline:
        images = (
            pipeline
            | "Read images info"
            >> beam.io.ReadFromBigQuery(dataset=bigquery_dataset, table=bigquery_table)
            | "Key by category" >> beam.WithKeys(lambda x: x["category"])
            | "Random samples"
            >> beam.combiners.Sample.FixedSizePerKey(max_images_per_class)
            | "Remove key" >> beam.Values()
            | "Discard small samples"
            >> beam.Filter(lambda sample: len(sample) >= min_images_per_class)
            | "Flatten elements" >> beam.FlatMap(lambda sample: sample)
            | "Get image" >> beam.FlatMap(get_image, cloud_storage_path)
        )

        dataset_csv_filename = f"{cloud_storage_path}/dataset.csv"
        dataset_csv_file = (
            pipeline
            | "Dataset filename" >> beam.Create([dataset_csv_filename])
            | "Write dataset file"
            >> beam.Map(write_dataset_csv_file, images=beam.pvalue.AsIter(images))
        )

        if ai_platform_name_prefix:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            (
                dataset_csv_file
                | "Create dataset"
                >> beam.Map(
                    create_dataset,
                    project=project,
                    region=region,
                    dataset_name=f"{ai_platform_name_prefix}_{timestamp}",
                )
                | "Import images" >> beam.MapTuple(import_images_to_dataset)
                | "Train model"
                >> beam.Map(
                    train_model,
                    project=project,
                    region=region,
                    model_name=f"{ai_platform_name_prefix}_{timestamp}",
                    budget_milli_node_hours=budget_milli_node_hours,
                )
            )


def get_image(
    image_info: dict[str, str], cloud_storage_path: str
) -> Iterable[tuple[str, str]]:
    """Makes sure an image exists in Cloud Storage.

    Checks if the image file_name exists in Cloud Storage.
    If it doesn't exist, it downloads it from the LILA WCS dataset.
    If the image can't be downloaded, it is skipped.

    Args:
        image_info: dict of {'category', 'file_name'}.
        cloud_storage_path: Cloud Storage path to look for and download images.

    Returns:
        A (category, image_gcs_path) tuple.
    """
    import apache_beam as beam

    base_url = "https://lilablobssc.blob.core.windows.net/wcs-unzipped"
    category = image_info["category"]
    file_name = image_info["file_name"]

    # If the image file does not exist, try downloading it.
    image_gcs_path = f"{cloud_storage_path}/{file_name}"
    logging.info(f"loading image: {image_gcs_path}")
    if not beam.io.gcp.gcsio.GcsIO().exists(image_gcs_path):
        image_url = f"{base_url}/{file_name}"
        logging.info(f"image not found, downloading: {image_gcs_path} [{image_url}]")
        try:
            ImageFile.LOAD_TRUNCATED_IMAGES = True
            image = Image.open(io.BytesIO(url_get(image_url)))
            with beam.io.gcp.gcsio.GcsIO().open(image_gcs_path, "w") as f:
                image.save(f, format="JPEG")
        except Exception as e:
            logging.warning(f"Failed to load image [{image_url}]: {e}")
            return

    yield category, image_gcs_path


def write_dataset_csv_file(
    dataset_csv_filename: str, images: Iterable[tuple[str, str]]
) -> str:
    """Writes the dataset image file names and categories in a CSV file.

    Each line in the output dataset CSV file is in the format:
        image_gcs_path,category

    For more information on the CSV format AI Platform expects:
        https://cloud.google.com/ai-platform-unified/docs/datasets/prepare-image#csv

    Args:
        dataset_csv_filename: Cloud Storage path for the output dataset CSV file.
        images: List of (category, image_gcs_path) tuples.

    Returns:
        The unchanged dataset_csv_filename.
    """
    import apache_beam as beam

    logging.info(f"Writing dataset CSV file: {dataset_csv_filename}")
    with beam.io.gcp.gcsio.GcsIO().open(dataset_csv_filename, "w") as f:
        for category, image_gcs_path in images:
            f.write(f"{image_gcs_path},{category}\n".encode())
    return dataset_csv_filename


def create_dataset(
    dataset_csv_filename: str, project: str, region: str, dataset_name: str
) -> tuple[str, str]:
    """Creates an dataset for AI Platform.

    For more information:
        https://cloud.google.com/ai-platform-unified/docs/datasets/create-dataset-api#create-dataset

    Args:
        dataset_csv_filename: Cloud Storage path for the dataset CSV file.
        project: Google Cloud Project ID.
        region: Location for AI Platform resources.
        dataset_name: Dataset name.

    Returns:
        A (dataset_full_path, dataset_csv_filename) tuple.
    """
    client = DatasetServiceClient(
        client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
    )

    response = client.create_dataset(
        parent=f"projects/{project}/locations/{region}",
        dataset={
            "display_name": dataset_name,
            "metadata_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml",
        },
    )
    logging.info(f"Creating dataset, operation: {response.operation.name}")
    dataset = response.result()  # wait until the operation finishes
    logging.info(f"Dataset created:\n{dataset}")
    return dataset.name, dataset_csv_filename


def import_images_to_dataset(dataset_full_path: str, dataset_csv_filename: str) -> str:
    """Imports the images from the dataset CSV file into the AI Platform dataset.

    For more information:
        https://cloud.google.com/ai-platform-unified/docs/datasets/create-dataset-api#import-data

    Args:
        dataset_full_path: The AI Platform dataset full path.
        dataset_csv_filename: Cloud Storage path for the dataset CSV file.

    Returns:
        The dataset_full_path.
    """
    client = DatasetServiceClient(
        client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
    )

    response = client.import_data(
        name=dataset_full_path,
        import_configs=[
            {
                "gcs_source": {"uris": [dataset_csv_filename]},
                "import_schema_uri": "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_classification_single_label_io_format_1.0.0.yaml",
            }
        ],
    )
    logging.info(f"Importing data into dataset, operation: {response.operation.name}")
    _ = response.result()  # wait until the operation finishes
    logging.info(f"Data imported: {dataset_full_path}")
    return dataset_full_path


def train_model(
    dataset_full_path: str,
    project: str,
    region: str,
    model_name: str,
    budget_milli_node_hours: int,
) -> str:
    """Starts a model training job.

    For more information:
        https://cloud.google.com/ai-platform-unified/docs/training/automl-api#training_an_automl_model_using_the_api

    Args:
        dataset_full_path: The AI Platform dataset full path.
        project: Google Cloud Project ID.
        region: Location for AI Platform resources.
        model_name: Model name.
        budget_milli_node_hours: Training budget.

    Returns:
        The training pipeline full path.
    """
    client = PipelineServiceClient(
        client_options={
            "api_endpoint": "us-central1-aiplatform.googleapis.com",
        }
    )

    training_pipeline = client.create_training_pipeline(
        parent=f"projects/{project}/locations/{region}",
        training_pipeline={
            "display_name": model_name,
            "input_data_config": {"dataset_id": dataset_full_path.split("/")[-1]},
            "model_to_upload": {"display_name": model_name},
            "training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml",
            "training_task_inputs": trainingjob.definition.AutoMlImageClassificationInputs(
                model_type="CLOUD",
                budget_milli_node_hours=budget_milli_node_hours,
            ).to_value(),
        },
    )
    logging.info(f"Training model, training pipeline:\n{training_pipeline}")
    return training_pipeline.name


def url_get(url: str) -> bytes:
    """Sends an HTTP GET request with retries.

    Args:
        url: URL for the request.

    Returns:
        The response content bytes.
    """
    logging.info(f"url_get: {url}")
    return with_retries(lambda: requests.get(url).content)


def with_retries(f: Callable[[], a], max_attempts: int = 3) -> a:
    """Runs a function with retries, using exponential backoff.

    For more information:
        https://developers.google.com/drive/api/v3/handle-errors?hl=pt-pt#exponential-backoff

    Args:
        f: A function that doesn't receive any input.
        max_attempts: The maximum number of attempts to run the function.

    Returns:
        The return value of `f`, or an Exception if max_attempts was reached.
    """
    for n in range(max_attempts + 1):
        try:
            return f()
        except Exception as e:
            if n < max_attempts:
                logging.warning(f"Got an error, {n+1} of {max_attempts} attempts: {e}")
                time.sleep(2**n + random.random())  # 2^n seconds + random jitter
            else:
                raise e


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cloud-storage-path",
        required=True,
        help="Cloud Storage path to store the AI Platform dataset files.",
    )
    parser.add_argument(
        "--bigquery-dataset",
        required=True,
        help="BigQuery dataset ID for the images database.",
    )
    parser.add_argument(
        "--bigquery-table",
        default="wildlife_images_metadata",
        help="BigQuery table ID for the images database.",
    )
    parser.add_argument(
        "--ai-platform-name-prefix",
        default="wildlife_classifier",
        help="Name prefix for AI Platform resources.",
    )
    parser.add_argument(
        "--min-images-per-class",
        type=int,
        default=50,
        help="Minimum number of images required per class for training.",
    )
    parser.add_argument(
        "--max-images-per-class",
        type=int,
        default=80,
        help="Maximum number of images allowed per class for training.",
    )
    parser.add_argument(
        "--budget-milli-node-hours",
        type=int,
        default=8000,
        help="Training budget, see: https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#imageclassificationmodelmetadata",
    )
    args, pipeline_args = parser.parse_known_args()

    pipeline_options = PipelineOptions(pipeline_args, save_main_session=True)
    project = pipeline_options.get_all_options().get("project")
    if not project:
        parser.error("please provide a Google Cloud project ID with --project")
    region = pipeline_options.get_all_options().get("region")
    if not region:
        parser.error("please provide a Google Cloud compute region with --region")

    run(
        project=project,
        region=region,
        cloud_storage_path=args.cloud_storage_path,
        bigquery_dataset=args.bigquery_dataset,
        bigquery_table=args.bigquery_table,
        ai_platform_name_prefix=args.ai_platform_name_prefix,
        min_images_per_class=args.min_images_per_class,
        max_images_per_class=args.max_images_per_class,
        budget_milli_node_hours=args.budget_milli_node_hours,
        pipeline_options=pipeline_options,
    )
