seed/make_qa_only_image_multiple_pdf.ipynb (281 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Generate QnA synthetic dataset from multiple PDFs - Image-heavy PDF\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from dotenv import load_dotenv\n",
"load_dotenv()\n",
"\n",
"aoai_api_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\")\n",
"aoai_api_key = os.getenv(\"AZURE_OPENAI_API_KEY\")\n",
"aoai_api_version = os.getenv(\"AZURE_OPENAI_API_VERSION\")\n",
"aoai_deployment_name = os.getenv(\"AZURE_OPENAI_DEPLOYMENT_NAME\")\n",
"\n",
"if not aoai_api_version:\n",
" aoai_api_version = os.getenv(\"OPENAI_API_VERSION\")\n",
"if not aoai_deployment_name:\n",
" aoai_deployment_name = os.getenv(\"DEPLOYMENT_NAME\")\n",
" \n",
"print(f\"aoai_api_endpoint: {aoai_api_endpoint}\")\n",
"print(f\"aoai_api_key: {aoai_api_key}\")\n",
"print(f\"aoai_api_version: {aoai_api_version}\")\n",
"print(f\"aoai_deployment_name: {aoai_deployment_name}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import glob\n",
"import pandas as pd\n",
"import shutil, random\n",
"from langchain_community.document_loaders.csv_loader import CSVLoader\n",
"from util.preprocess import convert_html_to_md, remove_short_sentences, remove_small_images\n",
"from util.common_utils import get_language_code\n",
"\n",
"DOMAIN = \"Advertising\"\n",
"LANGUAGE = \"English\" # You can change your language here. e.g., \"Korean\", \"Japanese\", \"Chinese\"\n",
"LANGUAGE_CODE = get_language_code(LANGUAGE)\n",
"print(f\"Domain: {DOMAIN}, Language: {LANGUAGE}, Language Code: {LANGUAGE_CODE}\")\n",
"\n",
"raw_data_dir = \"../raw_data\"\n",
"pdf_dir = f\"{raw_data_dir}/pdf\"\n",
"dataset_tmp_dir = \"dataset_tmp\"\n",
"\n",
"all_files = glob.glob(os.path.join(pdf_dir, \"img-*.pdf\"))\n",
"print(all_files)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import fitz\n",
"from glob import glob\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate\n",
"from langchain_openai import AzureChatOpenAI\n",
"\n",
"from util.preprocess import encode_image_base64\n",
"from langchain_core.runnables import RunnablePassthrough, RunnableLambda\n",
"from langchain_core.output_parsers import JsonOutputParser\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"from util.qa_pair import get_qna_prompt_template, QAPair\n",
"from util.common_utils import convert_to_oai_format, save_jsonl\n",
"\n",
"max_tokens = 1024\n",
"\n",
"llm = AzureChatOpenAI(\n",
" temperature=0, \n",
" max_tokens=max_tokens,\n",
" openai_api_version=aoai_api_version,\n",
" azure_deployment=aoai_deployment_name \n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Preprocess each PDF file\n",
"\n",
"---\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for idx, file_path in enumerate(all_files):\n",
"\n",
" print(f\"\\n##### Idx {idx} - Processing {file_path}...\")\n",
"\n",
" image_path = \"./image\"\n",
" if os.path.isdir(image_path): shutil.rmtree(image_path)\n",
" os.makedirs(image_path, exist_ok=True)\n",
"\n",
" doc = fitz.open(file_path)\n",
" doc.delete_page(0) # 1st page is the cover page, so we delete it.\n",
" clip_x, clip_y = 30, 30\n",
"\n",
" for i, page in enumerate(doc):\n",
" x, y, w, h = page.rect\n",
" clip = fitz.Rect(x+clip_x, y+clip_y, w-clip_x, h-clip_y)\n",
" page.set_cropbox(clip)\n",
" pix = page.get_pixmap()\n",
" pix.save(f\"{image_path}/page_{i:03d}.jpg\")\n",
"\n",
" images = sorted(glob(os.path.join(image_path, \"*.jpg\")))\n",
"\n",
" ### Generate image summariesd\n",
" print(f\"### Generating image summaries using LLM - path: {file_path}\")\n",
"\n",
" start = time.time()\n",
"\n",
" human_prompt_main = f\"Given image, give a concise summary in {LANGUAGE}. Don't insert any XML tag such as <text> and </text> when answering.\"\n",
"\n",
" system_prompt = \"You are an assistant tasked with describing table or image, specialized in Smartphone product.\"\n",
" system_message_template = SystemMessagePromptTemplate.from_template(system_prompt)\n",
" human_prompt = [\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": \"data:image/png;base64,\" + \"{image_base64}\",\n",
" },\n",
" },\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": human_prompt_main\n",
" },\n",
" ]\n",
" human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)\n",
"\n",
" prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" system_message_template,\n",
" human_message_template\n",
" ]\n",
" )\n",
"\n",
" summarize_chain = prompt | llm | StrOutputParser()\n",
" base64_images = [encode_image_base64(img_path) for img_path in images]\n",
" image_summaries = summarize_chain.batch(base64_images, {\"max_concurrency\": 8})\n",
" image_summaries = remove_short_sentences(image_summaries)\n",
" end = time.time()\n",
"\n",
" print(f\"Elasped {end - start:.5f} ses for generating image summaries using LLM\")\n",
"\n",
" ### Generate QA pair\n",
" print(f\"### Generating QA pairs using LLM - path: {file_path}\")\n",
" start = time.time()\n",
"\n",
" parser = JsonOutputParser(pydantic_object=QAPair)\n",
" prompt = get_qna_prompt_template()\n",
" #prompt = get_qna_repair_cost_prompt_template()\n",
" chain = prompt | llm | parser\n",
"\n",
" input_batch = []\n",
"\n",
" for doc in image_summaries:\n",
" dic = {\"context\": doc, \"domain\": \"Mobile phone\", \"num_questions\": \"3\"}\n",
" input_batch.append(dic)\n",
"\n",
"\n",
" qa_pair = chain.batch(input_batch, {\"max_concurrency\": 8})\n",
" end = time.time()\n",
"\n",
" print(f\"Elasped {end - start:.5f} ses for generating image summaries using LLM\")\n",
"\n",
" ### Save to jsonl for fine-tuning\n",
" print(f\"### Saving QA pairs to jsonl\")\n",
" os.makedirs(dataset_tmp_dir, exist_ok=True)\n",
"\n",
" system_prompt_msg = f\"\"\"You are the SME (Subject Matter Expert) in {DOMAIN}. Please answer the questions accurately. If the question is in {LANGUAGE}, write your answer in {LANGUAGE}.\"\"\"\n",
"\n",
" oai_qa_pair = convert_to_oai_format(qa_pair, system_prompt_msg=system_prompt_msg)\n",
"\n",
" #save_jsonl(qa_pair, f\"{dataset_tmp_dir}/{idx}.jsonl\")\n",
" save_jsonl(oai_qa_pair, f\"{dataset_tmp_dir}/{idx}-oai.jsonl\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Merge the generated jsonl files into a single jsonl file.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os, shutil, random\n",
"from util.preprocess import convert_html_to_md\n",
"import json\n",
"import glob\n",
"import pandas as pd\n",
"\n",
"all_files = glob.glob(os.path.join(dataset_tmp_dir, \"*-oai.jsonl\"))\n",
"\n",
"result = []\n",
"for f in all_files:\n",
" with open(f, \"r\", encoding=\"utf-8-sig\") as infile:\n",
" for line in infile.readlines():\n",
" try:\n",
" result.append(json.loads(line)) # read each line of the file\n",
" except ValueError:\n",
" print(f)\n",
"\n",
"save_filename = \"advertising-multiple\"\n",
"\n",
"output_dir = './dataset'\n",
"with open(f\"{output_dir}/{save_filename}-oai.jsonl\", \"w\", encoding=\"utf-8-sig\") as outfile:\n",
" for entry in result:\n",
" outfile.write(json.dumps(entry, ensure_ascii=False) + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!rm -rf pdf_image_tmp pdf_mixed_tmp outputs_tmp images {dataset_tmp_dir}"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10 - SDK v2",
"language": "python",
"name": "python310-sdkv2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 4
}