Migueldiaz1 commited on
Commit
753a056
·
verified ·
1 Parent(s): 441c9e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -93
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException # <--- AÑADIDO HTTPException
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
  from fastapi.responses import Response
@@ -19,7 +19,7 @@ import google.generativeai as genai
19
  from typing import Optional, List, Any, Dict, Union
20
  from diffusers import StableDiffusionPipeline, LCMScheduler
21
 
22
- app = FastAPI(title="Mirage Medical Search API - Lite Version")
23
 
24
  app.add_middleware(
25
  CORSMiddleware,
@@ -29,22 +29,22 @@ app.add_middleware(
29
  allow_headers=["*"],
30
  )
31
 
32
- # --- 1. CONFIGURACIÓN DE MODELOS ---
33
  MODEL_NAME = 'hf-hub:luhuitong/CLIP-ViT-L-14-448px-MedICaT-ROCO'
34
  HF_DATASET_ID = "mdwiratathya/ROCO-radiology"
35
  SPLIT = "train"
36
  device = "cpu"
37
 
38
- # Variables Globales
39
  model = None
40
  tokenizer = None
41
- embeddings = None # Image Embeddings (Visual Only)
42
  metadata = None
43
  dataset_stream = None
44
  gemini_available = False
45
  pipe = None
46
 
47
- # --- 2. AUTENTICACIÓN ---
48
  try:
49
  hf_token = os.environ.get('HF_TOKEN')
50
  if hf_token:
@@ -58,7 +58,7 @@ try:
58
  except Exception as e:
59
  print(f"Error auth: {e}")
60
 
61
- # --- HELPER: PLACEHOLDER ---
62
  def create_placeholder_image(text="Image Error"):
63
  img = Image.new('RGB', (512, 512), color=(40, 40, 45))
64
  d = ImageDraw.Draw(img)
@@ -73,24 +73,15 @@ def create_placeholder_image(text="Image Error"):
73
  img.save(img_byte_arr, format='JPEG')
74
  return img_byte_arr.getvalue()
75
 
76
- # --- 3. CARGA DE DATOS ---
77
  @app.on_event("startup")
78
  async def load_data():
79
  global model, tokenizer, embeddings, metadata, dataset_stream, pipe
80
- print("--- INICIANDO MIRAGE BACKEND (Lite Version) ---")
81
-
82
- # 1. CARGAR CLIP
83
- try:
84
- print("👁️ Cargando CLIP...")
85
- model, _, _ = open_clip.create_model_and_transforms(MODEL_NAME, device=device)
86
- tokenizer = open_clip.get_tokenizer(MODEL_NAME)
87
- model.eval()
88
- print("✅ CLIP cargado.")
89
- except Exception as e:
90
- print(f"❌ Error CLIP: {e}")
91
 
92
- # 2. CARGAR METADATA
93
- print("📦 Cargando Metadata...")
94
  if os.path.exists("metadata_text.json"):
95
  with open("metadata_text.json", 'r') as f:
96
  metadata = json.load(f)
@@ -98,53 +89,34 @@ async def load_data():
98
  with open("metadata.json", 'r') as f:
99
  metadata = json.load(f)
100
  else:
101
- print("⚠️ NO SE ENCONTRÓ METADATA.")
102
  metadata = [{"dataset_index": 0, "filename": "error", "caption": "Error"}]
103
 
104
- # 3. CARGAR EMBEDDINGS DE IMAGEN (SOLO IMAGEN)
105
- if os.path.exists("embeddings.npy"):
106
- embeddings = np.load("embeddings.npy")
107
- print(f"✅ Image Embeddings listos: {embeddings.shape[0]} registros.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  else:
109
- print("⚠️ NO SE ENCONTRARON IMAGE EMBEDDINGS.")
110
- embeddings = np.zeros((1, 768))
111
-
112
- # 4. CARGAR DATASET
113
- try:
114
- print("📦 Cargando Dataset en RAM (1-2 mins)...")
115
- dataset_stream = load_dataset(HF_DATASET_ID, split=SPLIT, streaming=False)
116
- print(f"✅ Dataset listo. Total: {len(dataset_stream)}")
117
- except Exception as e:
118
- print(f"❌ Error dataset: {e}")
119
- dataset_stream = None
120
-
121
- # 5. CARGAR STABLE DIFFUSION (LCM)
122
- print("🎨 Cargando modelo generativo (LCM Mode)...")
123
- try:
124
- model_id = "Nihirc/Prompt2MedImage"
125
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
126
- print("⚡ Inyectando pesos LCM-LoRA...")
127
- pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
128
- pipe.fuse_lora()
129
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, solver_order=2)
130
- pipe.safety_checker = None
131
- pipe.requires_safety_checker = False
132
-
133
- if device == "cpu":
134
- pipe = pipe.to("cpu")
135
- pipe.enable_attention_slicing()
136
- else:
137
- pipe = pipe.to("cuda")
138
- print("✅ Generador LCM listo.")
139
- except Exception as e:
140
- print(f"❌ Error Generador: {e}")
141
-
142
-
143
- # --- 4. FUNCIONES CORE ---
144
 
145
  def calculate_vector(text, add=None, sub=None):
146
  with torch.no_grad():
147
- # Usamos el texto tal cual viene del usuario
148
  text_tokens = tokenizer([text]).to(device)
149
  vec = model.encode_text(text_tokens)
150
  vec /= vec.norm(dim=-1, keepdim=True)
@@ -160,17 +132,12 @@ def calculate_vector(text, add=None, sub=None):
160
  return vec
161
 
162
  def get_retrieval_and_context(query_vector, top_k):
163
- """
164
- Realiza el retrieval basado EXCLUSIVAMENTE en similitud visual.
165
- Query Text Vector vs Image Embeddings.
166
- """
167
  query_vec_np = query_vector.cpu().numpy()
168
 
169
- # Similitud Visual (Query vs Image Embeddings)
170
- # query_vec_np es (1, 768), embeddings es (N, 768) -> resultado (N,)
171
  sim_img = (query_vec_np @ embeddings.T).squeeze()
172
-
173
- # Ordenar índices (descendente)
174
  best_indices = sim_img.argsort()[-top_k:][::-1]
175
 
176
  real_matches = []
@@ -185,7 +152,7 @@ def get_retrieval_and_context(query_vector, top_k):
185
 
186
  real_matches.append({
187
  "url": f"/image/{safe_index}",
188
- "score": float(sim_img[idx]), # Solo score visual
189
  "filename": meta.get("filename", "img"),
190
  "caption": meta.get("caption", ""),
191
  "index": safe_index
@@ -202,7 +169,6 @@ def generate_llm_prompt(captions, user_text):
202
  return user_text + ". " + (captions[0] if captions else "")
203
  try:
204
  llm = genai.GenerativeModel('gemini-2.5-flash')
205
- # Prompt actualizado para usar directamente el texto del usuario
206
  prompt = f"Using the following medical query: '{user_text}', synthesize these findings into a concise radiology description: {', '.join(captions[:3])}"
207
  res = llm.generate_content(prompt)
208
  return res.text.strip()
@@ -217,7 +183,7 @@ def generate_synthetic_image(prompt, steps=5, guidance=1.5):
217
  image = pipe(prompt[:77], height=512, width=512, num_inference_steps=steps, guidance_scale=guidance, negative_prompt=NEGATIVE_PROMPT).images[0]
218
 
219
  draw = ImageDraw.Draw(image)
220
- text = "Created by MIRAGE LITE"
221
  try: font = ImageFont.load_default()
222
  except: font = None
223
  bbox = draw.textbbox((0, 0), text, font=font)
@@ -239,8 +205,8 @@ def fetch_image_from_stream(index):
239
  return dataset_stream[idx]['image']
240
  except Exception: return None
241
 
242
- # --- ENDPOINTS ---
243
- # CAMBIO IMPORTANTE: Renombrado a /api/health para dejar libre la raíz "/"
244
  @app.get("/api/health")
245
  def health_check():
246
  return {"status": "online", "version": "lite"}
@@ -261,7 +227,6 @@ def get_image(index: str):
261
  except Exception: pass
262
  return Response(content=create_placeholder_image("Error"), media_type="image/jpeg")
263
 
264
- # --- MODELOS PYDANTIC SIMPLIFICADOS ---
265
  class GenerationRequest(BaseModel):
266
  original_text: str
267
  sub_concept: Optional[str] = None
@@ -272,17 +237,14 @@ class GenerationRequest(BaseModel):
272
  guidance_scale: float = 1.5
273
  num_inference_steps: int = 5
274
 
275
- # --- ENDPOINT PRINCIPAL ---
276
  @app.post("/generate_comparison")
277
  def generate_comparison(req: GenerationRequest):
278
- if not model: raise HTTPException(status_code=503, detail="Loading...") # HTTPException ahora funciona
279
  try:
280
- # ASIGNACIÓN DIRECTA SIN TRADUCCIÓN
281
  final_query = req.original_text
282
  final_add = req.add_concept
283
  final_sub = req.sub_concept
284
-
285
- print(f"⚡ Procesando Lite (Raw Input): '{final_query}'")
286
 
287
  response_data = {
288
  "original_text": final_query,
@@ -292,13 +254,11 @@ def generate_comparison(req: GenerationRequest):
292
  "input_lang_detected": "raw"
293
  }
294
 
295
- # 1. PROCESAR ORIGINAL (Siempre Visual Search)
296
  vec_orig = calculate_vector(final_query)
297
  match_orig, caps_orig = get_retrieval_and_context(vec_orig, req.top_k)
298
 
299
  prompt_orig = ""
300
  if req.gen_text:
301
- # Pasa el texto original al LLM
302
  prompt_orig = generate_llm_prompt(caps_orig, final_query)
303
  else:
304
  prompt_orig = "LLM generation skipped."
@@ -316,7 +276,6 @@ def generate_comparison(req: GenerationRequest):
316
  }
317
  }
318
 
319
- # 2. PROCESAR MODIFICADO (Dual Search - Aritmética)
320
  has_dual = (final_add and final_add.strip()) and (final_sub and final_sub.strip())
321
  if has_dual:
322
  vec_mod = calculate_vector(final_query, final_add, final_sub)
@@ -324,7 +283,6 @@ def generate_comparison(req: GenerationRequest):
324
 
325
  prompt_mod = ""
326
  if req.gen_text:
327
- # Construye el string de aritmética sin traducción
328
  prompt_mod = generate_llm_prompt(caps_mod, f"{final_query} + {final_add} - {final_sub}")
329
  else:
330
  prompt_mod = "LLM generation skipped."
@@ -354,27 +312,20 @@ def search(req: GenerationRequest):
354
  return generate_comparison(req)
355
 
356
 
357
- # --- SERVIR FRONTEND ---
358
-
359
- # 1. Montar los assets estáticos (JS, CSS que genera Vite)
360
  app.mount("/assets", StaticFiles(directory="static/assets"), name="assets")
361
 
362
- # 2. Servir imágenes si las hay en public
363
  if os.path.exists("static/images"):
364
  app.mount("/images", StaticFiles(directory="static/images"), name="images")
365
 
366
- # 3. Ruta raíz -> Devuelve el HTML principal
367
  @app.get("/")
368
  async def read_index():
369
  return FileResponse('static/index.html')
370
 
371
- # 4. Catch-all: Cualquier otra ruta devuelve index.html (para que React Router no falle al recargar)
372
  @app.get("/{full_path:path}")
373
  async def catch_all(full_path: str):
374
- # Si intentan pedir un archivo que existe (ej. un .png), lo damos
375
  if os.path.exists(f"static/{full_path}"):
376
  return FileResponse(f"static/{full_path}")
377
- # Si no, devolvemos la app de React
378
  return FileResponse('static/index.html')
379
 
380
  if __name__ == "__main__":
 
1
+ from fastapi import FastAPI, HTTPException
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
  from fastapi.responses import Response
 
19
  from typing import Optional, List, Any, Dict, Union
20
  from diffusers import StableDiffusionPipeline, LCMScheduler
21
 
22
+ app = FastAPI(title="MIRAGE")
23
 
24
  app.add_middleware(
25
  CORSMiddleware,
 
29
  allow_headers=["*"],
30
  )
31
 
32
+ # Models
33
  MODEL_NAME = 'hf-hub:luhuitong/CLIP-ViT-L-14-448px-MedICaT-ROCO'
34
  HF_DATASET_ID = "mdwiratathya/ROCO-radiology"
35
  SPLIT = "train"
36
  device = "cpu"
37
 
38
+ # Glob variables
39
  model = None
40
  tokenizer = None
41
+ embeddings = None
42
  metadata = None
43
  dataset_stream = None
44
  gemini_available = False
45
  pipe = None
46
 
47
+ # authentication
48
  try:
49
  hf_token = os.environ.get('HF_TOKEN')
50
  if hf_token:
 
58
  except Exception as e:
59
  print(f"Error auth: {e}")
60
 
61
+ # to handle if there's an error
62
  def create_placeholder_image(text="Image Error"):
63
  img = Image.new('RGB', (512, 512), color=(40, 40, 45))
64
  d = ImageDraw.Draw(img)
 
73
  img.save(img_byte_arr, format='JPEG')
74
  return img_byte_arr.getvalue()
75
 
76
+ # load the data
77
  @app.on_event("startup")
78
  async def load_data():
79
  global model, tokenizer, embeddings, metadata, dataset_stream, pipe
80
+ model, _, _ = open_clip.create_model_and_transforms(MODEL_NAME, device=device)
81
+ tokenizer = open_clip.get_tokenizer(MODEL_NAME)
82
+ model.eval()
 
 
 
 
 
 
 
 
83
 
84
+ # load metadata
 
85
  if os.path.exists("metadata_text.json"):
86
  with open("metadata_text.json", 'r') as f:
87
  metadata = json.load(f)
 
89
  with open("metadata.json", 'r') as f:
90
  metadata = json.load(f)
91
  else:
92
+ print("no metadata file found")
93
  metadata = [{"dataset_index": 0, "filename": "error", "caption": "Error"}]
94
 
95
+ # load the embdeddings of the images (already processed)
96
+ embeddings = np.load("embeddings.npy")
97
+ print(f"✅ Image Embeddings listos: {embeddings.shape[0]} registros.")
98
+
99
+ # load the dataset
100
+ dataset_stream = load_dataset(HF_DATASET_ID, split=SPLIT, streaming=False)
101
+
102
+ # Load the Stable Diffusion LCM Pipeline
103
+ model_id = "Nihirc/Prompt2MedImage"
104
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
105
+ pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
106
+ pipe.fuse_lora()
107
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, solver_order=2)
108
+ pipe.safety_checker = None
109
+ pipe.requires_safety_checker = False
110
+
111
+ if device == "cpu":
112
+ pipe = pipe.to("cpu")
113
+ pipe.enable_attention_slicing()
114
  else:
115
+ pipe = pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  def calculate_vector(text, add=None, sub=None):
118
  with torch.no_grad():
119
+ # the user gives us a text, we obtain the embedding using CLIP
120
  text_tokens = tokenizer([text]).to(device)
121
  vec = model.encode_text(text_tokens)
122
  vec /= vec.norm(dim=-1, keepdim=True)
 
132
  return vec
133
 
134
  def get_retrieval_and_context(query_vector, top_k):
135
+ # We compare the query (text) embd with the image embeddings to retrieve
 
 
 
136
  query_vec_np = query_vector.cpu().numpy()
137
 
138
+
139
+ # query_vec_np (1, 768), embeddings (N, 768) -> result (N,)
140
  sim_img = (query_vec_np @ embeddings.T).squeeze()
 
 
141
  best_indices = sim_img.argsort()[-top_k:][::-1]
142
 
143
  real_matches = []
 
152
 
153
  real_matches.append({
154
  "url": f"/image/{safe_index}",
155
+ "score": float(sim_img[idx]),
156
  "filename": meta.get("filename", "img"),
157
  "caption": meta.get("caption", ""),
158
  "index": safe_index
 
169
  return user_text + ". " + (captions[0] if captions else "")
170
  try:
171
  llm = genai.GenerativeModel('gemini-2.5-flash')
 
172
  prompt = f"Using the following medical query: '{user_text}', synthesize these findings into a concise radiology description: {', '.join(captions[:3])}"
173
  res = llm.generate_content(prompt)
174
  return res.text.strip()
 
183
  image = pipe(prompt[:77], height=512, width=512, num_inference_steps=steps, guidance_scale=guidance, negative_prompt=NEGATIVE_PROMPT).images[0]
184
 
185
  draw = ImageDraw.Draw(image)
186
+ text = "Created by MIRAGE"
187
  try: font = ImageFont.load_default()
188
  except: font = None
189
  bbox = draw.textbbox((0, 0), text, font=font)
 
205
  return dataset_stream[idx]['image']
206
  except Exception: return None
207
 
208
+
209
+ # ENDPOINTS
210
  @app.get("/api/health")
211
  def health_check():
212
  return {"status": "online", "version": "lite"}
 
227
  except Exception: pass
228
  return Response(content=create_placeholder_image("Error"), media_type="image/jpeg")
229
 
 
230
  class GenerationRequest(BaseModel):
231
  original_text: str
232
  sub_concept: Optional[str] = None
 
237
  guidance_scale: float = 1.5
238
  num_inference_steps: int = 5
239
 
240
+ # this is the main endpoint
241
  @app.post("/generate_comparison")
242
  def generate_comparison(req: GenerationRequest):
243
+ if not model: raise HTTPException(status_code=503, detail="Loading...")
244
  try:
 
245
  final_query = req.original_text
246
  final_add = req.add_concept
247
  final_sub = req.sub_concept
 
 
248
 
249
  response_data = {
250
  "original_text": final_query,
 
254
  "input_lang_detected": "raw"
255
  }
256
 
 
257
  vec_orig = calculate_vector(final_query)
258
  match_orig, caps_orig = get_retrieval_and_context(vec_orig, req.top_k)
259
 
260
  prompt_orig = ""
261
  if req.gen_text:
 
262
  prompt_orig = generate_llm_prompt(caps_orig, final_query)
263
  else:
264
  prompt_orig = "LLM generation skipped."
 
276
  }
277
  }
278
 
 
279
  has_dual = (final_add and final_add.strip()) and (final_sub and final_sub.strip())
280
  if has_dual:
281
  vec_mod = calculate_vector(final_query, final_add, final_sub)
 
283
 
284
  prompt_mod = ""
285
  if req.gen_text:
 
286
  prompt_mod = generate_llm_prompt(caps_mod, f"{final_query} + {final_add} - {final_sub}")
287
  else:
288
  prompt_mod = "LLM generation skipped."
 
312
  return generate_comparison(req)
313
 
314
 
315
+ # To create the frontend serving
 
 
316
  app.mount("/assets", StaticFiles(directory="static/assets"), name="assets")
317
 
 
318
  if os.path.exists("static/images"):
319
  app.mount("/images", StaticFiles(directory="static/images"), name="images")
320
 
 
321
  @app.get("/")
322
  async def read_index():
323
  return FileResponse('static/index.html')
324
 
 
325
  @app.get("/{full_path:path}")
326
  async def catch_all(full_path: str):
 
327
  if os.path.exists(f"static/{full_path}"):
328
  return FileResponse(f"static/{full_path}")
 
329
  return FileResponse('static/index.html')
330
 
331
  if __name__ == "__main__":