Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F92565344
main.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Thu, Nov 21, 12:54
Size
20 KB
Mime Type
text/x-python
Expires
Sat, Nov 23, 12:54 (2 d)
Engine
blob
Format
Raw Data
Handle
22384378
Attached To
R13363 Coling_RAG_exercise
main.py
View Options
# Welcome to the exciting world of Retrieval-Augmented Generation (RAG) systems!
# In this exercise, you'll build a powerful RAG system step by step.
# Get ready to dive into embeddings, vector databases, and AI-powered search!
import
os
from
dotenv
import
load_dotenv
from
typing
import
List
,
Tuple
import
sqlite3
import
numpy
print
(
numpy
.
__version__
)
import
faiss
import
numpy
as
np
from
langchain.prompts
import
ChatPromptTemplate
,
MessagesPlaceholder
from
langchain.memory
import
ConversationBufferMemory
from
langchain_openai
import
ChatOpenAI
from
langchain.agents
import
AgentExecutor
,
Tool
from
langchain.agents.format_scratchpad
import
format_to_openai_function_messages
from
langchain.agents.output_parsers
import
OpenAIFunctionsAgentOutputParser
from
langchain.tools.render
import
format_tool_to_openai_function
from
langchain.schema.runnable
import
RunnablePassthrough
from
langchain.tools
import
tool
from
langchain.text_splitter
import
TokenTextSplitter
from
langchain.schema
import
Document
from
sentence_transformers
import
SentenceTransformer
from
transformers
import
AutoTokenizer
from
read_pdf
import
read_pdf
# Let's start by setting up our environment and initializing our models
load_dotenv
()
os
.
environ
[
"OPENAI_API_KEY"
]
=
os
.
getenv
(
"OPENAI_KEY"
)
# Initialize SentenceTransformer and its underlying tokenizer
model
=
SentenceTransformer
(
'sentence-transformers/all-MiniLM-L6-v2'
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
'sentence-transformers/all-MiniLM-L6-v2'
)
def
create_sqlite_tables
(
db_path
:
str
)
->
None
:
"""
Create SQLite tables for storing document chunks and their embeddings.
This function sets up the foundation of our RAG system's database. It creates
two tables: 'chunks' for storing text chunks and their metadata, and 'embeddings'
for storing the vector representations of these chunks.
Args:
db_path (str): The file path where the SQLite database will be created or accessed.
Returns:
None
Fun fact: SQLite is so reliable it's used in airplanes and smartphones!
"""
conn
=
sqlite3
.
connect
(
db_path
)
cursor
=
conn
.
cursor
()
cursor
.
execute
(
'''
CREATE TABLE IF NOT EXISTS chunks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chunk_content TEXT,
source_document TEXT,
start_page INTEGER
)
'''
)
cursor
.
execute
(
'''
CREATE TABLE IF NOT EXISTS embeddings (
chunk_id INTEGER,
embedding BLOB,
FOREIGN KEY (chunk_id) REFERENCES chunks (id)
)
'''
)
conn
.
commit
()
conn
.
close
()
def
chunk_document
(
pages
:
List
[
Document
],
source
:
str
)
->
List
[
Tuple
[
str
,
str
,
int
]]:
"""
Chunk the document pages, handling chunks that cross page boundaries.
This function is like a master chef slicing a long document into bite-sized pieces.
It ensures that each chunk is just the right size for our model to digest, while
keeping track of where each chunk came from.
Args:
pages (List[Document]): List of Document objects, each representing a page.
source (str): The source document name.
Returns:
List[Tuple[str, str, int]]: List of (chunk_text, source, start_page).
"""
# initialization
text_splitter
=
TokenTextSplitter
(
chunk_size
=
500
,
chunk_overlap
=
200
)
result
=
[]
# where we would like to accumulate the chunks
# variables to keep track of chunking across pages
previous_last_chunk
=
""
# stores any chunks that may have overflown from past pages
chunk_start_page
=
1
# use it to keep track of page number
for
page
in
pages
:
page_content
=
page
.
page_content
page_number
=
page
.
metadata
[
'page'
]
########################################################################
# TODO: concatenate the current page content with the last chunk of previous page
# if the previous page was not exactly divisible by 500, then we
# we wouldn't want to throw the leftover string away.
# Instead we treat it as if it's part of the next page.
########################################################################
...
########################################################################
# TODO: chunk this concatenated string
# Hint: use text_splitter.split_text() method
########################################################################
...
########################################################################
# TODO: add all the chunks but the last one to the result
########################################################################
...
# add the last chunk of the last page to the result
if
previous_last_chunk
:
result
.
append
((
previous_last_chunk
,
source
,
chunk_start_page
))
return
result
def
embed_chunks
(
chunks
:
List
[
str
],
local
:
bool
=
True
)
->
np
.
ndarray
:
"""
Embed a list of text chunks using either a local SentenceTransformer model or OpenAI's embedding model.
This function is like a translator, converting our text chunks into a language
that our AI models can understand - the language of vectors!
Args:
chunks (List[str]): The list of text chunks to be embedded.
local (bool): If True, use the local SentenceTransformer model. If False, use OpenAI's model.
Returns:
np.ndarray: The embedding vectors for the chunks.
Exercise: Try implementing the OpenAI embedding method. How does it compare to the local model?
"""
if
local
:
########################################################################
# TODO: Implement the local SentenceTransformer embedding method here
# Hint: You'll need to use the model.encode() method, checkout its documentation!
########################################################################
pass
else
:
########################################################################
# (Optional) TODO: Implement OpenAI embedding method here
# Hint: You'll need to use the openai.Embedding.create() method, checkout its documentation!
########################################################################
pass
def
process_and_store_chunks
(
chunks
:
List
[
Tuple
[
str
,
str
,
int
]],
db_path
:
str
,
local
:
bool
=
True
)
->
None
:
"""
Process the input chunks, embed them, and store in the database.
This function is like a librarian, taking our chunks of text, creating a special
index for each (the embedding), and carefully storing both in our database.
Args:
chunks (List[Tuple[str, str, int]]): List of (chunk_text, source_document, start_page) tuples.
db_path (str): Path to the SQLite database file.
local (bool): Whether to use the local embedding model or OpenAI's.
Returns:
None
Challenge: Can you modify this function to batch process chunks for better efficiency?
"""
conn
=
sqlite3
.
connect
(
db_path
)
cursor
=
conn
.
cursor
()
for
chunk_text
,
source_document
,
start_page
in
chunks
:
########################################################################
# TODO: Define the sql query to insert the chunk into the database.
# Ideally you should save information on
# (1) the chunk's text,
# (2) which document it comes from,
# (3) what page it starts at
# Hint: the sqlite3 cursor usage is of the form
# cursor.execute("INSERT INTO table VALUES (?, ?, ?)", (var1, var2, var3))
########################################################################
insert_chunk_sql_query
=
""
cursor
.
execute
(
insert_chunk_sql_query
,
()
# TODO: pass the required variables to your SQL query here
)
chunk_id
=
cursor
.
lastrowid
########################################################################
# TODO: Embed the chunk using the embed_chunks function
########################################################################
...
########################################################################
# TODO: Store the embedding in the database s.t. its unique ID is
# the chunk_id you get from storing it in the database
# Hint: You'll need to convert the embedding to bytes using the .tobytes() method
########################################################################
insert_embed_sql_query
=
""
cursor
.
execute
(
insert_embed_sql_query
,
()
# TODO: pass the required variables to your SQL query here
)
conn
.
commit
()
conn
.
close
()
def
create_faiss_index
(
db_path
:
str
)
->
faiss
.
Index
:
"""
Create a FAISS index from the stored embeddings in the database.
This function is like building a high-tech library catalog. It takes all our
stored embeddings and organizes them in a way that allows for super-fast searching!
Args:
db_path (str): Path to the SQLite database file.
Returns:
faiss.Index: The created FAISS index.
Fun fact: FAISS can handle billions of vectors, making it perfect for large-scale search systems!
"""
# create conn and cursor to load the database
conn
=
sqlite3
.
connect
(
db_path
)
cursor
=
conn
.
cursor
()
############################################################################
# TODO: retrieve embeddings from the database using the SELECT SQL query
############################################################################
select_embed_sql_query
=
""
cursor
.
execute
(
select_embed_sql_query
)
embeddings
=
[
np
.
frombuffer
(
row
[
0
],
dtype
=
np
.
float32
)
for
row
in
cursor
.
fetchall
()]
# close the database connection
conn
.
close
()
############################################################################
# TODO: create the FAISS index using L2 distance
# Hint: checkout the documentation on the faiss.IndexFlatL2 function
############################################################################
dimension
=
...
# TODO: get the dimension of the embeddings
index
=
...
# TODO: create the L2 index
############################################################################
# TODO: add the embeddings to the index
# Hint: use the .add() method of the index
############################################################################
...
return
index
def
process_pdf
(
file_path
,
db_path
,
local
=
True
):
# create a connection to the database
conn
=
sqlite3
.
connect
(
db_path
)
cursor
=
conn
.
cursor
()
# check if document already exists in the database
cursor
.
execute
(
"SELECT id FROM chunks WHERE source_document = ?"
,
(
os
.
path
.
basename
(
file_path
),))
# close the connection
conn
.
close
()
############################################################################
# TODO: read the pdf file, use the read_pdf function
############################################################################
pages
=
...
source
=
os
.
path
.
basename
(
file_path
)
############################################################################
# TODO: get document chunks with the chunk_document functions
############################################################################
chunks
=
...
############################################################################
# TODO: process and store the chunks with the process_and_store_chunks function
############################################################################
...
def
search_engine
(
query
:
str
,
faiss_index
:
faiss
.
Index
,
db_path
:
str
,
k
:
int
=
5
)
->
List
[
Tuple
[
str
,
float
,
str
,
int
]]:
"""
Search for relevant chunks using the query and FAISS index.
This function is the heart of our RAG system. It takes a question, finds the most
relevant information in our database, and returns it. It's like having a super-smart
research assistant at your fingertips!
Args:
query (str): The search query.
faiss_index (faiss.Index): The FAISS index for similarity search.
db_path (str): Path to the SQLite database file.
k (int): Number of top results to return.
Returns:
List[Tuple[str, float, str, int]]: List of (chunk_content, similarity_score, source_document, start_page).
Exercise: Can you modify this function to also return the actual similarity scores?
"""
############################################################################
# Implement the search functionality
# Hint: You'll need to
# (1) embed the query
# (2) use faiss_index.search()
# (3) fetch corresponding chunks from the database
# Note that here a query doesn't mean an SQL query but a user document
# search query in the form of a NL string.
############################################################################
# TODO: embed the query
############################################################################
query_embedding
=
...
############################################################################
# TODO: use faiss_index.search() to find the relevant documents
############################################################################
distances
,
indices
=
...
# connect the database to get the results
conn
=
sqlite3
.
connect
(
db_path
)
cursor
=
conn
.
cursor
()
############################################################################
# TODO: fetch the corresponding chunks from the database with an SQL query
############################################################################
results
=
[]
for
i
,
distance
in
zip
(
indices
[
0
],
distances
[
0
]):
select_chunk_sql_query
=
...
cursor
.
execute
(
select_chunk_sql_query
,
()
# TODO: pass the required variables to your SQL query here
)
chunkid
,
chunk_content
=
cursor
.
fetchone
()
results
.
append
((
chunkid
,
chunk_content
))
conn
.
close
()
return
results
# In the following code, you will implement the agent that uses the search engine to answer questions using langchain
# Some example and help can be found here: https://python.langchain.com/docs/how_to/agent_executor/
@tool
def
search_tool
(
query
:
str
)
->
str
:
"""
Search for relevant information using the query.
"""
############################################################################
# TODO: Implement this function, you have to find a way to let the llm know
# which chunk comes from where so that we can add the sources in the end.
# Hint: Use your search_engine function and return the formatted the results
############################################################################
pass
# create tools list containing search_tool as a single tool. Use Tool class from langchain
tools
=
[
Tool
(
name
=
"Search"
,
func
=
search_tool
,
description
=
"Search for legal information about EPFL"
)]
# load ChatOpenAI from LangChain
llm
=
ChatOpenAI
(
temperature
=
0
,
model
=
'gpt-4o-mini'
)
############################################################################
# TODO: Create the prompt template in the file system_prompt.txt
############################################################################
# Get the directory of the current script
current_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# Construct the full path to the system_prompt.txt file
system_prompt_path
=
os
.
path
.
join
(
current_dir
,
'system_prompt.txt'
)
# Read the system prompt from the file
with
open
(
system_prompt_path
,
'r'
)
as
file
:
system_prompt
=
file
.
read
()
.
strip
()
# Use ChatPromptTemplate.from_messages to create a prompt that instructs the AI
# on how to use the search tool and format its responses
prompt
=
ChatPromptTemplate
.
from_messages
([
(
"system"
,
system_prompt
),
MessagesPlaceholder
(
variable_name
=
"chat_history"
),
(
"human"
,
"{input}"
),
MessagesPlaceholder
(
variable_name
=
"agent_scratchpad"
),
])
# Set up the memory
memory
=
ConversationBufferMemory
(
memory_key
=
"chat_history"
,
return_messages
=
True
)
################################################################################
# TODO: Create the agent
################################################################################
# Use the RunnablePassthrough, prompt, llm, and OpenAIFunctionsAgentOutputParser
# to create the agent you can find some infos here:
# https://github.com/langchain-ai/langchain/discussions/18591
agent
=
(
{
"input"
:
...
,
# TODO: Implement the input format
"chat_history"
:
...
,
# TODO: Implement the chat history format
"agent_scratchpad"
:
...
# TODO: Implement the agent scratchpad format
}
|
...
# TODO: Use the prompt
|
...
# TODO: Use the language model with tools
|
...
# TODO: Use the output parser
)
################################################################################
# TODO: Create the agent executor
################################################################################
agent_executor
=
...
# TODO: Use the AgentExecutor to create the agent executor
import
re
def
run_agent_conversation
()
->
None
:
"""
Run the LangChain agent in a console-based conversation loop.
"""
print
(
"Welcome to the RAG system. Type 'exit' to end the conversation."
)
conn
=
sqlite3
.
connect
(
db_path
)
cursor
=
conn
.
cursor
()
while
True
:
########################################################################
# TODO: Get user input with the builtin python function `input`
########################################################################
user_input
=
...
if
user_input
.
lower
()
==
'exit'
:
break
########################################################################
# TODO: Request a response from the executor
########################################################################
response
=
...
# the output contains the sources in the format [[id]]. we use a regex to extract the ids and get the sources
ids
=
re
.
findall
(
r"\[\[(\d+)\]\]"
,
response
[
"output"
])
for
id
in
ids
:
# 1. fetch the source and page from the database
select_src_page_query
=
"SELECT source_document, start_page FROM chunks WHERE id = ?"
cursor
.
execute
(
select_src_page_query
,
(
id
,))
chunk_content
=
cursor
.
fetchone
()
####################################################################
# TODO: 2. replace the id with the source document
# for the assistant response display
####################################################################
response
[
"output"
]
=
...
print
(
"Assistant:"
,
response
[
"output"
])
conn
.
close
()
if
__name__
==
"__main__"
:
print
(
"Welcome to your RAG system building adventure!"
)
LOCAL
=
os
.
getenv
(
"LOCAL"
,
"True"
)
.
lower
()
==
"true"
QUICK_DEMO
=
os
.
getenv
(
"QUICK_DEMO"
,
"False"
)
.
lower
()
==
"true"
# In your .env file, set LOCAL to False if you want to use the openai
# embedding model.
# Set QUICK_DEMO to False if you want to run the code on the entirety
# of the data, instead of a subset.
# Namely: add the following lines to your .env file
# LOCAL=False
# QUICK_DEMO=False
if
LOCAL
:
db_path
=
"rag_database.sqlite"
else
:
db_path
=
"rag_database_with_openai_embedding.sqlite"
# Initialize the database and FAISS index
create_sqlite_tables
(
db_path
)
# List all files in the data folder, to make sure you have the right path
data_folder
=
'./data'
all_files
=
os
.
listdir
(
data_folder
)
if
QUICK_DEMO
:
all_files
=
all_files
[:
2
]
for
file
in
all_files
:
file_path
=
os
.
path
.
join
(
data_folder
,
file
)
# check if file is a pdf
if
file_path
.
endswith
(
'.pdf'
):
process_pdf
(
file_path
,
db_path
)
# Create FAISS index
faiss_index
=
create_faiss_index
(
db_path
)
# Run the conversation loop
run_agent_conversation
()
Event Timeline
Log In to Comment