projects/imagen-object-changer/imagen_object_changer.py (255 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generatively change image objects or background."""
import argparse
import base64
import configparser
import io
import json
import subprocess
import cv2
from google.cloud import vision
import numpy as np
import requests
CONFIG_FILE = "config.ini"
def parse_config_args(config_file):
"""Parses config.ini and command line args.
Parses first the .ini -style config file, and then command line args.
CLI args overwrite default values from the config file.
Args:
config_file: Path to config file
Returns:
Configparser config
Raises:
None
"""
# Create a config parser object
config = configparser.ConfigParser()
config["DEFAULT"] = {
"mask": "mask.png",
"invert_mask": "False",
"output_json": "output.json",
}
config["parameters"] = {}
# Create an argument parser
parser = argparse.ArgumentParser()
# Read the configuration file
read_config = config.read(config_file)
if not read_config:
print("{} not found. Using command line args only".format(config_file))
# Add arguments for each configuration value using hardcoded defaults
parser.add_argument(
"--mask",
default=config["DEFAULT"]["mask"],
type=str,
help="Output mask file",
)
parser.add_argument(
"--invert_mask",
default=config["DEFAULT"]["invert_mask"],
type=str,
help="Invert mask; replace the background",
)
parser.add_argument(
"--output_json",
default=config["DEFAULT"]["output_json"],
type=str,
help="Output JSON file name for GenAI response",
)
parser.add_argument(
"--input", required=True, type=str, help="Original image file"
)
parser.add_argument(
"--label",
required=True,
type=str,
help="Object to detect e.g car | cat | tree",
)
parser.add_argument(
"--prompt",
required=True,
type=str,
help="Imagen prompt for image generation",
)
parser.add_argument(
"--project_id",
required=True,
type=str,
help="Google Cloud Project ID string",
)
else:
# Add arguments for each cfg value using read file for fallback defaults
parser.add_argument(
"--input",
default=config["parameters"]["input"],
type=str,
help="Original image file",
)
parser.add_argument(
"--label",
default=config["parameters"]["label"],
type=str,
help="Object to detect e.g car | cat",
)
parser.add_argument(
"--prompt",
default=config["parameters"]["prompt"],
type=str,
help="Imagen prompt for image generation",
)
parser.add_argument(
"--project_id",
default=config["parameters"]["project_id"],
type=str,
help="Google Cloud Project ID string",
)
parser.add_argument(
"--mask",
default=config["parameters"]["mask"],
type=str,
help="Output mask file",
)
parser.add_argument(
"--invert_mask",
default=config["parameters"]["invert_mask"],
type=str,
help="Invert mask; replace the background",
)
parser.add_argument(
"--output_json",
default=config["parameters"]["output_json"],
type=str,
help="Output JSON file name for GenAI response",
)
# Parse the arguments
args = parser.parse_args()
# Update the configuration values with the command line arguments
for arg in vars(args):
config["parameters"][arg] = getattr(args, arg)
print(dict(config["parameters"]))
# Check for required values
if not config["parameters"]["project_id"]:
print("error: the following arguments are required: --project_id")
exit(1)
return config
def query_vision_api(input_img):
"""Queries Cloud Vision API for object classification.
Uses Vision API to find objects in the input image.
Args:
input_img: user-supplied source image to infer
Returns:
objects found by Vision API
Raises:
None
"""
# Create a Vision client object.
client = vision.ImageAnnotatorClient()
# Create an Image object from the image content.
image = vision.Image(content=input_img)
# Perform object detection on the image.
objects = client.object_localization(
image=image
).localized_object_annotations
return objects
def draw_mask_image(input_file, objects, mask_file, label, invert):
"""Draws mask image based on selected objects' coordinates.
Draws a mask image for Imagen. Iterates through the objects
found by Vision API, and draws a mask for each object's coordinates,
if the object label matches the desired one.
Args:
input_file: user-supplied original image file name
objects: list of objects found by Vision API
mask_file: Imagen mask file name
label: type of object set by the user
invert: whether to invert the mask or not, for Imagen
Returns:
Boolean flag if the mask file was created or not
Raises:
None
"""
# Create the mask image with same resolution as original
orig_img = cv2.imread(input_file)
mask_img = np.zeros((orig_img.shape), dtype=np.uint8)
if invert:
mask_img.fill(255)
h, w, _ = orig_img.shape
# draw the mask using Vision API bounding boxes
masks = 0
for object_ in objects:
print("{} (confidence: {})".format(object_.name, object_.score))
if object_.name.casefold() == label.casefold():
masks += 1
print("{} found! drawing mask".format(label))
vertices = object_.bounding_poly.normalized_vertices
x1 = int(vertices[0].x * w)
y1 = int(vertices[0].y * h)
x2 = int(vertices[2].x * w)
y2 = int(vertices[2].y * h)
if invert:
cv2.rectangle(mask_img, (x1, y1), (x2, y2), (0, 0, 0), -1)
else:
cv2.rectangle(mask_img, (x1, y1), (x2, y2), (255, 255, 255), -1)
mask_created = True
if masks > 0:
cv2.imwrite(mask_file, mask_img)
print("Wrote mask to: {}".format(mask_file))
else:
print("No {} found in {}. Exiting".format(label, input_file))
mask_created = False
return mask_created
def query_imagen(prompt, input_img, mask_img, output_json, token, project_id):
"""Queries GenAI Imagen API for mask-based image editing.
Uses Imagen to replace parts of the original image. The image mask
restricts the image generation work area.
Args:
prompt: the Imagen mask-based editing GenAI prompt
input_img: the original source image
mask_img: the editing mask image
output_json: output file for writing Imagen response
token: Gcloud access token within the GCP project
project_id: GCP project ID
Returns:
Boolean, whether Imagen query was successful or not
Raises:
None
"""
# Base64 encode the image and mask files
img = base64.b64encode(input_img)
mask_img = base64.b64encode(mask_img)
# Create the JSON request body
data = {
"instances": [
{
"prompt": prompt,
"image": {"bytesBase64Encoded": img.decode("utf-8")},
"mask": {
"image": {
"bytesBase64Encoded": mask_img.decode("utf-8"),
}
},
}
],
"parameters": {"sampleCount": 4, "sampleImageSize": "1024"},
}
# Make the API request
headers = {
"Authorization": "Bearer {}".format(token),
"Content-Type": "application/json",
"User-Agent": "Mozilla/5.0",
"Accept-Encoding": "identity",
"Accept": "*/*",
}
url = (
"https://us-central1-aiplatform.googleapis.com/v1/projects/"
+ project_id
+ "/locations/us-central1/publishers/google/models/"
+ "imagegeneration@002:predict"
)
print("Querying Imagen...")
response = requests.post(
url,
headers=headers,
data=json.dumps(
data, sort_keys=False, indent=2, separators=(",", ": ")
),
verify=True,
timeout=None,
)
imagen_success = True
if not response:
print("No or empty response from Imagen. Exiting")
imagen_success = False
else:
with open(output_json, mode="w", encoding="utf-8") as f:
f.write(response.text)
print("Wrote Imagen response to: {}".format(output_json))
return imagen_success
def get_gcloud_auth_token():
"""Gets the session authentication token using Gcloud tool.
Uses the Gcloud tool to get the current project's authentication token.
The token is used for authenticating the HTTP POST request to Imagen.
The function uses subprocess to execute gcloud from the shell.
Args: None
Returns:
String, containing the authentication token
Raises:
None
"""
cmd = ("gcloud", "auth", "print-access-token")
p = subprocess.run(cmd, capture_output=True, text=True, check=False)
return p.stdout.strip()
def write_images(output_json):
"""Parses Imagen response JSON and writes images to files.
Parses the output.json response from Imagen, and for each
payload within, decodes them, and writes as image files on disk.
Args:
output_json: output file for writing Imagen response
Returns:
None
Raises:
None
"""
with open(output_json, mode="r", encoding="utf-8") as f:
data = json.load(f)
i = 0
for prediction in data["predictions"]:
image_data = base64.b64decode(prediction["bytesBase64Encoded"])
filename = "image" + str(i) + ".png"
with open(filename, mode="wb") as outfile:
outfile.write(image_data)
i += 1
return i
def main():
config = parse_config_args(CONFIG_FILE)
label = config["parameters"]["label"]
prompt = config["parameters"]["prompt"]
input_file = config["parameters"]["input"]
mask_file = config["parameters"]["mask"]
output_json = config["parameters"]["output_json"]
project_id = config["parameters"]["project_id"]
if "True" in config["parameters"]["invert_mask"]:
invert = True
else:
invert = False
print(
"Target label:",
label,
"source image:",
input_file,
"project_id:",
project_id,
)
# Read the image file into memory.
with io.open(input_file, mode="rb") as f:
input_img = f.read()
# find label location(s) bounding boxes with Cloud Vision API
objects = query_vision_api(input_img)
print("Number of objects found: {}".format(len(objects)))
mask_created = draw_mask_image(
input_file, objects, mask_file, label, invert
)
# Use GenAI Imagen to replace the object(s) or their background
if mask_created:
token = get_gcloud_auth_token()
# Read the mask file into memory.
with io.open(mask_file, mode="rb") as f:
mask_img = f.read()
imagen_success = query_imagen(
prompt, input_img, mask_img, output_json, token, project_id
)
# Extract and write generated images
if imagen_success:
written = write_images(output_json)
if written:
print("Wrote {} output images".format(written))
else:
exit(1)
if __name__ == "__main__":
main()