| import gradio as gr |
| import io |
| import sys |
| import time |
| import dataclasses |
| from pathlib import Path |
| import os |
| from enum import auto, Enum |
| from typing import List, Tuple, Any |
| from utility import prediction_guard_llava_conv |
| import lancedb |
| from utility import load_json_file |
| from mm_rag.embeddings.bridgetower_embeddings import BridgeTowerEmbeddings |
| from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB |
| from mm_rag.MLM.client import PredictionGuardClient |
| from mm_rag.MLM.lvlm import LVLM |
| from PIL import Image |
| from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda |
| from moviepy.video.io.VideoFileClip import VideoFileClip |
| from utility import prediction_guard_llava_conv, encode_image, Conversation, lvlm_inference_with_conversation |
|
|
| server_error_msg="**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" |
|
|
| |
| def split_video(video_path, timestamp_in_ms, output_video_path: str = "./shared_data/splitted_videos", output_video_name: str="video_tmp.mp4", play_before_sec: int=3, play_after_sec: int=3): |
| timestamp_in_sec = int(timestamp_in_ms / 1000) |
| |
| Path(output_video_path).mkdir(parents=True, exist_ok=True) |
| output_video = os.path.join(output_video_path, output_video_name) |
| with VideoFileClip(video_path) as video: |
| duration = video.duration |
| start_time = max(timestamp_in_sec - play_before_sec, 0) |
| end_time = min(timestamp_in_sec + play_after_sec, duration) |
| new = video.subclip(start_time, end_time) |
| new.write_videofile(output_video, audio_codec='aac') |
| return output_video |
|
|
|
|
| prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}""" |
|
|
| |
| def get_default_rag_chain(): |
| |
| LANCEDB_HOST_FILE = "./shared_data/.lancedb" |
| |
| TBL_NAME = "demo_tbl" |
| |
| |
| db = lancedb.connect(LANCEDB_HOST_FILE) |
|
|
| |
| embedder = BridgeTowerEmbeddings() |
|
|
| |
| vectorstore = MultimodalLanceDB(uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME) |
| |
| retriever_module = vectorstore.as_retriever(search_type='similarity', search_kwargs={"k": 1}) |
|
|
| |
| client = PredictionGuardClient() |
| |
| lvlm_inference_module = LVLM(client=client) |
| |
| def prompt_processing(input): |
| |
| retrieved_results, user_query = input['retrieved_results'], input['user_query'] |
| |
| retrieved_result = retrieved_results[0] |
| |
| |
| |
| metadata_retrieved_video_segment = retrieved_result.metadata['metadata'] |
| |
| |
| transcript = metadata_retrieved_video_segment['transcript'] |
| frame_path = metadata_retrieved_video_segment['extracted_frame_path'] |
| return { |
| 'prompt': prompt_template.format(transcript=transcript, user_query=user_query), |
| 'image' : frame_path, |
| 'metadata' : metadata_retrieved_video_segment, |
| } |
| |
| prompt_processing_module = RunnableLambda(prompt_processing) |
|
|
| |
| mm_rag_chain_with_retrieved_image = ( |
| RunnableParallel({"retrieved_results": retriever_module , |
| "user_query": RunnablePassthrough()}) |
| | prompt_processing_module |
| | RunnableParallel({'final_text_output': lvlm_inference_module, |
| 'input_to_lvlm' : RunnablePassthrough()}) |
| ) |
| return mm_rag_chain_with_retrieved_image |
| |
| class SeparatorStyle(Enum): |
| """Different separator style.""" |
| SINGLE = auto() |
| |
| @dataclasses.dataclass |
| class GradioInstance: |
| """A class that keeps all conversation history.""" |
| system: str |
| roles: List[str] |
| messages: List[List[str]] |
| offset: int |
| sep_style: SeparatorStyle = SeparatorStyle.SINGLE |
| sep: str = "\n" |
| sep2: str = None |
| version: str = "Unknown" |
| path_to_img: str = None |
| video_title: str = None |
| path_to_video: str = None |
| caption: str = None |
| mm_rag_chain: Any = None |
| |
| skip_next: bool = False |
|
|
| def _template_caption(self): |
| out = "" |
| if self.caption is not None: |
| out = f"The caption associated with the image is '{self.caption}'. " |
| return out |
| |
| def get_prompt_for_rag(self): |
| messages = self.messages |
| assert len(messages) == 2, "length of current conversation should be 2" |
| assert messages[1][1] is None, "the first response message of current conversation should be None" |
| ret = messages[0][1] |
| return ret |
| |
| def get_conversation_for_lvlm(self): |
| pg_conv = prediction_guard_llava_conv.copy() |
| image_path = self.path_to_img |
| b64_img = encode_image(image_path) |
| for i, (role, msg) in enumerate(self.messages[self.offset:]): |
| if msg is None: |
| break |
| if i == 0: |
| pg_conv.append_message(prediction_guard_llava_conv.roles[0], [msg, b64_img]) |
| elif i == len(self.messages[self.offset:]) - 2: |
| pg_conv.append_message(role, [prompt_template.format(transcript=self.caption, user_query=msg)]) |
| else: |
| pg_conv.append_message(role, [msg]) |
| return pg_conv |
| |
| def append_message(self, role, message): |
| self.messages.append([role, message]) |
|
|
| def get_images(self, return_pil=False): |
| images = [] |
| if self.path_to_img is not None: |
| path_to_image = self.path_to_img |
| images.append(path_to_image) |
| return images |
|
|
| def to_gradio_chatbot(self): |
| ret = [] |
| for i, (role, msg) in enumerate(self.messages[self.offset:]): |
| if i % 2 == 0: |
| if type(msg) is tuple: |
| import base64 |
| from io import BytesIO |
| msg, image, image_process_mode = msg |
| max_hw, min_hw = max(image.size), min(image.size) |
| aspect_ratio = max_hw / min_hw |
| max_len, min_len = 800, 400 |
| shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) |
| longest_edge = int(shortest_edge * aspect_ratio) |
| W, H = image.size |
| if H > W: |
| H, W = longest_edge, shortest_edge |
| else: |
| H, W = shortest_edge, longest_edge |
| image = image.resize((W, H)) |
| buffered = BytesIO() |
| image.save(buffered, format="JPEG") |
| img_b64_str = base64.b64encode(buffered.getvalue()).decode() |
| img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />' |
| msg = img_str + msg.replace('<image>', '').strip() |
| ret.append([msg, None]) |
| else: |
| ret.append([msg, None]) |
| else: |
| ret[-1][-1] = msg |
| return ret |
|
|
| def copy(self): |
| return GradioInstance( |
| system=self.system, |
| roles=self.roles, |
| messages=[[x, y] for x, y in self.messages], |
| offset=self.offset, |
| sep_style=self.sep_style, |
| sep=self.sep, |
| sep2=self.sep2, |
| version=self.version, |
| mm_rag_chain=self.mm_rag_chain, |
| ) |
|
|
| def dict(self): |
| return { |
| "system": self.system, |
| "roles": self.roles, |
| "messages": self.messages, |
| "offset": self.offset, |
| "sep": self.sep, |
| "sep2": self.sep2, |
| "path_to_img": self.path_to_img, |
| "video_title" : self.video_title, |
| "path_to_video": self.path_to_video, |
| "caption" : self.caption, |
| } |
| def get_path_to_subvideos(self): |
| if self.video_title is not None and self.path_to_img is not None: |
| info = video_helper_map[self.video_title] |
| path = info['path'] |
| prefix = info['prefix'] |
| vid_index = self.path_to_img.split('/')[-1] |
| vid_index = vid_index.split('_')[-1] |
| vid_index = vid_index.replace('.jpg', '') |
| ret = f"{prefix}{vid_index}.mp4" |
| ret = os.path.join(path, ret) |
| return ret |
| elif self.path_to_video is not None: |
| return self.path_to_video |
| return None |
|
|
| def get_gradio_instance(mm_rag_chain=None): |
| if mm_rag_chain is None: |
| mm_rag_chain = get_default_rag_chain() |
| |
| instance = GradioInstance( |
| system="", |
| roles=prediction_guard_llava_conv.roles, |
| messages=[], |
| offset=0, |
| sep_style=SeparatorStyle.SINGLE, |
| sep="\n", |
| path_to_img=None, |
| video_title=None, |
| caption=None, |
| mm_rag_chain=mm_rag_chain, |
| ) |
| return instance |
|
|
| gr.set_static_paths(paths=["./assets/"]) |
| theme = gr.themes.Base( |
| primary_hue=gr.themes.Color( |
| c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"), |
| secondary_hue=gr.themes.Color( |
| c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"), |
| ).set( |
| body_background_fill_dark='*primary_950', |
| body_text_color_dark='*neutral_300', |
| border_color_accent='*primary_700', |
| border_color_accent_dark='*neutral_800', |
| block_background_fill_dark='*primary_950', |
| block_border_width='2px', |
| block_border_width_dark='2px', |
| button_primary_background_fill_dark='*primary_500', |
| button_primary_border_color_dark='*primary_500' |
| ) |
|
|
| css=''' |
| @font-face { |
| font-family: IntelOne; |
| src: url("/file=./assets/intelone-bodytext-font-family-regular.ttf"); |
| } |
| .gradio-container {background-color: #0a0c2b} |
| table { |
| border-collapse: collapse; |
| border: none; |
| } |
| ''' |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| html_title = ''' |
| <table style="bordercolor=#0a0c2b; border=0"> |
| <tr style="height:150px; border:0"> |
| <td style="border:0"><img src="/file=./assets/header.png"></td> |
| </tr> |
| </table> |
| |
| ''' |
|
|
| |
| dropdown_list = [ |
| "What is the name of one of the astronauts?", |
| "An astronaut's spacewalk", |
| "What does the astronaut say?", |
| |
| ] |
|
|
| no_change_btn = gr.Button() |
| enable_btn = gr.Button(interactive=True) |
| disable_btn = gr.Button(interactive=False) |
|
|
| def clear_history(state, request: gr.Request): |
| state = get_gradio_instance(state.mm_rag_chain) |
| return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1 |
|
|
| def add_text(state, text, request: gr.Request): |
| if len(text) <= 0 : |
| state.skip_next = True |
| return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1 |
|
|
| text = text[:1536] |
|
|
| state.append_message(state.roles[0], text) |
| state.append_message(state.roles[1], None) |
| state.skip_next = False |
| return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1 |
| |
| def http_bot( |
| state, request: gr.Request |
| ): |
| start_tstamp = time.time() |
|
|
| if state.skip_next: |
| |
| path_to_sub_videos = state.get_path_to_subvideos() |
| yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1 |
| return |
|
|
| if len(state.messages) == state.offset + 2: |
| |
| new_state = get_gradio_instance(state.mm_rag_chain) |
| new_state.append_message(new_state.roles[0], state.messages[-2][1]) |
| new_state.append_message(new_state.roles[1], None) |
| state = new_state |
|
|
| all_images = state.get_images(return_pil=False) |
|
|
| |
| is_very_first_query = True |
| if len(all_images) == 0: |
| |
| |
| prompt_or_conversation = state.get_prompt_for_rag() |
| else: |
| |
| is_very_first_query = False |
| prompt_or_conversation = state.get_conversation_for_lvlm() |
| |
| if is_very_first_query: |
| executor = state.mm_rag_chain |
| else: |
| executor = lvlm_inference_with_conversation |
| |
| state.messages[-1][-1] = "▌" |
| path_to_sub_videos = state.get_path_to_subvideos() |
| yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1 |
|
|
| try: |
| if is_very_first_query: |
| |
| response = executor.invoke(prompt_or_conversation) |
| message = response['final_text_output'] |
| if 'metadata' in response['input_to_lvlm']: |
| metadata = response['input_to_lvlm']['metadata'] |
| if (state.path_to_img is None |
| and 'input_to_lvlm' in response |
| and 'image' in response['input_to_lvlm'] |
| ): |
| state.path_to_img = response['input_to_lvlm']['image'] |
| |
| if state.path_to_video is None and 'video_path' in metadata: |
| video_path = metadata['video_path'] |
| mid_time_ms = metadata['mid_time_ms'] |
| splited_video_path = split_video(video_path, mid_time_ms) |
| state.path_to_video = splited_video_path |
| |
| if state.caption is None and 'transcript' in metadata: |
| state.caption = metadata['transcript'] |
| else: |
| raise ValueError("Response's format is changed") |
| else: |
| |
| message = executor(prompt_or_conversation) |
| |
| except Exception as e: |
| print(e) |
| state.messages[-1][-1] = server_error_msg |
| yield (state, state.to_gradio_chatbot(), None) + ( |
| enable_btn, |
| ) |
| return |
|
|
| state.messages[-1][-1] = message |
| path_to_sub_videos = state.get_path_to_subvideos() |
| |
| |
| |
| |
| |
| yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1 |
|
|
| finish_tstamp = time.time() |
| return |
| |
| def get_demo(rag_chain=None): |
| if rag_chain is None: |
| rag_chain = get_default_rag_chain() |
| |
| with gr.Blocks(theme=theme, css=css) as demo: |
| |
| instance = get_gradio_instance(rag_chain) |
| state = gr.State(instance) |
| demo.load( |
| None, |
| None, |
| js=""" |
| () => { |
| const params = new URLSearchParams(window.location.search); |
| if (!params.has('__theme')) { |
| params.set('__theme', 'dark'); |
| window.location.search = params.toString(); |
| } |
| }""", |
| ) |
| gr.HTML(value=html_title) |
| with gr.Row(): |
| with gr.Column(scale=4): |
| video = gr.Video(height=512, width=512, elem_id="video", interactive=False ) |
| with gr.Column(scale=7): |
| chatbot = gr.Chatbot( |
| elem_id="chatbot", label="Multimodal RAG Chatbot", height=512, |
| ) |
| with gr.Row(): |
| with gr.Column(scale=8): |
| |
| textbox = gr.Dropdown( |
| dropdown_list, |
| allow_custom_value=True, |
| |
| |
| label="Query", |
| info="Enter your query here or choose a sample from the dropdown list!" |
| ) |
| with gr.Column(scale=1, min_width=50): |
| submit_btn = gr.Button( |
| value="Send", variant="primary", interactive=True |
| ) |
| with gr.Row(elem_id="buttons") as button_row: |
| clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) |
|
|
| btn_list = [clear_btn] |
| |
| clear_btn.click( |
| clear_history, [state], [state, chatbot, textbox, video] + btn_list |
| ) |
| submit_btn.click( |
| add_text, |
| [state, textbox], |
| [state, chatbot, textbox,] + btn_list, |
| ).then( |
| http_bot, |
| [state], |
| [state, chatbot, video] + btn_list, |
| ) |
| return demo |
|
|
|
|