File size: 2,493 Bytes
051c5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import json
import argparse
import sys
from handler import EndpointHandler

def test_inference(model_path=".", prompt=None, max_tokens=150, temperature=0.7):
    """
    Test the inference endpoint handler with a sample request.
    
    Args:
        model_path: Path to the model directory
        prompt: Custom prompt to use (optional)
        max_tokens: Maximum number of tokens to generate
        temperature: Temperature for generation
    """
    try:
        print(f"Initializing handler with model path: {model_path}")
        handler = EndpointHandler(model_path)
        
        # Default or custom prompt
        if prompt is None:
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": "Explain quantum computing in simple terms."}
            ]
        else:
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ]
        
        # Sample request with OpenAI-like format
        request = {
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": 0.95
        }
        
        print("Sending request to handler...")
        print(f"Request: {json.dumps(request, indent=2)}")
        
        # Generate response
        response = handler(request)
        
        # Print response in a readable format
        print("\nResponse:")
        print(json.dumps(response, indent=2))
        
        return response
    
    except Exception as e:
        print(f"Error during inference: {str(e)}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        return {"error": str(e)}

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test Phi-4 Mini inference")
    parser.add_argument("--model_path", type=str, default=".", help="Path to the model directory")
    parser.add_argument("--prompt", type=str, help="Custom prompt to use")
    parser.add_argument("--max_tokens", type=int, default=150, help="Maximum number of tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation")
    
    args = parser.parse_args()
    test_inference(
        model_path=args.model_path,
        prompt=args.prompt,
        max_tokens=args.max_tokens,
        temperature=args.temperature
    )