| | import gradio as gr |
| | import torch |
| | import torch.nn.functional as F |
| | from sentence_transformers import SentenceTransformer |
| |
|
| | |
| | revision = None |
| | model = SentenceTransformer("avsolatorio/GIST-small-Embedding-v0", revision=revision) |
| |
|
| | |
| | ref_texts = [ |
| | "Theatro App: Hello John. Hey John. Hi John. Call John", |
| | "Theatro App: Message John. Message for John. Leave a message for John", |
| | "Theatro App: Play messages. Listen to messages", |
| | "Theatro App: What time is it?", |
| | "Theatro App: What time is it?", |
| | "Theatro App: Cashier Backup. Backup Cashier. Register backup. Register assistance.", |
| | "Theatro App: repeat", |
| | "Theatro App: Check inventory", |
| | "Theatro App: Check Sales", |
| | "Theatro App: Curbside Pickup", |
| | "Theatro App: Replay last message.", |
| | "Theatro App: Post it. Post it for group" |
| | "Theatro App: Announcement. Announcement for the group", |
| | "Open question: This is about products sold in TractorSupply.", |
| | "Open question: This is about pet care.", |
| | "Open question: What is the weather like?", |
| | "Open question: What's 15% off from $79.99?", |
| | "Open question: Can you look up the skew for 1091784?", |
| | ] |
| |
|
| | ref_embeddings = model.encode(ref_texts, convert_to_tensor=True) |
| |
|
| | def find_query_type(query): |
| | query_embeddings = model.encode([query], convert_to_tensor=True) |
| | scores = F.cosine_similarity(query_embeddings, ref_embeddings, dim=-1) |
| | max_index = torch.argmax(scores).item() |
| | ref_text = ref_texts[max_index] |
| | query_type = ref_text.split(": ")[0] |
| | return query_type |
| |
|
| | import gradio as gr |
| | def predict(query): |
| | query_type = find_query_type(query) |
| | return query_type |
| |
|
| | iface = gr.Interface(fn=predict, |
| | inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."), |
| | outputs="text", |
| | title="Query Type Classifier", |
| | description="This model classifies the type of your query. Just input your query and get the predicted category.") |
| |
|
| | iface.launch() |
| |
|