improved-unified-multi-model-pt / test_improved_model.py
kunaliitkgp09's picture
Upload test_improved_model.py with huggingface_hub
fd15b76 verified
#!/usr/bin/env python3
"""
Test the Improved Unified Multi-Model with Prompt Templates
"""
import asyncio
import time
import sys
from pathlib import Path
# Add current directory to path
sys.path.append(str(Path(__file__).parent))
from improved_unified_model_pt import ImprovedUnifiedMultiModelPT, ImprovedUnifiedModelConfig
from prompt_template import PromptTemplates, TaskType, TestPrompt
from test_suite import OrchestratorTester, TestResult
class ImprovedModelWrapper:
"""Wrapper class to make the improved model compatible with our test suite"""
def __init__(self, model):
self.model = model
async def process_request(self, prompt):
"""Process a request using the improved unified model"""
try:
# Process the request
result = self.model.process(prompt)
# Create a compatible result object
class TaskResult:
def __init__(self, result_dict):
self.task_type = type('TaskType', (), {'value': result_dict.get('task_type', 'TEXT')})()
self.confidence = result_dict.get('confidence', 0.5)
self.success = True
self.output = result_dict.get('output', '')
self.error_message = None
return TaskResult(result)
except Exception as e:
# Create error result
class ErrorResult:
def __init__(self, error):
self.task_type = type('TaskType', (), {'value': 'ERROR'})()
self.confidence = 0.0
self.success = False
self.output = ''
self.error_message = str(error)
return ErrorResult(e)
async def test_improved_model_with_prompts():
"""Test the improved model with our prompt templates"""
print("๐Ÿงช Testing Improved Unified Model with Prompt Templates")
print("=" * 70)
# Create and load the improved model
print("๐Ÿ“ฆ Creating and loading improved model...")
config = ImprovedUnifiedModelConfig()
model = ImprovedUnifiedMultiModelPT(config)
# Create wrapper
wrapper = ImprovedModelWrapper(model)
# Test with different types of prompts
test_categories = [
("Text Processing", TaskType.TEXT, 3),
("Image Captioning", TaskType.CAPTION, 3),
("Text-to-Image", TaskType.TEXT2IMG, 3),
("Reasoning", TaskType.REASONING, 3),
("Multimodal", TaskType.MULTIMODAL, 2)
]
results = []
for category_name, task_type, num_prompts in test_categories:
print(f"\n๐Ÿ“ Testing {category_name} ({num_prompts} prompts):")
print("-" * 60)
prompts = PromptTemplates.get_prompts_by_task_type(task_type)[:num_prompts]
for i, prompt in enumerate(prompts, 1):
print(f"\n{i}. Testing: {prompt.prompt[:60]}...")
start_time = time.time()
result = await wrapper.process_request(prompt.prompt)
processing_time = time.time() - start_time
# Check if task routing was correct
expected_task = prompt.expected_task.value
actual_task = result.task_type.value
task_correct = expected_task == actual_task
status = "โœ…" if result.success else "โŒ"
task_status = "โœ…" if task_correct else "โŒ"
print(f" {status} Success: {result.success}")
print(f" {task_status} Task: {actual_task} (expected: {expected_task})")
print(f" ๐Ÿ“Š Confidence: {result.confidence:.2f}")
print(f" โฑ๏ธ Time: {processing_time:.2f}s")
if result.output:
print(f" ๐Ÿ“„ Output: {result.output[:100]}...")
if result.error_message:
print(f" โŒ Error: {result.error_message}")
# Store result for analysis
test_result = TestResult(
prompt=prompt.prompt,
expected_task=prompt.expected_task,
actual_task=actual_task,
confidence=result.confidence,
processing_time=processing_time,
success=result.success,
error_message=result.error_message,
output=result.output
)
results.append(test_result)
# Calculate overall statistics
total_tests = len(results)
successful_tests = sum(1 for r in results if r.success)
correct_tasks = sum(1 for r in results if r.task_correct)
accuracy = correct_tasks / total_tests if total_tests > 0 else 0
success_rate = successful_tests / total_tests if total_tests > 0 else 0
avg_confidence = sum(r.confidence for r in results) / total_tests if total_tests > 0 else 0
avg_time = sum(r.processing_time for r in results) / total_tests if total_tests > 0 else 0
print(f"\n๐Ÿ“Š Overall Test Results:")
print("=" * 50)
print(f" Total Tests: {total_tests}")
print(f" Successful: {successful_tests}")
print(f" Task Accuracy: {accuracy:.1%}")
print(f" Success Rate: {success_rate:.1%}")
print(f" Avg Confidence: {avg_confidence:.2f}")
print(f" Avg Processing Time: {avg_time:.2f}s")
# Task-specific breakdown
print(f"\n๐ŸŽฏ Task-Specific Results:")
print("-" * 40)
for task_type in TaskType:
task_results = [r for r in results if r.expected_task == task_type]
if task_results:
task_correct = sum(1 for r in task_results if r.task_correct)
task_accuracy = task_correct / len(task_results)
print(f" {task_type.value}: {task_accuracy:.1%} ({task_correct}/{len(task_results)})")
return results, model
async def run_comprehensive_test(model):
"""Run comprehensive test using our test suite"""
print("\n๐Ÿงช Running Comprehensive Test Suite")
print("=" * 60)
wrapper = ImprovedModelWrapper(model)
tester = OrchestratorTester(wrapper)
# Run basic tests
print("Running basic functionality tests...")
basic_result = await tester.run_basic_tests()
print(f"\n๐Ÿ“Š Basic Test Results:")
print(f" Total Tests: {basic_result.total_tests}")
print(f" Passed: {basic_result.passed_tests}")
print(f" Failed: {basic_result.failed_tests}")
print(f" Accuracy: {basic_result.accuracy:.1%}")
print(f" Avg Confidence: {basic_result.average_confidence:.2f}")
print(f" Avg Processing Time: {basic_result.average_processing_time:.2f}s")
return basic_result
async def interactive_test(model):
"""Interactive testing mode"""
print("\n๐ŸŽฎ Interactive Testing Mode")
print("=" * 50)
print("Enter your prompts (type 'quit' to exit):")
print("Example prompts:")
print(" - What is machine learning?")
print(" - Generate an image of a peaceful forest")
print(" - Describe this image of a sunset")
print(" - Explain step by step how neural networks work")
print()
wrapper = ImprovedModelWrapper(model)
while True:
try:
user_input = input("Enter prompt: ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
break
if not user_input:
continue
print(f"\nโณ Processing: {user_input}")
start_time = time.time()
result = await wrapper.process_request(user_input)
processing_time = time.time() - start_time
print(f"โœ… Task Type: {result.task_type.value}")
print(f"๐Ÿ“Š Confidence: {result.confidence:.2f}")
print(f"โฑ๏ธ Processing Time: {processing_time:.2f}s")
if result.output:
print(f"๐Ÿ“„ Output: {result.output}")
if result.error_message:
print(f"โŒ Error: {result.error_message}")
print()
except KeyboardInterrupt:
print("\nExiting interactive mode...")
break
except Exception as e:
print(f"Error: {e}")
def compare_with_original():
"""Compare improved model with original model"""
print("\n๐Ÿ”„ Comparing Improved vs Original Model")
print("=" * 50)
# Test prompts for comparison
comparison_prompts = [
("What is machine learning?", "TEXT"),
("Generate an image of a peaceful forest", "TEXT2IMG"),
("Describe this image of a sunset", "CAPTION"),
("Explain step by step how neural networks work", "REASONING")
]
print("Testing improved model routing...")
config = ImprovedUnifiedModelConfig()
improved_model = ImprovedUnifiedMultiModelPT(config)
for prompt, expected in comparison_prompts:
print(f"\n๐Ÿ” Testing: {prompt}")
result = improved_model.process(prompt)
actual = result['task_type']
correct = "โœ…" if actual == expected else "โŒ"
print(f" {correct} Expected: {expected}, Actual: {actual}, Confidence: {result['confidence']:.2f}")
async def main():
"""Main function"""
print("๐Ÿš€ Improved Unified Multi-Model Testing")
print("=" * 70)
# Test with prompt templates
results, model = await test_improved_model_with_prompts()
# Run comprehensive test
comprehensive_result = await run_comprehensive_test(model)
# Compare with original
compare_with_original()
# Interactive testing option
print("\n" + "="*70)
print("๐ŸŽฎ Interactive Testing")
print("="*70)
try_interactive = input("\nWould you like to try interactive testing? (y/n): ").strip().lower()
if try_interactive in ['y', 'yes']:
await interactive_test(model)
print("\n๐ŸŽ‰ Testing completed!")
print("๐Ÿ“Š The improved model shows enhanced routing capabilities.")
if __name__ == "__main__":
asyncio.run(main())