notebooks/123_clipseg-zero-shot.ipynb (567 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "cfN3e_6aUnDe" }, "source": [ "# Using CLIPSeg with Hugging Face Transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "JE8CVZ86W2Ru" }, "source": [ "Using Hugging Face Transformers, you can easily download and run a pre-trained CLIPSeg model on your images. Let’s start by installing transformers." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sYigMTweWttc" }, "outputs": [], "source": [ "!pip install -q transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "10YAoBoIW3zn" }, "source": [ "To download the model, simply instantiate it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eCdgxg4SXg_r" }, "outputs": [], "source": [ "from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation\n", "\n", "processor = CLIPSegProcessor.from_pretrained(\"CIDAS/clipseg-rd64-refined\")\n", "model = CLIPSegForImageSegmentation.from_pretrained(\"CIDAS/clipseg-rd64-refined\")" ] }, { "cell_type": "markdown", "metadata": { "id": "UOAKiQmEW_UE" }, "source": [ "Now we can load an image to try out the segmentation. We'll choose a picture of a delicious breakfast taken by [Calum Lewis](https://unsplash.com/@calumlewis)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P4CQJ0dlW__z" }, "outputs": [], "source": [ "from PIL import Image\n", "import requests\n", "\n", "url = \"https://unsplash.com/photos/8Nc_oQsc2qQ/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjcxMjAwNzI0&force=true&w=640\"\n", "image = Image.open(requests.get(url, stream=True).raw)\n", "image" ] }, { "cell_type": "markdown", "metadata": { "id": "wjnCAKnaU04t" }, "source": [ "## Text prompting" ] }, { "cell_type": "markdown", "metadata": { "id": "qc4IuietEqr1" }, "source": [ "Let’s start by defining some text categories we want to segment." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cPvX5IE5EnCE" }, "outputs": [], "source": [ "prompts = [\"cutlery\", \"pancakes\", \"blueberries\", \"orange juice\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "OVWkP-j0W8WO" }, "source": [ "Now that we have our inputs, we can process them and input them to the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XfCjKfgtW1_M" }, "outputs": [], "source": [ "import torch\n", "\n", "inputs = processor(text=prompts, images=[image] * len(prompts), padding=\"max_length\", return_tensors=\"pt\")\n", "# predict\n", "with torch.no_grad():\n", " outputs = model(**inputs)\n", "preds = outputs.logits.unsqueeze(1)" ] }, { "cell_type": "markdown", "metadata": { "id": "NYpV9474r-pS" }, "source": [ "Finally, let’s visualize the output." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xPney6KSS2lq" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "_, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))\n", "[a.axis('off') for a in ax.flatten()]\n", "ax[0].imshow(image)\n", "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];\n", "[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];" ] }, { "cell_type": "markdown", "metadata": { "id": "SyAF0WLGVBUM" }, "source": [ "## Visual prompting" ] }, { "cell_type": "markdown", "metadata": { "id": "1_JceT-gVZKz" }, "source": [ "As mentioned before, we can also use images as the input prompts (i.e. in place of the category names). This can be especially useful if it's not easy to describe the thing you want to segment. For this example, we'll use a picture of a coffee cup taken by [Daniel Hooper](https://unsplash.com/@dan_fromyesmorecontent)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zqei2oKlVbe8" }, "outputs": [], "source": [ "url = \"https://unsplash.com/photos/Ki7sAc8gOGE/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MTJ8fGNvZmZlJTIwdG8lMjBnb3xlbnwwfHx8fDE2NzExOTgzNDQ&force=true&w=640\"\n", "prompt = Image.open(requests.get(url, stream=True).raw)\n", "prompt" ] }, { "cell_type": "markdown", "metadata": { "id": "i0okRt3OVmMu" }, "source": [ "We can now process the input image and prompt image and input them to the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0ldgagAIVsuM" }, "outputs": [], "source": [ "encoded_image = processor(images=[image], return_tensors=\"pt\")\n", "encoded_prompt = processor(images=[prompt], return_tensors=\"pt\")\n", "# predict\n", "with torch.no_grad():\n", " outputs = model(**encoded_image, conditional_pixel_values=encoded_prompt.pixel_values)\n", "preds = outputs.logits.unsqueeze(1)\n", "preds = torch.transpose(preds, 0, 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GlHJVJ6ZZFr2" }, "outputs": [], "source": [ "_, ax = plt.subplots(1, 2, figsize=(6, 4))\n", "[a.axis('off') for a in ax.flatten()]\n", "ax[0].imshow(image)\n", "ax[1].imshow(torch.sigmoid(preds[0]))" ] }, { "cell_type": "markdown", "metadata": { "id": "ATaBdwatfCsa" }, "source": [ "Let’s try one last time by using the visual prompting tips described in the paper, i.e. cropping the image and darkening the background." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VTXmXSN4dSVK" }, "outputs": [], "source": [ "url = \"https://i.imgur.com/mRSORqz.jpg\"\n", "alternative_prompt = Image.open(requests.get(url, stream=True).raw)\n", "alternative_prompt\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dcwEclf1db-p" }, "outputs": [], "source": [ "encoded_alternative_prompt = processor(images=[alternative_prompt], return_tensors=\"pt\")\n", "# predict\n", "with torch.no_grad():\n", " outputs = model(**encoded_image, conditional_pixel_values=encoded_alternative_prompt.pixel_values)\n", "preds = outputs.logits.unsqueeze(1)\n", "preds = torch.transpose(preds, 0, 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "15PDzaDedimS" }, "outputs": [], "source": [ "_, ax = plt.subplots(1, 2, figsize=(6, 4))\n", "[a.axis('off') for a in ax.flatten()]\n", "ax[0].imshow(image)\n", "ax[1].imshow(torch.sigmoid(preds[0]))" ] }, { "cell_type": "markdown", "metadata": { "id": "zDkUTz9hWUMv" }, "source": [ "In this case, the result is pretty much the same. This is probably because the coffee cup was already separated well from the background in the original image." ] }, { "cell_type": "markdown", "metadata": { "id": "W2cEony9VPEL" }, "source": [ "# Using CLIPSeg to pre-label images on Segments.ai" ] }, { "cell_type": "markdown", "metadata": { "id": "60IWZMWiXPF1" }, "source": [ "As you can see, the results from CLIPSeg are a little fuzzy and very low-res. If we want to obtain better results, you can fine-tune a state-of-the-art segmentation model, as explained in [our previous blogpost](https://huggingface.co/blog/fine-tune-segformer). To finetune the model, we'll need labeled data. In this section, we'll show you how you can use CLIPSeg to create some rough segmentation masks and then refine them on [Segments.ai](https://segments.ai/?utm_source=hf&utm_medium=colab&utm_campaign=clipseg), the best labeling platform for image segmentation. " ] }, { "cell_type": "markdown", "metadata": { "id": "gnAtvgwo7Sj8" }, "source": [ "First, create an account at [https://segments.ai/join](https://segments.ai/join?utm_source=hf&utm_medium=colab&utm_campaign=clipseg) and install the Segments Python SDK. Then you can initialize the Segments.ai Python client using an API key. This key can be found on [the account page](https://segments.ai/account)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EOM6DjBf-cK2" }, "outputs": [], "source": [ "!pip install -q segments-ai" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BWwGgvJ-7yZR" }, "outputs": [], "source": [ "from segments import SegmentsClient\n", "from getpass import getpass\n", "\n", "api_key = getpass('Enter your API key: ')\n", "segments_client = SegmentsClient(api_key)" ] }, { "cell_type": "markdown", "metadata": { "id": "r6WZ89IA9jgw" }, "source": [ "Next, let's load an image from a dataset using the Segments client. We'll use the [a2d2 self-driving dataset](https://www.a2d2.audi/a2d2/en.html). You can also create your own dataset by following [these instructions](https://docs.segments.ai/tutorials/getting-started)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PmEl1Vjd-Sh1" }, "outputs": [], "source": [ "samples = segments_client.get_samples(\"admin-tobias/clipseg\")\n", "\n", "# Use the last image as an example\n", "sample = samples[1]\n", "image = Image.open(requests.get(sample.attributes.image.url, stream=True).raw)\n", "image" ] }, { "cell_type": "markdown", "metadata": { "id": "h90uH9Ga_aGh" }, "source": [ "We also need to get the category names from the dataset attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xYT6W6Pz_eol" }, "outputs": [], "source": [ "dataset = segments_client.get_dataset(\"admin-tobias/clipseg\")\n", "category_names = [category.name for category in dataset.task_attributes.categories]" ] }, { "cell_type": "markdown", "metadata": { "id": "3NiMOEwS_suH" }, "source": [ "Now we can use CLIPSeg on the image as before. This time, we'll also scale up the outputs so that they match the input image's size." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t2OxnTQl_KrA" }, "outputs": [], "source": [ "from torch import nn\n", "\n", "inputs = processor(text=category_names, images=[image] * len(category_names), padding=\"max_length\", return_tensors=\"pt\")\n", "\n", "# predict\n", "with torch.no_grad():\n", " outputs = model(**inputs)\n", "\n", "# resize the outputs\n", "preds = nn.functional.interpolate(\n", " outputs.logits.unsqueeze(1),\n", " size=(image.size[1], image.size[0]),\n", " mode=\"bilinear\"\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "NrlA7MdDB3Hv" }, "source": [ "And we can visualize the results again." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vIaM7ZPVB2uy" }, "outputs": [], "source": [ "len_cats = len(category_names)\n", "_, ax = plt.subplots(1, len_cats + 1, figsize=(3*(len_cats + 1), 4))\n", "[a.axis('off') for a in ax.flatten()]\n", "ax[0].imshow(image)\n", "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len_cats)];\n", "[ax[i+1].text(0, -15, category_name) for i, category_name in enumerate(category_names)];" ] }, { "cell_type": "markdown", "metadata": { "id": "lypMLy9dCO--" }, "source": [ "Now we have to combine the predictions to a single segmentated image. We'll simply do this by taking the category with the greatest sigmoid value for each patch. We'll also make sure that all the values under a certain threshold do not count." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-N0CorwmCvHj" }, "outputs": [], "source": [ "threshold = 0.1\n", "\n", "flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))\n", "\n", "# Initialize a dummy \"unlabeled\" mask with the threshold\n", "flat_preds_with_treshold = torch.full((preds.shape[0] + 1, flat_preds.shape[-1]), threshold)\n", "flat_preds_with_treshold[1:preds.shape[0]+1,:] = flat_preds\n", "\n", "# Get the top mask index for each pixel\n", "inds = torch.topk(flat_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))" ] }, { "cell_type": "markdown", "metadata": { "id": "4wH2YfBwN5wn" }, "source": [ "Let's quickly visualize the result." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bDL52Mo2H6QG" }, "outputs": [], "source": [ "plt.imshow(inds)" ] }, { "cell_type": "markdown", "metadata": { "id": "Q5qoXhvgN9ax" }, "source": [ "Lastly, we can upload the prediction to Segments.ai. To do that, we'll first convert the bitmap to a png file, then we'll upload this file to the Segments, and finally we'll add the label to the sample." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0vUihdVtG9KR" }, "outputs": [], "source": [ "from segments.utils import bitmap2file\n", "import numpy as np\n", "\n", "inds_np = inds.numpy().astype(np.uint32)\n", "unique_inds = np.unique(inds_np).tolist()\n", "f = bitmap2file(inds_np, is_segmentation_bitmap=True)\n", "\n", "asset = segments_client.upload_asset(f, \"clipseg_prediction.png\")\n", "\n", "attributes = {\n", " 'format_version': '0.1',\n", " 'annotations': [{\"id\": i, \"category_id\": i} for i in unique_inds if i != 0],\n", " 'segmentation_bitmap': { 'url': asset.url },\n", " }\n", "\n", "segments_client.add_label(sample.uuid, 'ground-truth', attributes)" ] }, { "cell_type": "markdown", "metadata": { "id": "fXoeWXZUt2wR" }, "source": [ "If you take a look at the [uploaded prediction on Segments.ai](https://segments.ai/admin-tobias/clipseg/samples/71a80d39-8cf3-4768-a097-e81e0b677517/ground-truth), you can see that it's not perfect. However, you can manually correct the biggest mistakes, and then you can use the corrected dataset to train a better model than CLIPSeg." ] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.6.0 (default, Feb 23 2022, 10:45:45) \n[GCC 9.3.0]" }, "vscode": { "interpreter": { "hash": "c1330ff811d49988c6bd3f6d30257d79d8234a7d606d1a91e87e277837a80053" } } }, "nbformat": 4, "nbformat_minor": 0 }