seed/make_qa_only_image_multiple_pdf.ipynb (281 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate QnA synthetic dataset from multiple PDFs - Image-heavy PDF\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "from dotenv import load_dotenv\n", "load_dotenv()\n", "\n", "aoai_api_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\")\n", "aoai_api_key = os.getenv(\"AZURE_OPENAI_API_KEY\")\n", "aoai_api_version = os.getenv(\"AZURE_OPENAI_API_VERSION\")\n", "aoai_deployment_name = os.getenv(\"AZURE_OPENAI_DEPLOYMENT_NAME\")\n", "\n", "if not aoai_api_version:\n", " aoai_api_version = os.getenv(\"OPENAI_API_VERSION\")\n", "if not aoai_deployment_name:\n", " aoai_deployment_name = os.getenv(\"DEPLOYMENT_NAME\")\n", " \n", "print(f\"aoai_api_endpoint: {aoai_api_endpoint}\")\n", "print(f\"aoai_api_key: {aoai_api_key}\")\n", "print(f\"aoai_api_version: {aoai_api_version}\")\n", "print(f\"aoai_deployment_name: {aoai_deployment_name}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "import glob\n", "import pandas as pd\n", "import shutil, random\n", "from langchain_community.document_loaders.csv_loader import CSVLoader\n", "from util.preprocess import convert_html_to_md, remove_short_sentences, remove_small_images\n", "from util.common_utils import get_language_code\n", "\n", "DOMAIN = \"Advertising\"\n", "LANGUAGE = \"English\" # You can change your language here. e.g., \"Korean\", \"Japanese\", \"Chinese\"\n", "LANGUAGE_CODE = get_language_code(LANGUAGE)\n", "print(f\"Domain: {DOMAIN}, Language: {LANGUAGE}, Language Code: {LANGUAGE_CODE}\")\n", "\n", "raw_data_dir = \"../raw_data\"\n", "pdf_dir = f\"{raw_data_dir}/pdf\"\n", "dataset_tmp_dir = \"dataset_tmp\"\n", "\n", "all_files = glob.glob(os.path.join(pdf_dir, \"img-*.pdf\"))\n", "print(all_files)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import fitz\n", "from glob import glob\n", "from langchain.schema.output_parser import StrOutputParser\n", "from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate\n", "from langchain_openai import AzureChatOpenAI\n", "\n", "from util.preprocess import encode_image_base64\n", "from langchain_core.runnables import RunnablePassthrough, RunnableLambda\n", "from langchain_core.output_parsers import JsonOutputParser\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", "from util.qa_pair import get_qna_prompt_template, QAPair\n", "from util.common_utils import convert_to_oai_format, save_jsonl\n", "\n", "max_tokens = 1024\n", "\n", "llm = AzureChatOpenAI(\n", " temperature=0, \n", " max_tokens=max_tokens,\n", " openai_api_version=aoai_api_version,\n", " azure_deployment=aoai_deployment_name \n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Preprocess each PDF file\n", "\n", "---\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for idx, file_path in enumerate(all_files):\n", "\n", " print(f\"\\n##### Idx {idx} - Processing {file_path}...\")\n", "\n", " image_path = \"./image\"\n", " if os.path.isdir(image_path): shutil.rmtree(image_path)\n", " os.makedirs(image_path, exist_ok=True)\n", "\n", " doc = fitz.open(file_path)\n", " doc.delete_page(0) # 1st page is the cover page, so we delete it.\n", " clip_x, clip_y = 30, 30\n", "\n", " for i, page in enumerate(doc):\n", " x, y, w, h = page.rect\n", " clip = fitz.Rect(x+clip_x, y+clip_y, w-clip_x, h-clip_y)\n", " page.set_cropbox(clip)\n", " pix = page.get_pixmap()\n", " pix.save(f\"{image_path}/page_{i:03d}.jpg\")\n", "\n", " images = sorted(glob(os.path.join(image_path, \"*.jpg\")))\n", "\n", " ### Generate image summariesd\n", " print(f\"### Generating image summaries using LLM - path: {file_path}\")\n", "\n", " start = time.time()\n", "\n", " human_prompt_main = f\"Given image, give a concise summary in {LANGUAGE}. Don't insert any XML tag such as <text> and </text> when answering.\"\n", "\n", " system_prompt = \"You are an assistant tasked with describing table or image, specialized in Smartphone product.\"\n", " system_message_template = SystemMessagePromptTemplate.from_template(system_prompt)\n", " human_prompt = [\n", " {\n", " \"type\": \"image_url\",\n", " \"image_url\": {\n", " \"url\": \"data:image/png;base64,\" + \"{image_base64}\",\n", " },\n", " },\n", " {\n", " \"type\": \"text\",\n", " \"text\": human_prompt_main\n", " },\n", " ]\n", " human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)\n", "\n", " prompt = ChatPromptTemplate.from_messages(\n", " [\n", " system_message_template,\n", " human_message_template\n", " ]\n", " )\n", "\n", " summarize_chain = prompt | llm | StrOutputParser()\n", " base64_images = [encode_image_base64(img_path) for img_path in images]\n", " image_summaries = summarize_chain.batch(base64_images, {\"max_concurrency\": 8})\n", " image_summaries = remove_short_sentences(image_summaries)\n", " end = time.time()\n", "\n", " print(f\"Elasped {end - start:.5f} ses for generating image summaries using LLM\")\n", "\n", " ### Generate QA pair\n", " print(f\"### Generating QA pairs using LLM - path: {file_path}\")\n", " start = time.time()\n", "\n", " parser = JsonOutputParser(pydantic_object=QAPair)\n", " prompt = get_qna_prompt_template()\n", " #prompt = get_qna_repair_cost_prompt_template()\n", " chain = prompt | llm | parser\n", "\n", " input_batch = []\n", "\n", " for doc in image_summaries:\n", " dic = {\"context\": doc, \"domain\": \"Mobile phone\", \"num_questions\": \"3\"}\n", " input_batch.append(dic)\n", "\n", "\n", " qa_pair = chain.batch(input_batch, {\"max_concurrency\": 8})\n", " end = time.time()\n", "\n", " print(f\"Elasped {end - start:.5f} ses for generating image summaries using LLM\")\n", "\n", " ### Save to jsonl for fine-tuning\n", " print(f\"### Saving QA pairs to jsonl\")\n", " os.makedirs(dataset_tmp_dir, exist_ok=True)\n", "\n", " system_prompt_msg = f\"\"\"You are the SME (Subject Matter Expert) in {DOMAIN}. Please answer the questions accurately. If the question is in {LANGUAGE}, write your answer in {LANGUAGE}.\"\"\"\n", "\n", " oai_qa_pair = convert_to_oai_format(qa_pair, system_prompt_msg=system_prompt_msg)\n", "\n", " #save_jsonl(qa_pair, f\"{dataset_tmp_dir}/{idx}.jsonl\")\n", " save_jsonl(oai_qa_pair, f\"{dataset_tmp_dir}/{idx}-oai.jsonl\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Merge the generated jsonl files into a single jsonl file.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, shutil, random\n", "from util.preprocess import convert_html_to_md\n", "import json\n", "import glob\n", "import pandas as pd\n", "\n", "all_files = glob.glob(os.path.join(dataset_tmp_dir, \"*-oai.jsonl\"))\n", "\n", "result = []\n", "for f in all_files:\n", " with open(f, \"r\", encoding=\"utf-8-sig\") as infile:\n", " for line in infile.readlines():\n", " try:\n", " result.append(json.loads(line)) # read each line of the file\n", " except ValueError:\n", " print(f)\n", "\n", "save_filename = \"advertising-multiple\"\n", "\n", "output_dir = './dataset'\n", "with open(f\"{output_dir}/{save_filename}-oai.jsonl\", \"w\", encoding=\"utf-8-sig\") as outfile:\n", " for entry in result:\n", " outfile.write(json.dumps(entry, ensure_ascii=False) + \"\\n\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!rm -rf pdf_image_tmp pdf_mixed_tmp outputs_tmp images {dataset_tmp_dir}" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10 - SDK v2", "language": "python", "name": "python310-sdkv2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 4 }