ja/5_vision_language_models/notebooks/vlm_sft_sample.ipynb (469 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 画像とテキストをVLMで処理する\n",
"\n",
"このノートブックでは、`HuggingFaceTB/SmolVLM-Instruct` 4bit量子化モデルを使用して、以下のようなさまざまなマルチモーダルタスクを実行する方法を示します:\n",
"- 視覚質問応答(VQA):画像の内容に基づいて質問に答える。\n",
"- テキスト認識(OCR):画像内のテキストを抽出して解釈する。\n",
"- ビデオ理解:連続フレーム解析を通じてビデオを説明する。\n",
"\n",
"プロンプトを効果的に構築することで、シーン理解、ドキュメント分析、動的視覚推論など、多くのアプリケーションにモデルを活用できます。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Google Colabでの要件をインストール\n",
"# !pip install transformers datasets trl huggingface_hub bitsandbytes\n",
"\n",
"# Hugging Faceに認証\n",
"from huggingface_hub import notebook_login\n",
"notebook_login()\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"`low_cpu_mem_usage` was None, now default to True since model is quantized.\n",
"You shouldn't move a model that is dispatched using accelerate hooks.\n",
"Some kwargs in processor config are unused and will not have any effect: image_seq_len. \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'longest_edge': 1536}\n"
]
}
],
"source": [
"import torch, PIL\n",
"from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig\n",
"from transformers.image_utils import load_image\n",
"\n",
"device = (\n",
" \"cuda\"\n",
" if torch.cuda.is_available()\n",
" else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
")\n",
"\n",
"quantization_config = BitsAndBytesConfig(load_in_4bit=True)\n",
"model_name = \"HuggingFaceTB/SmolVLM-Instruct\"\n",
"model = AutoModelForVision2Seq.from_pretrained(\n",
" model_name,\n",
" quantization_config=quantization_config,\n",
").to(device)\n",
"processor = AutoProcessor.from_pretrained(\"HuggingFaceTB/SmolVLM-Instruct\")\n",
"\n",
"print(processor.image_processor.size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 画像の処理"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"まず、画像のキャプションを生成し、画像に関する質問に答えることから始めましょう。また、複数の画像を処理する方法も探ります。\n",
"### 1. 単一画像のキャプション生成"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<img src=\"https://cdn.pixabay.com/photo/2024/11/20/09/14/christmas-9210799_1280.jpg\"/>"
],
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"https://cdn.pixabay.com/photo/2024/11/23/08/18/christmas-9218404_1280.jpg\"/>"
],
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from IPython.display import Image, display\n",
"\n",
"image_url1 = \"https://cdn.pixabay.com/photo/2024/11/20/09/14/christmas-9210799_1280.jpg\"\n",
"display(Image(url=image_url1))\n",
"\n",
"image_url2 = \"https://cdn.pixabay.com/photo/2024/11/23/08/18/christmas-9218404_1280.jpg\"\n",
"display(Image(url=image_url2))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/duydl/Miniconda3/envs/py310/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:451: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"['User:<image>Can you describe the image?\\nAssistant: The image is a scene of a person walking in a forest. The person is wearing a coat and a cap. The person is holding the hand of another person. The person is walking on a path. The path is covered with dry leaves. The background of the image is a forest with trees.']\n"
]
}
],
"source": [
"# 画像を1枚読み込む\n",
"image1 = load_image(image_url1)\n",
"\n",
"# 入力メッセージを作成\n",
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"image\"},\n",
" {\"type\": \"text\", \"text\": \"Can you describe the image?\"}\n",
" ]\n",
" },\n",
"]\n",
"\n",
"# 入力を準備\n",
"prompt = processor.apply_chat_template(messages, add_generation_prompt=True)\n",
"inputs = processor(text=prompt, images=[image1], return_tensors=\"pt\")\n",
"inputs = inputs.to(device)\n",
"\n",
"# 出力を生成\n",
"generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
"generated_texts = processor.batch_decode(\n",
" generated_ids,\n",
" skip_special_tokens=True,\n",
")\n",
"\n",
"print(generated_texts)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. 複数画像の比較\n",
"モデルは複数の画像を処理して比較することもできます。2つの画像の共通のテーマを判断してみましょう。"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['User:<image>What event do they both represent?\\nAssistant: Christmas.']\n"
]
}
],
"source": [
"\n",
"# 画像を読み込む\n",
"image2 = load_image(image_url2)\n",
"\n",
"# 入力メッセージを作成\n",
"messages = [\n",
" # {\n",
" # \"role\": \"user\",\n",
" # \"content\": [\n",
" # {\"type\": \"image\"},\n",
" # {\"type\": \"image\"},\n",
" # {\"type\": \"text\", \"text\": \"Can you describe the two images?\"}\n",
" # ]\n",
" # },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"image\"},\n",
" {\"type\": \"image\"},\n",
" {\"type\": \"text\", \"text\": \"What event do they both represent?\"}\n",
" ]\n",
" },\n",
"]\n",
"\n",
"# 入力を準備\n",
"prompt = processor.apply_chat_template(messages, add_generation_prompt=True)\n",
"inputs = processor(text=prompt, images=[image1, image2], return_tensors=\"pt\")\n",
"inputs = inputs.to(device)\n",
"\n",
"# 出力を生成\n",
"generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
"generated_texts = processor.batch_decode(\n",
" generated_ids,\n",
" skip_special_tokens=True,\n",
")\n",
"\n",
"print(generated_texts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🔠 テキスト認識(OCR)\n",
"VLMは画像内のテキストを認識して解釈することもでき、ドキュメント分析などのタスクに適しています。\n",
"テキストが密集している画像で実験してみてください。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<img src=\"https://cdn.pixabay.com/photo/2020/11/30/19/23/christmas-5792015_960_720.png\"/>"
],
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"['User:<image>What is written?\\nAssistant: MERRY CHRISTMAS AND A HAPPY NEW YEAR']\n"
]
}
],
"source": [
"document_image_url = \"https://cdn.pixabay.com/photo/2020/11/30/19/23/christmas-5792015_960_720.png\"\n",
"display(Image(url=document_image_url))\n",
"\n",
"# ドキュメント画像を読み込む\n",
"document_image = load_image(document_image_url)\n",
"\n",
"# 分析のための入力メッセージを作成\n",
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"image\"},\n",
" {\"type\": \"text\", \"text\": \"What is written?\"}\n",
" ]\n",
" }\n",
"]\n",
"\n",
"# 入力を準備\n",
"prompt = processor.apply_chat_template(messages, add_generation_prompt=True)\n",
"inputs = processor(text=prompt, images=[document_image], return_tensors=\"pt\")\n",
"inputs = inputs.to(device)\n",
"\n",
"# 出力を生成\n",
"generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
"generated_texts = processor.batch_decode(\n",
" generated_ids,\n",
" skip_special_tokens=True,\n",
")\n",
"\n",
"print(generated_texts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ビデオの処理\n",
"\n",
"視覚言語モデル(VLM)は、キーフレームを抽出し、時間順に推論することで間接的にビデオを処理できます。VLMは専用のビデオモデルの時間的モデリング機能を欠いていますが、それでも以下のことが可能です:\n",
"- サンプリングされたフレームを順次分析することで、アクションやイベントを説明する。\n",
"- 代表的なキーフレームに基づいてビデオに関する質問に答える。\n",
"- 複数のフレームのテキスト記述を組み合わせてビデオの内容を要約する。\n",
"\n",
"例を使って実験してみましょう:\n",
"\n",
"<video width=\"640\" height=\"360\" controls>\n",
" <source src=\"https://cdn.pixabay.com/video/2023/10/28/186794-879050032_large.mp4\" type=\"video/mp4\">\n",
" Your browser does not support the video tag.\n",
"</video>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !pip install opencv-python"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Video\n",
"import cv2\n",
"import numpy as np\n",
"\n",
"def extract_frames(video_path, max_frames=50, target_size=None):\n",
" cap = cv2.VideoCapture(video_path)\n",
" if not cap.isOpened():\n",
" raise ValueError(f\"Could not open video: {video_path}\")\n",
" \n",
" total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
" frame_indices = np.linspace(0, total_frames - 1, max_frames, dtype=int)\n",
"\n",
" frames = []\n",
" for idx in frame_indices:\n",
" cap.set(cv2.CAP_PROP_POS_FRAMES, idx)\n",
" ret, frame = cap.read()\n",
" if ret:\n",
" frame = PIL.Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n",
" if target_size:\n",
" frames.append(resize_and_crop(frame, target_size))\n",
" else:\n",
" frames.append(frame)\n",
" cap.release()\n",
" return frames\n",
"\n",
"def resize_and_crop(image, target_size):\n",
" width, height = image.size\n",
" scale = target_size / min(width, height)\n",
" image = image.resize((int(width * scale), int(height * scale)), PIL.Image.Resampling.LANCZOS)\n",
" left = (image.width - target_size) // 2\n",
" top = (image.height - target_size) // 2\n",
" return image.crop((left, top, left + target_size, top + target_size))\n",
"\n",
"# ビデオリンク\n",
"video_link = \"https://cdn.pixabay.com/video/2023/10/28/186794-879050032_large.mp4\""
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Response: User: Following are the frames of a video in temporal order.<image>Describe what the woman is doing.\n",
"Assistant: The woman is hanging an ornament on a Christmas tree.\n"
]
}
],
"source": [
"question = \"Describe what the woman is doing.\"\n",
"\n",
"def generate_response(model, processor, frames, question):\n",
"\n",
" image_tokens = [{\"type\": \"image\"} for _ in frames]\n",
" messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [{\"type\": \"text\", \"text\": \"Following are the frames of a video in temporal order.\"}, *image_tokens, {\"type\": \"text\", \"text\": question}]\n",
" }\n",
" ]\n",
" inputs = processor(\n",
" text=processor.apply_chat_template(messages, add_generation_prompt=True),\n",
" images=frames,\n",
" return_tensors=\"pt\"\n",
" ).to(model.device)\n",
"\n",
" outputs = model.generate(\n",
" **inputs, max_new_tokens=100, num_beams=5, temperature=0.7, do_sample=True, use_cache=True\n",
" )\n",
" return processor.decode(outputs[0], skip_special_tokens=True)\n",
"\n",
"\n",
"# ビデオからフレームを抽出\n",
"frames = extract_frames(video_link, max_frames=15, target_size=384)\n",
"\n",
"processor.image_processor.size = (384, 384)\n",
"processor.image_processor.do_resize = False\n",
"# 応答を生成\n",
"response = generate_response(model, processor, frames, question)\n",
"\n",
"# 結果を表示\n",
"# print(\"Question:\", question)\n",
"print(\"Response:\", response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 💐 完了です!\n",
"\n",
"このノートブックでは、マルチモーダルタスクのプロンプトをフォーマットする方法など、視覚言語モデル(VLM)を使用する方法を示しました。ここで説明した手順に従うことで、VLMとそのアプリケーションを実験できます。\n",
"\n",
"### 次のステップ\n",
"- VLMのさらなる使用例を試してみてください。\n",
"- 同僚と協力して、彼らのプルリクエスト(PR)をレビューしてください。\n",
"- 新しい使用例、例、概念を導入するために、問題を開いたりPRを提出したりして、このコース資料の改善に貢献してください。\n",
"\n",
"楽しい探求を!🌟"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py310",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}