File size: 4,496 Bytes
0ed44c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
"""
Simple test script for the Nano Banana Image Edit API
"""
import requests
import sys
import os
API_BASE_URL = "http://localhost:8000"
def test_upload_image(image_path):
"""Test uploading an image"""
print(f"\n1. Uploading image: {image_path}")
if not os.path.exists(image_path):
print(f"Error: Image file not found: {image_path}")
return None
with open(image_path, "rb") as f:
response = requests.post(
f"{API_BASE_URL}/upload",
files={"file": f}
)
if response.status_code == 200:
data = response.json()
print(f"β Image uploaded successfully!")
print(f" Image ID: {data['image_id']}")
return data['image_id']
else:
print(f"β Upload failed: {response.status_code}")
print(f" {response.text}")
return None
def test_edit_image(image_id, prompt):
"""Test editing an image"""
print(f"\n2. Editing image with prompt: '{prompt}'")
response = requests.post(
f"{API_BASE_URL}/edit",
data={
"image_id": image_id,
"prompt": prompt
}
)
if response.status_code == 200:
data = response.json()
print(f"β Image edited successfully!")
print(f" Task ID: {data['task_id']}")
print(f" Status: {data['status']}")
return data['task_id']
else:
print(f"β Edit failed: {response.status_code}")
print(f" {response.text}")
return None
def test_get_result(task_id):
"""Test getting result"""
print(f"\n3. Getting result for task: {task_id}")
response = requests.get(f"{API_BASE_URL}/result/{task_id}")
if response.status_code == 200:
data = response.json()
print(f"β Result retrieved!")
print(f" Status: {data['status']}")
if data.get('result_image_id'):
print(f" Result Image ID: {data['result_image_id']}")
print(f" Result URL: {data['result_image_url']}")
return data
else:
print(f"β Get result failed: {response.status_code}")
print(f" {response.text}")
return None
def test_download_image(result_image_id, output_path):
"""Test downloading the edited image"""
print(f"\n4. Downloading edited image...")
response = requests.get(f"{API_BASE_URL}/result/image/{result_image_id}")
if response.status_code == 200:
with open(output_path, "wb") as f:
f.write(response.content)
print(f"β Image downloaded to: {output_path}")
return True
else:
print(f"β Download failed: {response.status_code}")
print(f" {response.text}")
return False
def test_health():
"""Test health endpoint"""
print("Testing API health...")
try:
response = requests.get(f"{API_BASE_URL}/health")
if response.status_code == 200:
data = response.json()
print(f"β API is healthy")
print(f" Model available: {data.get('model_available', False)}")
print(f" Model loaded: {data.get('model_loaded', False)}")
return True
else:
print(f"β Health check failed: {response.status_code}")
return False
except requests.exceptions.ConnectionError:
print(f"β Cannot connect to API at {API_BASE_URL}")
print(" Make sure the API server is running: python api.py")
return False
if __name__ == "__main__":
# Check health first
if not test_health():
sys.exit(1)
# Get image path from command line or use default
if len(sys.argv) > 1:
image_path = sys.argv[1]
else:
print("Usage: python test_api.py <image_path> [prompt]")
print("Example: python test_api.py test_image.jpg 'make the sky blue'")
sys.exit(1)
prompt = sys.argv[2] if len(sys.argv) > 2 else "enhance the image"
# Run tests
image_id = test_upload_image(image_path)
if not image_id:
sys.exit(1)
task_id = test_edit_image(image_id, prompt)
if not task_id:
sys.exit(1)
result = test_get_result(task_id)
if not result or not result.get('result_image_id'):
sys.exit(1)
result_image_id = result['result_image_id']
output_path = f"edited_{os.path.basename(image_path)}"
test_download_image(result_image_id, output_path)
print("\nβ All tests completed!")
|