src/models/rag.py (82 lines of code) (raw):
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Optional
import os
from time import time
from glob import glob
from typing import List, Dict, Literal
import numpy as np
import pdb
from llama_index.core import VectorStoreIndex, Document, SimpleKeywordTableIndex, Settings
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.schema import QueryBundle, QueryType, NodeWithScore
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.core.retrievers import VectorIndexRetriever, KeywordTableSimpleRetriever
from llama_index.core.storage.index_store.simple_index_store import SimpleIndexStore
from llama_index.core.vector_stores.simple import SimpleVectorStore
from llama_index.core.vector_stores.types import VectorStoreQueryMode
Settings.llm = None
class LlamaIndexRetriever(object):
def __init__(self, chunk_list: List[Dict[str, str]] = [],
storage_path: str = './data/textbooks/rag_storage',
emb_model_path: str = "local:../../models/rag_embedding/bge-m3",
chunk_size: int = 1024, similarity_top_k: int = 3, hybrid_search: bool = False,
reranker_path: Optional[str] = None, rerank_top_n: int = 3,
**kwargs):
os.makedirs(storage_path, exist_ok=True)
if len(os.listdir(storage_path)) == 0:
assert len(chunk_list) > 0
documents = [Document(text=chunk['data'], doc_id=chunk['idx']) for chunk in chunk_list]
node_parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=0)
nodes = node_parser.get_nodes_from_documents(documents)
self.index = VectorStoreIndex(nodes, embed_model=emb_model_path, show_progress=True, **kwargs)
self.index.storage_context.persist(storage_path)
else:
print('Loading LlamaIndex Storage ...')
t0 = time()
storage_context = StorageContext.from_defaults(persist_dir=storage_path)
self.index = load_index_from_storage(storage_context, embed_model=emb_model_path)
print(f'Done in {time() - t0:.1} seconds.')
self.retriever: VectorIndexRetriever = self.index.as_retriever(similarity_top_k=similarity_top_k)
if hybrid_search:
nodes = list(self.index.storage_context.docstore.docs.values())
self.keyword_index = SimpleKeywordTableIndex(nodes, show_progress=True)
self.keyword_retriever: KeywordTableSimpleRetriever = \
self.keyword_index.as_retriever(num_chunks_per_query=similarity_top_k)
else:
self.keyword_retriever = None
if reranker_path:
self.rerank = SentenceTransformerRerank(top_n=rerank_top_n, model=reranker_path)
else:
self.rerank = None
def retrieve(self, query, return_text=True):
nodes: List[NodeWithScore] = self.retriever.retrieve(query)
if self.keyword_retriever: # hybrid search
kw_nodes: List[NodeWithScore] = self.keyword_retriever.retrieve(query)
node_dict = {node.node.node_id: node for node in (nodes + kw_nodes)}
nodes = list(node_dict.values()) # union
if self.rerank:
nodes = self.rerank.postprocess_nodes(nodes, query_bundle=QueryBundle(query))
if return_text:
nodes = [node.get_text() for node in nodes]
return nodes
def set_topk(self, topk):
if self.rerank:
self.rerank.top_n = topk
else:
self.retriever.similarity_top_k = topk
if self.keyword_retriever:
self.keyword_retriever.num_chunks_per_query = topk
if __name__ == '__main__':
chunk_size = 512
prefix = '_sp'
llama_index = LlamaIndexRetriever(
data_dir=f'./data/textbooks/rag_doc{prefix}_{chunk_size}',
storage_path=f'./data/textbooks/rag_storage{prefix}_{chunk_size}',
chunk_size=chunk_size
# reranker_path='../../models/rag_embedding/bge-reranker-v2-m3'
)
# query = "A 79-year-old man presents to the office due to shortness of breath with moderate exertion and a slightly productive cough. He has a medical history of 25 years of heavy smoking. His vitals include: heart rate 89/min, respiratory rate 27/min, and blood pressure 120/90 mm Hg. The physical exam shows increased resonance to percussion, decreased breath sounds, and crackles at the lung base. Chest radiography shows signs of pulmonary hyperinflation. Spirometry shows a forced expiratory volume in the first second (FEV1) of 48%, a forced vital capacity (FVC) of 85%, and an FEV1/FVC ratio of 56%. According to these results, what is the most likely diagnosis?"
# query = "经调查证实出现医院感染流行时,医院应报告当地卫生行政部门的时间是( )。"
# query = "Parmi les techniques voltampérométriques, on trouve:"
# query = 'Противокашлевое действие бутамирата цитрата обусловлено главным образом воздействием на кашлевой центр в мозге?'
query = 'El complejo proteico responsable de la ramificación de los filamentos de actina es:'
res = llama_index.retrieve(query, return_text=False)
print(len(res))
chunk = res[0].get_text()
print(len(chunk.split()))
# print(chunk)
print(res[0].metadata['file_name']) # file_idx