| import gradio as gr |
| import os |
| from pathlib import Path |
| import autogen |
| import chromadb |
| import multiprocessing as mp |
| from autogen.retrieve_utils import TEXT_FORMATS, get_file_from_url, is_url |
| from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent |
| from autogen.agentchat.contrib.retrieve_user_proxy_agent import ( |
| RetrieveUserProxyAgent, |
| PROMPT_CODE, |
| ) |
|
|
| TIMEOUT = 60 |
|
|
|
|
| def initialize_agents(config_list, docs_path=None): |
| if isinstance(config_list, gr.State): |
| _config_list = config_list.value |
| else: |
| _config_list = config_list |
| if docs_path is None: |
| docs_path = "https://raw.githubusercontent.com/microsoft/autogen/main/README.md" |
|
|
| assistant = RetrieveAssistantAgent( |
| name="assistant", |
| system_message="You are a helpful assistant.", |
| ) |
|
|
| ragproxyagent = RetrieveUserProxyAgent( |
| name="ragproxyagent", |
| human_input_mode="NEVER", |
| max_consecutive_auto_reply=5, |
| retrieve_config={ |
| "task": "code", |
| "docs_path": docs_path, |
| "chunk_token_size": 2000, |
| "model": _config_list[0]["model"], |
| "client": chromadb.PersistentClient(path="/tmp/chromadb"), |
| "embedding_model": "all-mpnet-base-v2", |
| "customized_prompt": PROMPT_CODE, |
| "get_or_create": True, |
| "collection_name": "autogen_rag", |
| }, |
| ) |
|
|
| return assistant, ragproxyagent |
|
|
|
|
| def initiate_chat(config_list, problem, queue, n_results=3): |
| global assistant, ragproxyagent |
| if isinstance(config_list, gr.State): |
| _config_list = config_list.value |
| else: |
| _config_list = config_list |
| if len(_config_list[0].get("api_key", "")) < 2: |
| queue.put( |
| ["Hi, nice to meet you! Please enter your API keys in below text boxs."] |
| ) |
| return |
| else: |
| llm_config = ( |
| { |
| "request_timeout": TIMEOUT, |
| |
| "config_list": _config_list, |
| "use_cache": False, |
| }, |
| ) |
| assistant.llm_config.update(llm_config[0]) |
| assistant.reset() |
| try: |
| ragproxyagent.initiate_chat( |
| assistant, problem=problem, silent=False, n_results=n_results |
| ) |
| messages = ragproxyagent.chat_messages |
| messages = [messages[k] for k in messages.keys()][0] |
| messages = [m["content"] for m in messages if m["role"] == "user"] |
| print("messages: ", messages) |
| except Exception as e: |
| messages = [str(e)] |
| queue.put(messages) |
|
|
|
|
| def chatbot_reply(input_text): |
| """Chat with the agent through terminal.""" |
| queue = mp.Queue() |
| process = mp.Process( |
| target=initiate_chat, |
| args=(config_list, input_text, queue), |
| ) |
| process.start() |
| try: |
| |
| messages = queue.get(timeout=TIMEOUT) |
| except Exception as e: |
| messages = [ |
| str(e) |
| if len(str(e)) > 0 |
| else "Invalid Request to OpenAI, please check your API keys." |
| ] |
| finally: |
| try: |
| process.terminate() |
| except: |
| pass |
| return messages |
|
|
|
|
| def get_description_text(): |
| return """ |
| # Microsoft AutoGen: Retrieve Chat Demo |
| |
| This demo shows how to use the RetrieveUserProxyAgent and RetrieveAssistantAgent to build a chatbot. |
| |
| #### [AutoGen](https://github.com/microsoft/autogen) [Discord](https://discord.gg/pAbnFJrkgZ) [Blog](https://microsoft.github.io/autogen/blog/2023/10/18/RetrieveChat) [Paper](https://arxiv.org/abs/2308.08155) [SourceCode](https://github.com/thinkall/autogen-demos) |
| """ |
|
|
|
|
| global assistant, ragproxyagent |
|
|
| with gr.Blocks() as demo: |
| config_list, assistant, ragproxyagent = ( |
| gr.State( |
| [ |
| { |
| "api_key": "", |
| "api_base": "", |
| "api_type": "azure", |
| "api_version": "2023-07-01-preview", |
| "model": "gpt-35-turbo", |
| } |
| ] |
| ), |
| None, |
| None, |
| ) |
| assistant, ragproxyagent = initialize_agents(config_list) |
|
|
| gr.Markdown(get_description_text()) |
| chatbot = gr.Chatbot( |
| [], |
| elem_id="chatbot", |
| bubble_full_width=False, |
| avatar_images=(None, (os.path.join(os.path.dirname(__file__), "autogen.png"))), |
| |
| ) |
|
|
| txt_input = gr.Textbox( |
| scale=4, |
| show_label=False, |
| placeholder="Enter text and press enter", |
| container=False, |
| ) |
|
|
| with gr.Row(): |
|
|
| def update_config(config_list): |
| global assistant, ragproxyagent |
| config_list = autogen.config_list_from_models( |
| model_list=[os.environ.get("MODEL", "gpt-35-turbo")], |
| ) |
| if not config_list: |
| config_list = [ |
| { |
| "api_key": "", |
| "api_base": "", |
| "api_type": "azure", |
| "api_version": "2023-07-01-preview", |
| "model": "gpt-35-turbo", |
| } |
| ] |
| llm_config = ( |
| { |
| "request_timeout": TIMEOUT, |
| |
| "config_list": config_list, |
| }, |
| ) |
| assistant.llm_config.update(llm_config[0]) |
| ragproxyagent._model = config_list[0]["model"] |
| return config_list |
|
|
|
|
| def set_params(model, oai_key, aoai_key, aoai_base): |
| os.environ["MODEL"] = model |
| os.environ["OPENAI_API_KEY"] = oai_key |
| os.environ["AZURE_OPENAI_API_KEY"] = aoai_key |
| os.environ["AZURE_OPENAI_API_BASE"] = aoai_base |
| return model, oai_key, aoai_key, aoai_base |
|
|
|
|
| txt_model = gr.Dropdown( |
| label="Model", |
| choices=[ |
| "gpt-4", |
| "gpt-35-turbo", |
| "gpt-3.5-turbo", |
| ], |
| allow_custom_value=True, |
| value="gpt-35-turbo", |
| container=True, |
| ) |
| txt_oai_key = gr.Textbox( |
| label="OpenAI API Key", |
| placeholder="Enter key and press enter", |
| max_lines=1, |
| show_label=True, |
| value=os.environ.get("OPENAI_API_KEY", ""), |
| container=True, |
| type="password", |
| ) |
| txt_aoai_key = gr.Textbox( |
| label="Azure OpenAI API Key", |
| placeholder="Enter key and press enter", |
| max_lines=1, |
| show_label=True, |
| value=os.environ.get("AZURE_OPENAI_API_KEY", ""), |
| container=True, |
| type="password", |
| ) |
| txt_aoai_base_url = gr.Textbox( |
| label="Azure OpenAI API Base", |
| placeholder="Enter base url and press enter", |
| max_lines=1, |
| show_label=True, |
| value=os.environ.get("AZURE_OPENAI_API_BASE", ""), |
| container=True, |
| type="password", |
| ) |
|
|
| clear = gr.ClearButton([txt_input, chatbot]) |
|
|
| with gr.Row(): |
|
|
| def upload_file(file): |
| return update_context_url(file.name) |
|
|
|
|
| upload_button = gr.UploadButton( |
| "Click to upload a context file or enter a url in the right textbox", |
| file_types=[f".{i}" for i in TEXT_FORMATS], |
| file_count="single", |
| ) |
|
|
| txt_context_url = gr.Textbox( |
| label="Enter the url to your context file and chat on the context", |
| info=f"File must be in the format of [{', '.join(TEXT_FORMATS)}]", |
| max_lines=1, |
| show_label=True, |
| value="https://raw.githubusercontent.com/microsoft/autogen/main/README.md", |
| container=True, |
| ) |
|
|
| txt_prompt = gr.Textbox( |
| label="Enter your prompt for Retrieve Agent and press enter to replace the default prompt", |
| max_lines=40, |
| show_label=True, |
| value=PROMPT_CODE, |
| container=True, |
| show_copy_button=True, |
| ) |
|
|
|
|
| def respond(message, chat_history, model, oai_key, aoai_key, aoai_base): |
| global config_list |
| set_params(model, oai_key, aoai_key, aoai_base) |
| config_list = update_config(config_list) |
| messages = chatbot_reply(message) |
| _msg = ( |
| messages[-1] |
| if len(messages) > 0 and messages[-1] != "TERMINATE" |
| else messages[-2] |
| if len(messages) > 1 |
| else "Context is not enough for answering the question. Please press `enter` in the context url textbox to make sure the context is activated for the chat." |
| ) |
| chat_history.append((message, _msg)) |
| return "", chat_history |
|
|
|
|
| def update_prompt(prompt): |
| ragproxyagent.customized_prompt = prompt |
| return prompt |
|
|
|
|
| def update_context_url(context_url): |
| global assistant, ragproxyagent |
|
|
| file_extension = Path(context_url).suffix |
| print("file_extension: ", file_extension) |
| if file_extension.lower() not in [f".{i}" for i in TEXT_FORMATS]: |
| return f"File must be in the format of {TEXT_FORMATS}" |
|
|
| if is_url(context_url): |
| try: |
| file_path = get_file_from_url( |
| context_url, |
| save_path=os.path.join("/tmp", os.path.basename(context_url)), |
| ) |
| except Exception as e: |
| return str(e) |
| else: |
| file_path = context_url |
| context_url = os.path.basename(context_url) |
|
|
| try: |
| chromadb.PersistentClient(path="/tmp/chromadb").delete_collection( |
| name="autogen_rag" |
| ) |
| except: |
| pass |
| assistant, ragproxyagent = initialize_agents(config_list, docs_path=file_path) |
| return context_url |
|
|
|
|
| txt_input.submit( |
| respond, |
| [txt_input, chatbot, txt_model, txt_oai_key, txt_aoai_key, txt_aoai_base_url], |
| [txt_input, chatbot], |
| ) |
| txt_prompt.submit(update_prompt, [txt_prompt], [txt_prompt]) |
| txt_context_url.submit(update_context_url, [txt_context_url], [txt_context_url]) |
| upload_button.upload(upload_file, upload_button, [txt_context_url]) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True, server_name="0.0.0.0") |