I am trying to make a private llm with RAG capabilities. I successfully followed a few tutorials and made one. But I wish to view the context the MultiVectorRetriever
retriever used when langchain invokes my query.
This is my code:
from langchain_core.output_parsers import StrOutputParserfrom langchain_core.prompts import ChatPromptTemplatefrom langchain.retrievers.multi_vector import MultiVectorRetrieverfrom langchain.storage import InMemoryStorefrom langchain_community.chat_models import ChatOllamafrom langchain_community.embeddings import GPT4AllEmbeddingsfrom langchain_community.vectorstores import Chromafrom langchain_core.documents import Documentfrom langchain_core.runnables import RunnablePassthroughfrom PIL import Imageimport ioimport osimport uuidimport jsonimport base64def convert_bytes_to_base64(image_bytes): encoded_string= base64.b64encode(image_bytes).decode("utf-8") return "data:image/jpeg;base64," + encoded_string#Load Retrieverpath="./vectorstore/pdf_test_file.pdf"#Load from JSON filestexts = json.load(open(os.path.join(path, "json", "texts.json")))text_summaries = json.load(open(os.path.join(path, "json", "text_summaries.json")))tables = json.load(open(os.path.join(path, "json", "tables.json")))table_summaries = json.load(open(os.path.join(path, "json", "table_summaries.json")))img_summaries = json.load(open(os.path.join(path, "json", "img_summaries.json")))#Load from figuresimages_base64_list = []for image in (os.listdir(os.path.join(path, "figures"))): img = Image.open(os.path.join(path, "figures",image)) buffered = io.BytesIO() img.save(buffered,format="png") image_base64 = convert_bytes_to_base64(buffered.getvalue()) #Warning: this section of the code does not support external IDEs like spyder and will break. Run it loccally in the native terminal images_base64_list.append(image_base64)#Add to vectorstore# The vectorstore to use to index the child chunksvectorstore = Chroma( collection_name="summaries", embedding_function=GPT4AllEmbeddings())# The storage layer for the parent documentsstore = InMemoryStore() # <- Can we extend this to imagesid_key = "doc_id"# The retriever (empty to start)retriever = MultiVectorRetriever( vectorstore=vectorstore, docstore=store, id_key=id_key,)# Add textsdoc_ids = [str(uuid.uuid4()) for _ in texts]summary_texts = [ Document(page_content=s, metadata={id_key: doc_ids[i]}) for i, s in enumerate(text_summaries)]retriever.vectorstore.add_documents(summary_texts)retriever.docstore.mset(list(zip(doc_ids, texts)))# Add tablestable_ids = [str(uuid.uuid4()) for _ in tables]summary_tables = [ Document(page_content=s, metadata={id_key: table_ids[i]}) for i, s in enumerate(table_summaries)]retriever.vectorstore.add_documents(summary_tables)retriever.docstore.mset(list(zip(table_ids, tables)))# Add imagesimg_ids = [str(uuid.uuid4()) for _ in img_summaries]summary_img = [ Document(page_content=s, metadata={id_key: img_ids[i]}) for i, s in enumerate(img_summaries)]retriever.vectorstore.add_documents(summary_img)retriever.docstore.mset( list(zip(img_ids, img_summaries))) # Store the image summary as the raw documentimg_summaries_ids_and_images_base64=[]count=0for img in images_base64_list: new_summary = [img_ids[count],img] img_summaries_ids_and_images_base64.append(new_summary) count+=1# Check Response# Question Example: "What is the issues plagueing the acres?""""Testing Retrivalprint("\nTesting Retrival: \n")prompt = "Images / figures with playful and creative examples"responce = retriever.get_relevant_documents(prompt)[0]print(responce)""""""retriever.vectorstore.similarity_search("What is the issues plagueing the acres? show any relevant tables",k=10)"""# Prompt templatetemplate = """Answer the question based only on the following context, which can include text, tables and images/figures:{context}Question: {question}"""prompt = ChatPromptTemplate.from_template(template)# Multi-modal LLM# model = LLaVAmodel = ChatOllama(model="custom-mistral")# RAG pipelinechain = ( {"context": retriever, "question": RunnablePassthrough()} | prompt | model | StrOutputParser())print("\n\n\nTesting Responce: \n")print(chain.invoke("What is the issues plagueing the acres? show any relevant tables"))
The output will look something like this:
Testing Responce:In the provided text, the main issue with acres is related to wildfires and their impact on various lands and properties. The text discusses the number of fires, acreage burned, and the level of destruction caused by wildfires in the United States from 2018 to 2022. It also highlights that most wildfires are human-caused (89% of the average number of wildfires from 2018 to 2022) and that fires caused by lightning tend to be slightly larger and burn more acreage than those caused by humans.Here's the table provided in the text, which shows the number of fires and acres burned on federal lands (by different organizations), other non-federal lands, and total:| Year | Number of Fires (thousands) | Acres Burned (millions) ||------|-----------------------------|--------------------------|| 2018 | 58.1 | 8.8 || 2019 | 58.1 | 4.7 || 2020 | 58.1 | 10.1 || 2021 | 58.1 | 10.1 || 2022 | 58.1 | 3.6 |The table also breaks down the acreage burned by federal lands (DOI and FS) and other non-federal lands, as well as showing the total acreage burned each year.<|im_end|>
From the RAG pipline i wish to print out the the context used from the retriever which stores tons of vector embeddings. i wish to know which ones it uses for the query. something like :
chain.invoke("What is the issues plagueing the acres? show any relevant tables").get_context_used()
i know there are functions like
retriever.get_relevant_documents(prompt)
and
retriever.vectorstore.similarity_search(prompt)
which provides the most relevant context to the query but I'm unsure whether the invoke function pulls the same context with the other 2 functions.
the Retriver Im using from Langchain is the MultiVectorRetriever