vectorizer/main.py (59 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import os from tempfile import NamedTemporaryFile import numpy as np import tensorflow as tf from google.cloud import storage logger = logging.getLogger("vectorizer") class SampleDataVectorizer: BUCKET = "cloud-samples-data" PREFIX = "ai-platform/flowers/" def __init__(self, flower: str, destination: str): self._flower = flower self._client = storage.Client() self._blobs = self._client.list_blobs( self.BUCKET, prefix=f"{self.PREFIX}{flower}/" ) dst_bucket_name, dst_base = destination[5:].split("/", maxsplit=1) self._dst_bucket = self._client.bucket(dst_bucket_name) self._dst_base = dst_base self._model = tf.keras.applications.EfficientNetB0( include_top=False, pooling="avg" ) def vectorize_and_upload(self) -> None: data = [] for blob in self._blobs: name = blob.name.split("/")[-1] logger.info("downloading %s", name) raw = self._download_as_tensor(blob) logger.info("vectorizing %s", name) embedding = self._vectorize(raw) data.append( { "id": f"{self._flower}/{name}", "embedding": embedding, } ) blob = self._dst_bucket.blob(f"{self._dst_base}/{self._flower}.json") with blob.open(mode="w") as f: for datapoint in data: f.write(json.dumps(datapoint) + "\n") def _download_as_tensor(self, blob: storage.Blob) -> tf.Tensor: with NamedTemporaryFile(prefix="vectorizer") as temp: blob.download_to_filename(temp.name) return tf.io.read_file(temp.name) def _vectorize(self, raw: tf.Tensor) -> list[float]: image = tf.image.decode_jpeg(raw, channels=3) return self._model.predict(np.array([image.numpy()]))[0].tolist() def main(destination_root: str, flower: str) -> None: destination = os.path.join(destination_root, "flowers") vectorizer = SampleDataVectorizer(flower, destination) vectorizer.vectorize_and_upload() logger.info("finished successfully 🤗") if __name__ == "__main__": logging.basicConfig(level=logging.INFO) # e.g. gs://my-bucket/index01/embeddings destination = os.environ["DESTINATION_ROOT"] flowers = os.environ["FLOWERS"].split(",") task_index = int(os.environ.get("CLOUD_RUN_TASK_INDEX", "0")) main(destination, flowers[task_index])