""" Simple test script for the Nano Banana Image Edit API """ import requests import sys import os API_BASE_URL = "http://localhost:8000" def test_upload_image(image_path): """Test uploading an image""" print(f"\n1. Uploading image: {image_path}") if not os.path.exists(image_path): print(f"Error: Image file not found: {image_path}") return None with open(image_path, "rb") as f: response = requests.post( f"{API_BASE_URL}/upload", files={"file": f} ) if response.status_code == 200: data = response.json() print(f"✓ Image uploaded successfully!") print(f" Image ID: {data['image_id']}") return data['image_id'] else: print(f"✗ Upload failed: {response.status_code}") print(f" {response.text}") return None def test_edit_image(image_id, prompt): """Test editing an image""" print(f"\n2. Editing image with prompt: '{prompt}'") response = requests.post( f"{API_BASE_URL}/edit", data={ "image_id": image_id, "prompt": prompt } ) if response.status_code == 200: data = response.json() print(f"✓ Image edited successfully!") print(f" Task ID: {data['task_id']}") print(f" Status: {data['status']}") return data['task_id'] else: print(f"✗ Edit failed: {response.status_code}") print(f" {response.text}") return None def test_get_result(task_id): """Test getting result""" print(f"\n3. Getting result for task: {task_id}") response = requests.get(f"{API_BASE_URL}/result/{task_id}") if response.status_code == 200: data = response.json() print(f"✓ Result retrieved!") print(f" Status: {data['status']}") if data.get('result_image_id'): print(f" Result Image ID: {data['result_image_id']}") print(f" Result URL: {data['result_image_url']}") return data else: print(f"✗ Get result failed: {response.status_code}") print(f" {response.text}") return None def test_download_image(result_image_id, output_path): """Test downloading the edited image""" print(f"\n4. Downloading edited image...") response = requests.get(f"{API_BASE_URL}/result/image/{result_image_id}") if response.status_code == 200: with open(output_path, "wb") as f: f.write(response.content) print(f"✓ Image downloaded to: {output_path}") return True else: print(f"✗ Download failed: {response.status_code}") print(f" {response.text}") return False def test_health(): """Test health endpoint""" print("Testing API health...") try: response = requests.get(f"{API_BASE_URL}/health") if response.status_code == 200: data = response.json() print(f"✓ API is healthy") print(f" Model available: {data.get('model_available', False)}") print(f" Model loaded: {data.get('model_loaded', False)}") return True else: print(f"✗ Health check failed: {response.status_code}") return False except requests.exceptions.ConnectionError: print(f"✗ Cannot connect to API at {API_BASE_URL}") print(" Make sure the API server is running: python api.py") return False if __name__ == "__main__": # Check health first if not test_health(): sys.exit(1) # Get image path from command line or use default if len(sys.argv) > 1: image_path = sys.argv[1] else: print("Usage: python test_api.py [prompt]") print("Example: python test_api.py test_image.jpg 'make the sky blue'") sys.exit(1) prompt = sys.argv[2] if len(sys.argv) > 2 else "enhance the image" # Run tests image_id = test_upload_image(image_path) if not image_id: sys.exit(1) task_id = test_edit_image(image_id, prompt) if not task_id: sys.exit(1) result = test_get_result(task_id) if not result or not result.get('result_image_id'): sys.exit(1) result_image_id = result['result_image_id'] output_path = f"edited_{os.path.basename(image_path)}" test_download_image(result_image_id, output_path) print("\n✓ All tests completed!")