| import streamlit as st |
| from PIL import Image |
| from function import bounding_box |
| from tempfile import NamedTemporaryFile |
| import os |
| from function import ImageCaptionTools, ObjectDetectionTool |
| from langchain.agents import initialize_agent, AgentType |
| from langchain_google_genai import ChatGoogleGenerativeAI |
| from langchain.memory import ConversationBufferWindowMemory |
| from htmlTemplate import css, bot_template, user_template |
| import random |
|
|
| DIR = './temp' |
| if not os.path.exists(DIR): |
| os.mkdir(DIR) |
|
|
| if "image_processed" not in st.session_state: |
| DIR_PATH = os.path.join(DIR, str(random.randint(1,999999999))) |
| st.session_state.dirpath = DIR_PATH |
| if not os.path.exists(DIR_PATH): |
| os.mkdir(DIR_PATH) |
|
|
| def delete_temp_files(): |
| for filename in os.listdir(st.session_state.dirpath): |
| file_path = os.path.join(st.session_state.dirpath, filename) |
| if os.path.isfile(file_path): |
| os.unlink(file_path) |
|
|
|
|
|
|
| |
| def agent_init(): |
| tools = [ImageCaptionTools(), ObjectDetectionTool()] |
| llm = ChatGoogleGenerativeAI(model="gemini-pro") |
| memory = ConversationBufferWindowMemory(memory_key='chat_history', |
| k=5, |
| return_messages=True) |
| agents = initialize_agent( |
| agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION, |
| llm=llm, |
| tools=tools, |
| max_iterations=5, |
| verbose=True, |
| memory=memory |
| ) |
| return agents |
|
|
|
|
|
|
| def main(): |
| st.set_page_config( |
| page_title="Chat with an Image", |
| page_icon="🖼️", |
| layout="wide" |
| ) |
| st.write(css, unsafe_allow_html=True) |
| st.title("Chat with an Image 🖼️") |
| agent = agent_init() |
|
|
| |
| if 'reloaded' not in st.session_state: |
| st.session_state.reloaded = False |
| else: |
| st.session_state.reloaded = True |
|
|
| if "image_processed" not in st.session_state: |
| st.session_state.image_processed = None |
|
|
| if "result_bounding" not in st.session_state: |
| st.session_state.result_bounding = None |
|
|
| |
|
|
| col1, col2 = st.columns([1, 1]) |
| with col1: |
| image_upload = st.file_uploader(label="Please Upload Your Image", type=['jpg', 'png', 'jpeg']) |
| if not image_upload: |
| st.warning("Please upload your image") |
| else: |
| st.image( |
| image_upload, |
| use_column_width=True |
| ) |
| click_process = st.button("Process Image", disabled=not image_upload) |
| if click_process: |
| delete_temp_files() |
| with NamedTemporaryFile(dir=st.session_state.dirpath, delete=False) as f: |
| f.write(image_upload.getbuffer()) |
| st.session_state.image_path = f.name |
| st.session_state.image_processed = True |
|
|
| if (st.session_state.image_processed and st.session_state.result_bounding is None) or click_process: |
| with st.spinner("Please Wait"): |
| result_bounding = bounding_box(st.session_state.image_path) |
| st.session_state.result_bounding = result_bounding |
|
|
| |
| if st.session_state.result_bounding is not None: |
| with st.expander("Show Image (Bounding Box)"): |
| st.image(st.session_state.result_bounding) |
|
|
| with col2: |
| user_question = st.text_area("Ask About your image", |
| disabled=not st.session_state.image_processed, |
| max_chars=150) |
| click_ask = st.button("Ask Question", disabled=not st.session_state.image_processed) |
| if click_ask: |
| st.write(user_template.replace("{{MSG}}", user_question), unsafe_allow_html=True) |
| with st.spinner("Doraemon Searching for Answer🔎"): |
| chat_history = agent.invoke({"input": f"{user_question}, this is the image path: {st.session_state.image_path}"}) |
| response = chat_history['output'] |
| st.write(bot_template.replace("{{MSG}}", response), unsafe_allow_html=True) |
|
|
| if __name__ == "__main__": |
| main() |
|
|