notebooks/en/structured_generation_vision_language_models.ipynb (402 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Structured Generation from Images or Documents Using Vision Language Models\n",
"\n",
"We will be using the SmolVLM-Instruct model from HuggingFaceTB to extract structured information from documents. We will run the VLM using the Hugging Face Transformers library and the [Outlines library](https://github.com/dottxt-ai/outlines), which facilitates structured generation based on limiting token sampling probabilities. \n",
"\n",
"> This approach is based on a [Outlines tutorial](https://dottxt-ai.github.io/outlines/latest/cookbook/atomic_caption/).\n",
"\n",
"## Dependencies and imports\n",
"\n",
"First, let's install the necessary libraries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install accelerate outlines transformers torch flash-attn datasets sentencepiece"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's continue with importing the necessary libraries."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import outlines\n",
"import torch\n",
"\n",
"from datasets import load_dataset\n",
"from outlines.models.transformers_vision import transformers_vision\n",
"from transformers import AutoModelForImageTextToText, AutoProcessor\n",
"from pydantic import BaseModel"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialising our model\n",
"\n",
"We will start by initialising our model from [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct). Outlines expects us to pass in a model class and processor class, so we will make this example a bit more generic by creating a function that returns those. Alternatively, you could look at the model and tokenizer config within the [Hub repo files](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct/tree/main), and import those classes directly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some kwargs in processor config are unused and will not have any effect: image_seq_len. \n",
"Some kwargs in processor config are unused and will not have any effect: image_seq_len. \n"
]
}
],
"source": [
"model_name = \"HuggingFaceTB/SmolVLM-Instruct\"\n",
"\n",
"\n",
"def get_model_and_processor_class(model_name: str):\n",
" model = AutoModelForImageTextToText.from_pretrained(model_name)\n",
" processor = AutoProcessor.from_pretrained(model_name)\n",
" classes = model.__class__, processor.__class__\n",
" del model, processor\n",
" return classes\n",
"\n",
"\n",
"model_class, processor_class = get_model_and_processor_class(model_name)\n",
"\n",
"if torch.cuda.is_available():\n",
" device = \"cuda\"\n",
"elif torch.backends.mps.is_available():\n",
" device = \"mps\"\n",
"else:\n",
" device = \"cpu\"\n",
"\n",
"model = transformers_vision(\n",
" model_name,\n",
" model_class=model_class,\n",
" device=device,\n",
" model_kwargs={\"torch_dtype\": torch.bfloat16, \"device_map\": \"auto\"},\n",
" processor_kwargs={\"device\": device},\n",
" processor_class=processor_class,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Structured Generation\n",
"\n",
"Now, we are going to define a function that will define how the output of our model will be structured. We will be using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset), which contains a set of images along with questions and their chosen and rejected reponses. This is an okay dataset but we want to create additional text-image-to-text data on top of the images to get our own structured dataset, and potentially fine-tune our model on it. We will use the model to generate a caption, a question and a simple quality tag for the image. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class ImageData(BaseModel):\n",
" quality: str\n",
" description: str\n",
" question: str\n",
"\n",
"structured_generator = outlines.generate.json(model, ImageData)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's come up with an extraction prompt."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"\"\"\n",
"You are an image analysis assisant.\n",
"\n",
"Provide a quality tag, a description and a question.\n",
"\n",
"The quality can either be \"good\", \"okay\" or \"bad\".\n",
"The question should be concise and objective.\n",
"\n",
"Return your response as a valid JSON object.\n",
"\"\"\".strip()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load our image dataset."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['ds_name', 'image', 'question', 'chosen', 'rejected', 'origin_dataset', 'origin_split', 'idx', 'image_path'],\n",
" num_rows: 10\n",
"})"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = load_dataset(\"openbmb/RLAIF-V-Dataset\", split=\"train[:10]\")\n",
"dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's define a function that will extract the structured information from the image. We will format the prompt using the `apply_chat_template` method and pass it to the model along with the image after that."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/davidberenstein/Documents/programming/huggingface/cookbook/.venv/lib/python3.11/site-packages/dill/_dill.py:414: PicklingWarning: Cannot locate reference to <class '__main__.ImageData'>.\n",
" StockPickler.save(self, obj, save_persistent_id)\n",
"/Users/davidberenstein/Documents/programming/huggingface/cookbook/.venv/lib/python3.11/site-packages/dill/_dill.py:414: PicklingWarning: Cannot pickle <class '__main__.ImageData'>: __main__.ImageData has recursive self-references that trigger a RecursionError.\n",
" StockPickler.save(self, obj, save_persistent_id)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e1d431b922334b0297195415a11cf68a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/10 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['ds_name', 'image', 'question', 'chosen', 'rejected', 'origin_dataset', 'origin_split', 'idx', 'image_path', 'synthetic_question', 'synthetic_description', 'synthetic_quality'],\n",
" num_rows: 10\n",
"})"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def extract(row):\n",
" messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": prompt}],\n",
" },\n",
" ]\n",
"\n",
" formatted_prompt = model.processor.apply_chat_template(\n",
" messages, add_generation_prompt=True\n",
" )\n",
"\n",
" result = structured_generator(formatted_prompt, [row[\"image\"]])\n",
" row['synthetic_question'] = result.question\n",
" row['synthetic_description'] = result.description\n",
" row['synthetic_quality'] = result.quality\n",
" return row\n",
"\n",
"\n",
"dataset = dataset.map(lambda x: extract(x))\n",
"dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's now push our new dataset to the Hub."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ab88b1b3bb1441498788bdc2c2b4cf30",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e5e359d02ede43959e92a9e5626f9ffd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/10 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9f7f07dad09f47c5a8dfdeba403845f6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Creating parquet from Arrow format: 0%| | 0/1 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e47600c765b64b55aa6f93e9cf5d077e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"README.md: 0%| | 0.00/719 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"CommitInfo(commit_url='https://huggingface.co/datasets/davidberenstein1957/structured-generation-information-extraction-vlms-openbmb-RLAIF-V-Dataset/commit/373d6a25e8301077773fc6a37899b1598cf6f8cd', commit_message='Upload dataset', commit_description='', oid='373d6a25e8301077773fc6a37899b1598cf6f8cd', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/davidberenstein1957/structured-generation-information-extraction-vlms-openbmb-RLAIF-V-Dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='davidberenstein1957/structured-generation-information-extraction-vlms-openbmb-RLAIF-V-Dataset'), pr_revision=None, pr_num=None)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset.push_to_hub(\"davidberenstein1957/structured-generation-information-extraction-vlms-openbmb-RLAIF-V-Dataset\", split=\"train\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<iframe\n",
" src=\"https://huggingface.co/datasets/davidberenstein1957/structured-generation-information-extraction-vlms-openbmb-RLAIF-V-Dataset/embed/viewer/default/train?row=3\"\n",
" frameborder=\"0\"\n",
" width=\"100%\"\n",
" height=\"560px\"\n",
"></iframe>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The results are not perfect, but they are a good starting point to continue exploring with different models and prompts!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"We've seen how to extract structured information from documents using a vision language model. We can use similar extractive methods to extract structured information from documents, using somehting like `pdf2image` to convert the document to images and running information extraction on each image pdf of the page.\n",
"\n",
"```python\n",
"pdf_path = \"path/to/your/pdf/file.pdf\"\n",
"pages = convert_from_path(pdf_path)\n",
"for page in pages:\n",
" extract_objects = extract_objects(page, prompt)\n",
"```\n",
"\n",
"## Next Steps\n",
"\n",
"- Take a look at the [Outlines](https://github.com/outlines-ai/outlines) library for more information on how to use it. Explore the different methods and parameters.\n",
"- Explore extraction on your own usecase with your own model.\n",
"- Use a different method of extracting structured information from documents."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}