Parthiban97 commited on
Commit
c111a70
·
verified ·
1 Parent(s): 737ea28

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +155 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import tempfile
4
+ import time
5
+ import nbformat
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.chains.combine_documents import create_stuff_documents_chain
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain.chains import create_retrieval_chain
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
13
+ from dotenv import load_dotenv
14
+ from langchain_core.documents import Document
15
+
16
+ load_dotenv()
17
+
18
+ st.set_page_config(page_title="Chat with Notebooks", page_icon=":books:")
19
+
20
+ st.title("Chat Gemini Document Q&A with Jupyter Notebooks")
21
+
22
+ # Custom prompt template
23
+ custom_context_input = """
24
+ <context>
25
+ {context}
26
+ </context>
27
+ Questions:{input}
28
+ """
29
+
30
+ # Default prompt template
31
+ default_prompt_template = """
32
+ Answer the questions based on the provided context only.
33
+ Please provide the most accurate response based on the question
34
+ <context>
35
+ {context}
36
+ </context>
37
+ Questions:{input}
38
+ """
39
+
40
+ def load_notebook(file_path):
41
+ with open(file_path, 'r', encoding='utf-8') as f:
42
+ notebook = nbformat.read(f, as_version=4)
43
+ return notebook
44
+
45
+ def extract_text_from_notebook(notebook):
46
+ text = []
47
+ for cell in notebook.cells:
48
+ if cell.cell_type == 'markdown':
49
+ text.append(cell.source)
50
+ elif cell.cell_type == 'code':
51
+ text.append(cell.source)
52
+ if 'outputs' in cell:
53
+ for output in cell.outputs:
54
+ if output.output_type == 'stream':
55
+ text.append(output.text)
56
+ elif output.output_type == 'execute_result' and 'data' in output:
57
+ text.append(output.data.get('text/plain', ''))
58
+ return "\n".join(text)
59
+
60
+ def vector_embedding(ipynb_files):
61
+ if "vectors" not in st.session_state:
62
+ st.session_state.embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
63
+
64
+ documents = []
65
+ for ipynb_file in ipynb_files:
66
+ # Save the uploaded file to a temporary location
67
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".ipynb") as tmp_file:
68
+ tmp_file.write(ipynb_file.getvalue())
69
+ tmp_file_path = tmp_file.name
70
+
71
+ # Load the .ipynb file from the temporary file path
72
+ notebook = load_notebook(tmp_file_path)
73
+ text = extract_text_from_notebook(notebook)
74
+ # Create a Document object instead of using plain text
75
+ documents.append(Document(page_content=text))
76
+
77
+ # Remove the temporary file
78
+ os.remove(tmp_file_path)
79
+
80
+ # Ensure documents are properly segmented or chunked
81
+ st.session_state.text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
82
+ try:
83
+ segmented_documents = st.session_state.text_splitter.split_documents(documents)
84
+ st.session_state.final_documents = segmented_documents
85
+
86
+ if st.session_state.final_documents:
87
+ # Embedding using FAISS
88
+ st.session_state.vectors = FAISS.from_documents(st.session_state.final_documents, st.session_state.embeddings)
89
+ st.success("Document embedding is completed!")
90
+ else:
91
+ st.warning("No documents found to embed.")
92
+
93
+ except Exception as e:
94
+ st.error(f"Error splitting or embedding documents: {str(e)}")
95
+ st.session_state.final_documents = [] # Handle empty documents or retry
96
+
97
+ # Define model options for Gemini
98
+ model_options = [
99
+ "models/gemini-1.0-pro",
100
+ "models/gemini-1.0-pro-001",
101
+ "models/gemini-1.0-pro-latest",
102
+ "models/gemini-1.0-pro-vision-latest",
103
+ "models/gemini-1.5-flash-latest",
104
+ "models/gemini-1.5-pro-latest",
105
+ "models/gemini-pro",
106
+ "models/gemini-pro-vision"
107
+ ]
108
+
109
+ # Sidebar elements
110
+ with st.sidebar:
111
+ st.header("Configuration")
112
+ st.markdown("Enter your API key below:")
113
+ google_api_key = st.text_input("Enter your Google API Key", type="password", help="Get your API key from [Google AI Studio](https://aistudio.google.com/app/apikey)")
114
+ selected_model = st.selectbox("Select Gemini Model", model_options)
115
+ os.environ["GOOGLE_API_KEY"] = str(google_api_key)
116
+
117
+ st.markdown("Upload your .ipynb files:")
118
+ uploaded_files = st.file_uploader("Choose .ipynb files", accept_multiple_files=True, type="ipynb")
119
+
120
+ # Custom prompt text areas
121
+ st.markdown("Enter a custom prompt template (optional):")
122
+ custom_prompt_template = st.text_area("Custom Prompt Template", placeholder="Enter your custom prompt here...")
123
+
124
+ if st.button("Start Document Embedding"):
125
+ if uploaded_files:
126
+ vector_embedding(uploaded_files)
127
+ st.success("Vector Store DB is Ready")
128
+ else:
129
+ st.warning("Please upload at least one .ipynb file.")
130
+
131
+ # Main section for question input and results
132
+ prompt1 = st.text_area("Enter Your Question From Documents")
133
+
134
+ if prompt1 and "vectors" in st.session_state:
135
+ if custom_prompt_template:
136
+ custom_prompt = custom_prompt_template + custom_context_input
137
+ prompt = ChatPromptTemplate.from_template(custom_prompt)
138
+ else:
139
+ prompt = ChatPromptTemplate.from_template(default_prompt_template)
140
+
141
+ llm = ChatGoogleGenerativeAI(model=selected_model, temperature=0.3)
142
+ document_chain = create_stuff_documents_chain(llm, prompt)
143
+ retriever = st.session_state.vectors.as_retriever()
144
+ retrieval_chain = create_retrieval_chain(retriever, document_chain)
145
+ start = time.process_time()
146
+ response = retrieval_chain.invoke({'input': prompt1})
147
+ st.write("Response time:", time.process_time() - start)
148
+ st.write(response['answer'])
149
+
150
+ # With a Streamlit expander
151
+ with st.expander("Document Similarity Search"):
152
+ # Find the relevant chunks
153
+ for i, doc in enumerate(response["context"]):
154
+ st.write(doc.page_content)
155
+ st.write("--------------------------------")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ langchain
3
+ openai
4
+ nbformat
5
+ faiss-cpu
6
+ langchain-google-genai
7
+ langchain-groq
8
+ langchain_community
9
+ python-dotenv