projects/imagen-object-changer/imagen_object_changer.py (255 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Generatively change image objects or background.""" import argparse import base64 import configparser import io import json import subprocess import cv2 from google.cloud import vision import numpy as np import requests CONFIG_FILE = "config.ini" def parse_config_args(config_file): """Parses config.ini and command line args. Parses first the .ini -style config file, and then command line args. CLI args overwrite default values from the config file. Args: config_file: Path to config file Returns: Configparser config Raises: None """ # Create a config parser object config = configparser.ConfigParser() config["DEFAULT"] = { "mask": "mask.png", "invert_mask": "False", "output_json": "output.json", } config["parameters"] = {} # Create an argument parser parser = argparse.ArgumentParser() # Read the configuration file read_config = config.read(config_file) if not read_config: print("{} not found. Using command line args only".format(config_file)) # Add arguments for each configuration value using hardcoded defaults parser.add_argument( "--mask", default=config["DEFAULT"]["mask"], type=str, help="Output mask file", ) parser.add_argument( "--invert_mask", default=config["DEFAULT"]["invert_mask"], type=str, help="Invert mask; replace the background", ) parser.add_argument( "--output_json", default=config["DEFAULT"]["output_json"], type=str, help="Output JSON file name for GenAI response", ) parser.add_argument( "--input", required=True, type=str, help="Original image file" ) parser.add_argument( "--label", required=True, type=str, help="Object to detect e.g car | cat | tree", ) parser.add_argument( "--prompt", required=True, type=str, help="Imagen prompt for image generation", ) parser.add_argument( "--project_id", required=True, type=str, help="Google Cloud Project ID string", ) else: # Add arguments for each cfg value using read file for fallback defaults parser.add_argument( "--input", default=config["parameters"]["input"], type=str, help="Original image file", ) parser.add_argument( "--label", default=config["parameters"]["label"], type=str, help="Object to detect e.g car | cat", ) parser.add_argument( "--prompt", default=config["parameters"]["prompt"], type=str, help="Imagen prompt for image generation", ) parser.add_argument( "--project_id", default=config["parameters"]["project_id"], type=str, help="Google Cloud Project ID string", ) parser.add_argument( "--mask", default=config["parameters"]["mask"], type=str, help="Output mask file", ) parser.add_argument( "--invert_mask", default=config["parameters"]["invert_mask"], type=str, help="Invert mask; replace the background", ) parser.add_argument( "--output_json", default=config["parameters"]["output_json"], type=str, help="Output JSON file name for GenAI response", ) # Parse the arguments args = parser.parse_args() # Update the configuration values with the command line arguments for arg in vars(args): config["parameters"][arg] = getattr(args, arg) print(dict(config["parameters"])) # Check for required values if not config["parameters"]["project_id"]: print("error: the following arguments are required: --project_id") exit(1) return config def query_vision_api(input_img): """Queries Cloud Vision API for object classification. Uses Vision API to find objects in the input image. Args: input_img: user-supplied source image to infer Returns: objects found by Vision API Raises: None """ # Create a Vision client object. client = vision.ImageAnnotatorClient() # Create an Image object from the image content. image = vision.Image(content=input_img) # Perform object detection on the image. objects = client.object_localization( image=image ).localized_object_annotations return objects def draw_mask_image(input_file, objects, mask_file, label, invert): """Draws mask image based on selected objects' coordinates. Draws a mask image for Imagen. Iterates through the objects found by Vision API, and draws a mask for each object's coordinates, if the object label matches the desired one. Args: input_file: user-supplied original image file name objects: list of objects found by Vision API mask_file: Imagen mask file name label: type of object set by the user invert: whether to invert the mask or not, for Imagen Returns: Boolean flag if the mask file was created or not Raises: None """ # Create the mask image with same resolution as original orig_img = cv2.imread(input_file) mask_img = np.zeros((orig_img.shape), dtype=np.uint8) if invert: mask_img.fill(255) h, w, _ = orig_img.shape # draw the mask using Vision API bounding boxes masks = 0 for object_ in objects: print("{} (confidence: {})".format(object_.name, object_.score)) if object_.name.casefold() == label.casefold(): masks += 1 print("{} found! drawing mask".format(label)) vertices = object_.bounding_poly.normalized_vertices x1 = int(vertices[0].x * w) y1 = int(vertices[0].y * h) x2 = int(vertices[2].x * w) y2 = int(vertices[2].y * h) if invert: cv2.rectangle(mask_img, (x1, y1), (x2, y2), (0, 0, 0), -1) else: cv2.rectangle(mask_img, (x1, y1), (x2, y2), (255, 255, 255), -1) mask_created = True if masks > 0: cv2.imwrite(mask_file, mask_img) print("Wrote mask to: {}".format(mask_file)) else: print("No {} found in {}. Exiting".format(label, input_file)) mask_created = False return mask_created def query_imagen(prompt, input_img, mask_img, output_json, token, project_id): """Queries GenAI Imagen API for mask-based image editing. Uses Imagen to replace parts of the original image. The image mask restricts the image generation work area. Args: prompt: the Imagen mask-based editing GenAI prompt input_img: the original source image mask_img: the editing mask image output_json: output file for writing Imagen response token: Gcloud access token within the GCP project project_id: GCP project ID Returns: Boolean, whether Imagen query was successful or not Raises: None """ # Base64 encode the image and mask files img = base64.b64encode(input_img) mask_img = base64.b64encode(mask_img) # Create the JSON request body data = { "instances": [ { "prompt": prompt, "image": {"bytesBase64Encoded": img.decode("utf-8")}, "mask": { "image": { "bytesBase64Encoded": mask_img.decode("utf-8"), } }, } ], "parameters": {"sampleCount": 4, "sampleImageSize": "1024"}, } # Make the API request headers = { "Authorization": "Bearer {}".format(token), "Content-Type": "application/json", "User-Agent": "Mozilla/5.0", "Accept-Encoding": "identity", "Accept": "*/*", } url = ( "https://us-central1-aiplatform.googleapis.com/v1/projects/" + project_id + "/locations/us-central1/publishers/google/models/" + "imagegeneration@002:predict" ) print("Querying Imagen...") response = requests.post( url, headers=headers, data=json.dumps( data, sort_keys=False, indent=2, separators=(",", ": ") ), verify=True, timeout=None, ) imagen_success = True if not response: print("No or empty response from Imagen. Exiting") imagen_success = False else: with open(output_json, mode="w", encoding="utf-8") as f: f.write(response.text) print("Wrote Imagen response to: {}".format(output_json)) return imagen_success def get_gcloud_auth_token(): """Gets the session authentication token using Gcloud tool. Uses the Gcloud tool to get the current project's authentication token. The token is used for authenticating the HTTP POST request to Imagen. The function uses subprocess to execute gcloud from the shell. Args: None Returns: String, containing the authentication token Raises: None """ cmd = ("gcloud", "auth", "print-access-token") p = subprocess.run(cmd, capture_output=True, text=True, check=False) return p.stdout.strip() def write_images(output_json): """Parses Imagen response JSON and writes images to files. Parses the output.json response from Imagen, and for each payload within, decodes them, and writes as image files on disk. Args: output_json: output file for writing Imagen response Returns: None Raises: None """ with open(output_json, mode="r", encoding="utf-8") as f: data = json.load(f) i = 0 for prediction in data["predictions"]: image_data = base64.b64decode(prediction["bytesBase64Encoded"]) filename = "image" + str(i) + ".png" with open(filename, mode="wb") as outfile: outfile.write(image_data) i += 1 return i def main(): config = parse_config_args(CONFIG_FILE) label = config["parameters"]["label"] prompt = config["parameters"]["prompt"] input_file = config["parameters"]["input"] mask_file = config["parameters"]["mask"] output_json = config["parameters"]["output_json"] project_id = config["parameters"]["project_id"] if "True" in config["parameters"]["invert_mask"]: invert = True else: invert = False print( "Target label:", label, "source image:", input_file, "project_id:", project_id, ) # Read the image file into memory. with io.open(input_file, mode="rb") as f: input_img = f.read() # find label location(s) bounding boxes with Cloud Vision API objects = query_vision_api(input_img) print("Number of objects found: {}".format(len(objects))) mask_created = draw_mask_image( input_file, objects, mask_file, label, invert ) # Use GenAI Imagen to replace the object(s) or their background if mask_created: token = get_gcloud_auth_token() # Read the mask file into memory. with io.open(mask_file, mode="rb") as f: mask_img = f.read() imagen_success = query_imagen( prompt, input_img, mask_img, output_json, token, project_id ) # Extract and write generated images if imagen_success: written = write_images(output_json) if written: print("Wrote {} output images".format(written)) else: exit(1) if __name__ == "__main__": main()