File size: 4,496 Bytes
0ed44c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Simple test script for the Nano Banana Image Edit API
"""
import requests
import sys
import os

API_BASE_URL = "http://localhost:8000"

def test_upload_image(image_path):
    """Test uploading an image"""
    print(f"\n1. Uploading image: {image_path}")
    
    if not os.path.exists(image_path):
        print(f"Error: Image file not found: {image_path}")
        return None
    
    with open(image_path, "rb") as f:
        response = requests.post(
            f"{API_BASE_URL}/upload",
            files={"file": f}
        )
    
    if response.status_code == 200:
        data = response.json()
        print(f"βœ“ Image uploaded successfully!")
        print(f"  Image ID: {data['image_id']}")
        return data['image_id']
    else:
        print(f"βœ— Upload failed: {response.status_code}")
        print(f"  {response.text}")
        return None

def test_edit_image(image_id, prompt):
    """Test editing an image"""
    print(f"\n2. Editing image with prompt: '{prompt}'")
    
    response = requests.post(
        f"{API_BASE_URL}/edit",
        data={
            "image_id": image_id,
            "prompt": prompt
        }
    )
    
    if response.status_code == 200:
        data = response.json()
        print(f"βœ“ Image edited successfully!")
        print(f"  Task ID: {data['task_id']}")
        print(f"  Status: {data['status']}")
        return data['task_id']
    else:
        print(f"βœ— Edit failed: {response.status_code}")
        print(f"  {response.text}")
        return None

def test_get_result(task_id):
    """Test getting result"""
    print(f"\n3. Getting result for task: {task_id}")
    
    response = requests.get(f"{API_BASE_URL}/result/{task_id}")
    
    if response.status_code == 200:
        data = response.json()
        print(f"βœ“ Result retrieved!")
        print(f"  Status: {data['status']}")
        if data.get('result_image_id'):
            print(f"  Result Image ID: {data['result_image_id']}")
            print(f"  Result URL: {data['result_image_url']}")
        return data
    else:
        print(f"βœ— Get result failed: {response.status_code}")
        print(f"  {response.text}")
        return None

def test_download_image(result_image_id, output_path):
    """Test downloading the edited image"""
    print(f"\n4. Downloading edited image...")
    
    response = requests.get(f"{API_BASE_URL}/result/image/{result_image_id}")
    
    if response.status_code == 200:
        with open(output_path, "wb") as f:
            f.write(response.content)
        print(f"βœ“ Image downloaded to: {output_path}")
        return True
    else:
        print(f"βœ— Download failed: {response.status_code}")
        print(f"  {response.text}")
        return False

def test_health():
    """Test health endpoint"""
    print("Testing API health...")
    try:
        response = requests.get(f"{API_BASE_URL}/health")
        if response.status_code == 200:
            data = response.json()
            print(f"βœ“ API is healthy")
            print(f"  Model available: {data.get('model_available', False)}")
            print(f"  Model loaded: {data.get('model_loaded', False)}")
            return True
        else:
            print(f"βœ— Health check failed: {response.status_code}")
            return False
    except requests.exceptions.ConnectionError:
        print(f"βœ— Cannot connect to API at {API_BASE_URL}")
        print("  Make sure the API server is running: python api.py")
        return False

if __name__ == "__main__":
    # Check health first
    if not test_health():
        sys.exit(1)
    
    # Get image path from command line or use default
    if len(sys.argv) > 1:
        image_path = sys.argv[1]
    else:
        print("Usage: python test_api.py <image_path> [prompt]")
        print("Example: python test_api.py test_image.jpg 'make the sky blue'")
        sys.exit(1)
    
    prompt = sys.argv[2] if len(sys.argv) > 2 else "enhance the image"
    
    # Run tests
    image_id = test_upload_image(image_path)
    if not image_id:
        sys.exit(1)
    
    task_id = test_edit_image(image_id, prompt)
    if not task_id:
        sys.exit(1)
    
    result = test_get_result(task_id)
    if not result or not result.get('result_image_id'):
        sys.exit(1)
    
    result_image_id = result['result_image_id']
    output_path = f"edited_{os.path.basename(image_path)}"
    test_download_image(result_image_id, output_path)
    
    print("\nβœ“ All tests completed!")