seed/make_qa_only_image_pdf.ipynb (409 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate QnA synthetic dataset from a PDF - Image-heavy PDF\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "from dotenv import load_dotenv\n", "load_dotenv()\n", "\n", "aoai_api_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\")\n", "aoai_api_key = os.getenv(\"AZURE_OPENAI_API_KEY\")\n", "aoai_api_version = os.getenv(\"AZURE_OPENAI_API_VERSION\")\n", "aoai_deployment_name = os.getenv(\"AZURE_OPENAI_DEPLOYMENT_NAME\")\n", "\n", "if not aoai_api_version:\n", " aoai_api_version = os.getenv(\"OPENAI_API_VERSION\")\n", "if not aoai_deployment_name:\n", " aoai_deployment_name = os.getenv(\"DEPLOYMENT_NAME\")\n", " \n", "print(f\"aoai_api_endpoint: {aoai_api_endpoint}\")\n", "print(f\"aoai_api_key: {aoai_api_key}\")\n", "print(f\"aoai_api_version: {aoai_api_version}\")\n", "print(f\"aoai_deployment_name: {aoai_deployment_name}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Split PDF into individual pages\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# import fitz\n", "# raw_data_dir = \"../raw_data\"\n", "\n", "# file_path = f\"{raw_data_dir}/pdf/img-advertising-generated-by-ai.pdf\"\n", "\n", "# # Open the first PDF document\n", "# doc1 = fitz.open(file_path)\n", "# #split_pages = [(4, 122), (4, 194)]\n", "# split_pages = [(1, 5)]\n", "# for idx, s in enumerate(split_pages):\n", "# # Create a new empty PDF document\n", "# doc2 = fitz.open()\n", "\n", "# # Insert the first 2 pages of doc1 into doc2\n", "# doc2.insert_pdf(doc1, from_page=s[0], to_page=s[1])\n", "\n", "# # Save the modified document\n", "# doc2.save(f\"{raw_data_dir}/part{idx}.pdf\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, shutil, random\n", "from langchain_community.document_loaders.csv_loader import CSVLoader\n", "from util.preprocess import remove_short_sentences, remove_small_images\n", "from util.common_utils import get_language_code\n", "\n", "image_dir = \"./image\"\n", "raw_data_dir = \"../raw_data\"\n", "\n", "if os.path.isdir(image_dir): shutil.rmtree(image_dir)\n", "os.makedirs(image_dir, exist_ok=True)\n", "\n", "DOMAIN = \"Advertising\"\n", "LANGUAGE = \"English\" # You can change your language here. e.g., \"Korean\", \"Japanese\", \"Chinese\"\n", "LANGUAGE_CODE = get_language_code(LANGUAGE)\n", "print(f\"Domain: {DOMAIN}, Language: {LANGUAGE}, Language Code: {LANGUAGE_CODE}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preprocess PDF file (image part)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fitz\n", "from glob import glob\n", "\n", "file_path = f\"{raw_data_dir}/pdf/img-advertising-generated-by-ai.pdf\"\n", "\n", "doc = fitz.open(file_path)\n", "clip_x, clip_y = 30, 30\n", "\n", "for i, page in enumerate(doc):\n", " x, y, w, h = page.rect\n", " clip = fitz.Rect(x+clip_x, y+clip_y, w-clip_x, h-clip_y)\n", " page.set_cropbox(clip)\n", " pix = page.get_pixmap()\n", " pix.save(f\"{image_dir}/page_{i:03d}.jpg\")\n", "\n", "images = sorted(glob(os.path.join(image_dir, \"*.jpg\")))\n", "max_tokens = 1024" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain.schema.output_parser import StrOutputParser\n", "from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate\n", "\n", "from langchain_openai import AzureChatOpenAI\n", "llm = AzureChatOpenAI(\n", " temperature=0, \n", " max_tokens=max_tokens,\n", " openai_api_version=aoai_api_version,\n", " azure_deployment=aoai_deployment_name \n", ")\n", "\n", "human_prompt_main = f\"Given image, give a concise summary in {LANGUAGE}. Don't insert any XML tag such as <text> and </text> when answering.\"\n", "\n", "system_prompt = \"You are an assistant tasked with describing table or image, specialized in Smartphone product.\"\n", "system_message_template = SystemMessagePromptTemplate.from_template(system_prompt)\n", "human_prompt = [\n", " {\n", " \"type\": \"image_url\",\n", " \"image_url\": {\n", " \"url\": \"data:image/png;base64,\" + \"{image_base64}\",\n", " },\n", " },\n", " {\n", " \"type\": \"text\",\n", " \"text\": human_prompt_main\n", " },\n", "]\n", "human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)\n", "\n", "prompt = ChatPromptTemplate.from_messages(\n", " [\n", " system_message_template,\n", " human_message_template\n", " ]\n", ")\n", "\n", "summarize_chain = prompt | llm | StrOutputParser()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "from util.preprocess import encode_image_base64\n", "#images = glob(os.path.join(image_dir, \"*.jpg\"))\n", "base64_images = [encode_image_base64(img_path) for img_path in images]\n", "image_summaries = summarize_chain.batch(base64_images, {\"max_concurrency\": 8})\n", "image_summaries = remove_short_sentences(image_summaries)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_summaries[:3]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preprocess PDF file (text part)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain_community.document_loaders import PyMuPDFLoader\n", "\n", "loader = PyMuPDFLoader(file_path)\n", "docs = loader.load()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import re\n", "import tiktoken\n", "from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter\n", "\n", "tokenizer = tiktoken.get_encoding('o200k_base')\n", "\n", "# create the length function\n", "def tiktoken_len(text):\n", " tokens = tokenizer.encode(\n", " text,\n", " disallowed_special=()\n", " )\n", " return len(tokens)\n", "\n", "text_splitter = RecursiveCharacterTextSplitter(\n", " # Set a really small chunk size, just to show.\n", " chunk_size=1024,\n", " chunk_overlap=100,\n", " length_function=tiktoken_len,\n", " separators=[\n", " \"\\n\\n\",\n", " \"\\n\",\n", " \" \",\n", " \".\",\n", " \",\",\n", " \"\\u200b\", # Zero-width space\n", " \"\\uff0c\", # Fullwidth comma\n", " \"\\u3001\", # Ideographic comma\n", " \"\\uff0e\", # Fullwidth full stop\n", " \"\\u3002\", # Ideographic full stop\n", " \"\",\n", " ], \n", ")\n", "\n", "# split_docs = text_splitter.split_documents(docs)\n", "# print(f'Number of splitted docs: {len(split_docs)}')\n", "\n", "a = [re.sub(' +', ' ', doc.page_content) for doc in docs]\n", "joined_docs = '\\n\\n'.join(a)\n", "\n", "split_docs = text_splitter.split_text(joined_docs)\n", "print(f'Number of splitted docs: {len(split_docs)}')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "split_docs[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Construct QnA Pairs\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain_openai import AzureChatOpenAI\n", "from langchain_core.runnables import RunnablePassthrough, RunnableLambda\n", "from langchain_core.output_parsers import JsonOutputParser\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", "\n", "from util.qa_pair import get_qna_prompt_template, QAPair\n", "\n", "llm = AzureChatOpenAI(\n", " temperature=0, \n", " max_tokens=1024,\n", " openai_api_version=aoai_api_version,\n", " azure_deployment=aoai_deployment_name \n", ")\n", "\n", "parser = JsonOutputParser(pydantic_object=QAPair)\n", "prompt = get_qna_prompt_template(LANGUAGE)\n", "\n", "chain = prompt | llm | parser" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "input_batch = []\n", "\n", "for doc in split_docs:\n", " dic = {\"context\": doc, \"domain\": DOMAIN, \"num_questions\": \"3\"}\n", " input_batch.append(dic)\n", "\n", "#for doc in image_summaries_tiktoken:\n", "for doc in image_summaries:\n", " dic = {\"context\": doc, \"domain\": DOMAIN, \"num_questions\": \"3\"}\n", " input_batch.append(dic)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "qa_pair = chain.batch(input_batch, {\"max_concurrency\": 8})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Save to jsonl\n", "\n", "---\n", "\n", "If you want to augment dataset, you can try Evovle-Instruct or other data augmentation techniques.<br>\n", "Please refer to `../evolve-instruct` and `../glan-instruct` for more details.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "from util.common_utils import convert_to_oai_format, save_jsonl\n", "\n", "output_dir = './dataset'\n", "os.makedirs(output_dir, exist_ok=True)\n", "\n", "system_prompt_msg = f\"\"\"You are the SME (Subject Matter Expert) in {DOMAIN}. Please answer the questions accurately. If the question is in {LANGUAGE}, write your answer in {LANGUAGE}.\"\"\"\n", "\n", "save_filename = \"advertising\"\n", "oai_qa_pair = convert_to_oai_format(qa_pair, system_prompt_msg=system_prompt_msg)\n", "\n", "#save_jsonl(qa_pair, f\"{output_dir}/{save_filename}.jsonl\")\n", "save_jsonl(oai_qa_pair, f\"{output_dir}/{save_filename}-oai.jsonl\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!rm -rf pdf_image_tmp outputs_tmp image" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10 - SDK v2", "language": "python", "name": "python310-sdkv2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 4 }