seed/make_qa_only_image_pdf.ipynb (409 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Generate QnA synthetic dataset from a PDF - Image-heavy PDF\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"from dotenv import load_dotenv\n",
"load_dotenv()\n",
"\n",
"aoai_api_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\")\n",
"aoai_api_key = os.getenv(\"AZURE_OPENAI_API_KEY\")\n",
"aoai_api_version = os.getenv(\"AZURE_OPENAI_API_VERSION\")\n",
"aoai_deployment_name = os.getenv(\"AZURE_OPENAI_DEPLOYMENT_NAME\")\n",
"\n",
"if not aoai_api_version:\n",
" aoai_api_version = os.getenv(\"OPENAI_API_VERSION\")\n",
"if not aoai_deployment_name:\n",
" aoai_deployment_name = os.getenv(\"DEPLOYMENT_NAME\")\n",
" \n",
"print(f\"aoai_api_endpoint: {aoai_api_endpoint}\")\n",
"print(f\"aoai_api_key: {aoai_api_key}\")\n",
"print(f\"aoai_api_version: {aoai_api_version}\")\n",
"print(f\"aoai_deployment_name: {aoai_deployment_name}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Split PDF into individual pages\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# import fitz\n",
"# raw_data_dir = \"../raw_data\"\n",
"\n",
"# file_path = f\"{raw_data_dir}/pdf/img-advertising-generated-by-ai.pdf\"\n",
"\n",
"# # Open the first PDF document\n",
"# doc1 = fitz.open(file_path)\n",
"# #split_pages = [(4, 122), (4, 194)]\n",
"# split_pages = [(1, 5)]\n",
"# for idx, s in enumerate(split_pages):\n",
"# # Create a new empty PDF document\n",
"# doc2 = fitz.open()\n",
"\n",
"# # Insert the first 2 pages of doc1 into doc2\n",
"# doc2.insert_pdf(doc1, from_page=s[0], to_page=s[1])\n",
"\n",
"# # Save the modified document\n",
"# doc2.save(f\"{raw_data_dir}/part{idx}.pdf\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os, shutil, random\n",
"from langchain_community.document_loaders.csv_loader import CSVLoader\n",
"from util.preprocess import remove_short_sentences, remove_small_images\n",
"from util.common_utils import get_language_code\n",
"\n",
"image_dir = \"./image\"\n",
"raw_data_dir = \"../raw_data\"\n",
"\n",
"if os.path.isdir(image_dir): shutil.rmtree(image_dir)\n",
"os.makedirs(image_dir, exist_ok=True)\n",
"\n",
"DOMAIN = \"Advertising\"\n",
"LANGUAGE = \"English\" # You can change your language here. e.g., \"Korean\", \"Japanese\", \"Chinese\"\n",
"LANGUAGE_CODE = get_language_code(LANGUAGE)\n",
"print(f\"Domain: {DOMAIN}, Language: {LANGUAGE}, Language Code: {LANGUAGE_CODE}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Preprocess PDF file (image part)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import fitz\n",
"from glob import glob\n",
"\n",
"file_path = f\"{raw_data_dir}/pdf/img-advertising-generated-by-ai.pdf\"\n",
"\n",
"doc = fitz.open(file_path)\n",
"clip_x, clip_y = 30, 30\n",
"\n",
"for i, page in enumerate(doc):\n",
" x, y, w, h = page.rect\n",
" clip = fitz.Rect(x+clip_x, y+clip_y, w-clip_x, h-clip_y)\n",
" page.set_cropbox(clip)\n",
" pix = page.get_pixmap()\n",
" pix.save(f\"{image_dir}/page_{i:03d}.jpg\")\n",
"\n",
"images = sorted(glob(os.path.join(image_dir, \"*.jpg\")))\n",
"max_tokens = 1024"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.output_parser import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate\n",
"\n",
"from langchain_openai import AzureChatOpenAI\n",
"llm = AzureChatOpenAI(\n",
" temperature=0, \n",
" max_tokens=max_tokens,\n",
" openai_api_version=aoai_api_version,\n",
" azure_deployment=aoai_deployment_name \n",
")\n",
"\n",
"human_prompt_main = f\"Given image, give a concise summary in {LANGUAGE}. Don't insert any XML tag such as <text> and </text> when answering.\"\n",
"\n",
"system_prompt = \"You are an assistant tasked with describing table or image, specialized in Smartphone product.\"\n",
"system_message_template = SystemMessagePromptTemplate.from_template(system_prompt)\n",
"human_prompt = [\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": \"data:image/png;base64,\" + \"{image_base64}\",\n",
" },\n",
" },\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": human_prompt_main\n",
" },\n",
"]\n",
"human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" system_message_template,\n",
" human_message_template\n",
" ]\n",
")\n",
"\n",
"summarize_chain = prompt | llm | StrOutputParser()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"from util.preprocess import encode_image_base64\n",
"#images = glob(os.path.join(image_dir, \"*.jpg\"))\n",
"base64_images = [encode_image_base64(img_path) for img_path in images]\n",
"image_summaries = summarize_chain.batch(base64_images, {\"max_concurrency\": 8})\n",
"image_summaries = remove_short_sentences(image_summaries)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image_summaries[:3]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Preprocess PDF file (text part)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.document_loaders import PyMuPDFLoader\n",
"\n",
"loader = PyMuPDFLoader(file_path)\n",
"docs = loader.load()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"import tiktoken\n",
"from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter\n",
"\n",
"tokenizer = tiktoken.get_encoding('o200k_base')\n",
"\n",
"# create the length function\n",
"def tiktoken_len(text):\n",
" tokens = tokenizer.encode(\n",
" text,\n",
" disallowed_special=()\n",
" )\n",
" return len(tokens)\n",
"\n",
"text_splitter = RecursiveCharacterTextSplitter(\n",
" # Set a really small chunk size, just to show.\n",
" chunk_size=1024,\n",
" chunk_overlap=100,\n",
" length_function=tiktoken_len,\n",
" separators=[\n",
" \"\\n\\n\",\n",
" \"\\n\",\n",
" \" \",\n",
" \".\",\n",
" \",\",\n",
" \"\\u200b\", # Zero-width space\n",
" \"\\uff0c\", # Fullwidth comma\n",
" \"\\u3001\", # Ideographic comma\n",
" \"\\uff0e\", # Fullwidth full stop\n",
" \"\\u3002\", # Ideographic full stop\n",
" \"\",\n",
" ], \n",
")\n",
"\n",
"# split_docs = text_splitter.split_documents(docs)\n",
"# print(f'Number of splitted docs: {len(split_docs)}')\n",
"\n",
"a = [re.sub(' +', ' ', doc.page_content) for doc in docs]\n",
"joined_docs = '\\n\\n'.join(a)\n",
"\n",
"split_docs = text_splitter.split_text(joined_docs)\n",
"print(f'Number of splitted docs: {len(split_docs)}')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"split_docs[:5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Construct QnA Pairs\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import AzureChatOpenAI\n",
"from langchain_core.runnables import RunnablePassthrough, RunnableLambda\n",
"from langchain_core.output_parsers import JsonOutputParser\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"\n",
"from util.qa_pair import get_qna_prompt_template, QAPair\n",
"\n",
"llm = AzureChatOpenAI(\n",
" temperature=0, \n",
" max_tokens=1024,\n",
" openai_api_version=aoai_api_version,\n",
" azure_deployment=aoai_deployment_name \n",
")\n",
"\n",
"parser = JsonOutputParser(pydantic_object=QAPair)\n",
"prompt = get_qna_prompt_template(LANGUAGE)\n",
"\n",
"chain = prompt | llm | parser"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_batch = []\n",
"\n",
"for doc in split_docs:\n",
" dic = {\"context\": doc, \"domain\": DOMAIN, \"num_questions\": \"3\"}\n",
" input_batch.append(dic)\n",
"\n",
"#for doc in image_summaries_tiktoken:\n",
"for doc in image_summaries:\n",
" dic = {\"context\": doc, \"domain\": DOMAIN, \"num_questions\": \"3\"}\n",
" input_batch.append(dic)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"qa_pair = chain.batch(input_batch, {\"max_concurrency\": 8})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Save to jsonl\n",
"\n",
"---\n",
"\n",
"If you want to augment dataset, you can try Evovle-Instruct or other data augmentation techniques.<br>\n",
"Please refer to `../evolve-instruct` and `../glan-instruct` for more details.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from util.common_utils import convert_to_oai_format, save_jsonl\n",
"\n",
"output_dir = './dataset'\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"\n",
"system_prompt_msg = f\"\"\"You are the SME (Subject Matter Expert) in {DOMAIN}. Please answer the questions accurately. If the question is in {LANGUAGE}, write your answer in {LANGUAGE}.\"\"\"\n",
"\n",
"save_filename = \"advertising\"\n",
"oai_qa_pair = convert_to_oai_format(qa_pair, system_prompt_msg=system_prompt_msg)\n",
"\n",
"#save_jsonl(qa_pair, f\"{output_dir}/{save_filename}.jsonl\")\n",
"save_jsonl(oai_qa_pair, f\"{output_dir}/{save_filename}-oai.jsonl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!rm -rf pdf_image_tmp outputs_tmp image"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10 - SDK v2",
"language": "python",
"name": "python310-sdkv2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 4
}