fokan commited on
Commit
f91a057
·
verified ·
1 Parent(s): 6281b8a

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +5 -4
  2. app.py +14 -17
  3. requirements.txt +5 -4
README.md CHANGED
@@ -24,7 +24,8 @@ Zero-shot image classification for medical imagery powered by **google/medsiglip
24
  - Zero-shot predictions using the MedSigLIP vision-language model without fine-tuning.
25
  - Smart Modality Router v2 blends filename heuristics, simple color statistics, and a lightweight fallback classifier to choose the best label bank.
26
  - CT, Ultrasound, Musculoskeletal, chest X-ray, brain MRI, fundus, histopathology, skin, cardiovascular, and general label libraries curated from MedSigLIP prompts and clinical references.
27
- - CPU-optimized inference with single model load, float32 execution on CPU, capped torch threads, cached results, and batched label scoring.
 
28
  - Gradio interface ready for local execution or deployment to Hugging Face Spaces (verified on Gradio 4.44.1+, API disabled by default to avoid schema bugs).
29
 
30
 
@@ -104,10 +105,10 @@ Each label file contains 100-200 modality-specific diagnostic phrases reflecting
104
 
105
 
106
  ## Performance Considerations
107
- - Loads the MedSigLIP processor and model once at startup, keeps the model in `eval()` mode, and pins execution to a single CPU thread with `torch.set_num_threads(1)`.
108
  - Leverages the `cached_inference` utility (LRU cache of five items) to reuse results for repeated requests without re-running the full forward pass.
109
- - Splits label scoring into batches of 50 within the cache manager, applies softmax over the concatenated logits, and returns the top five predictions.
110
- - Executes in float32 on CPU (float16 on GPU when available) to balance precision and memory consumption.
111
  - Avoids `transformers.pipeline()` to retain full control over preprocessing, batching, and device placement.
112
 
113
 
 
24
  - Zero-shot predictions using the MedSigLIP vision-language model without fine-tuning.
25
  - Smart Modality Router v2 blends filename heuristics, simple color statistics, and a lightweight fallback classifier to choose the best label bank.
26
  - CT, Ultrasound, Musculoskeletal, chest X-ray, brain MRI, fundus, histopathology, skin, cardiovascular, and general label libraries curated from MedSigLIP prompts and clinical references.
27
+ - CPU-optimized inference with single model load, float32 execution on CPU, capped torch threads via `psutil`, cached results, and batched label scoring.
28
+ - Automatic image downscaling to 448×448 before scoring to keep memory usage predictable.
29
  - Gradio interface ready for local execution or deployment to Hugging Face Spaces (verified on Gradio 4.44.1+, API disabled by default to avoid schema bugs).
30
 
31
 
 
105
 
106
 
107
  ## Performance Considerations
108
+ - Loads the MedSigLIP processor and model once at startup, keeps the model in `eval()` mode, and limits PyTorch threading with `torch.set_num_threads(min(psutil.cpu_count(logical=False), 4))`.
109
  - Leverages the `cached_inference` utility (LRU cache of five items) to reuse results for repeated requests without re-running the full forward pass.
110
+ - Downscales incoming images to 448×448 prior to tokenization and splits label scoring into batches of 50, applying softmax over concatenated logits before returning the top five predictions.
111
+ - Executes the transformer in float32 for deterministic CPU inference while still supporting GPU acceleration when available.
112
  - Avoids `transformers.pipeline()` to retain full control over preprocessing, batching, and device placement.
113
 
114
 
app.py CHANGED
@@ -4,11 +4,12 @@ from functools import lru_cache
4
  from pathlib import Path
5
  from typing import Dict, List, Tuple
6
 
 
7
  import torch
8
  import gradio as gr
9
  from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
10
 
11
- from utils.cache_manager import cached_inference
12
  from utils.modality_router import detect_modality
13
 
14
 
@@ -19,19 +20,25 @@ MODEL_ID = "google/medsiglip-448"
19
 
20
  HF_TOKEN = os.getenv("HF_TOKEN")
21
 
22
- torch.set_num_threads(1)
 
23
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
- model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
26
 
27
- processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
 
 
 
 
28
  model = AutoModelForZeroShotImageClassification.from_pretrained(
29
  MODEL_ID,
30
  token=HF_TOKEN,
31
- torch_dtype=model_dtype,
32
  ).to(device)
33
  model.eval()
34
 
 
 
35
 
36
  LABEL_OVERRIDES = {
37
  "xray": "chest_labels.json",
@@ -64,7 +71,7 @@ def classify_medical_image(image_path: str) -> Dict[str, float]:
64
  return {}
65
 
66
  candidate_labels = get_candidate_labels(image_path)
67
- scores = cached_inference(image_path, candidate_labels, model, processor)
68
 
69
  if not scores:
70
  return {}
@@ -81,18 +88,8 @@ demo = gr.Interface(
81
  outputs=gr.Label(num_top_classes=5, label="🧠 Top Predictions"),
82
  title="🩻 MedSigLIP Smart Medical Classifier",
83
  description="Zero-shot model with automatic label filtering for different modalities.",
84
- allow_api=False,
85
  )
86
 
87
 
88
  if __name__ == "__main__":
89
- server_name = os.getenv("SERVER_NAME", "0.0.0.0")
90
- port_env = os.getenv("SERVER_PORT") or os.getenv("PORT") or "7860"
91
- share_env = os.getenv("GRADIO_SHARE", "false").lower()
92
-
93
- demo.launch(
94
- server_name=server_name,
95
- server_port=int(port_env),
96
- share=share_env in {"1", "true", "yes"},
97
- show_api=False,
98
- )
 
4
  from pathlib import Path
5
  from typing import Dict, List, Tuple
6
 
7
+ import psutil
8
  import torch
9
  import gradio as gr
10
  from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
11
 
12
+ from utils.cache_manager import cached_inference, configure_cache
13
  from utils.modality_router import detect_modality
14
 
15
 
 
20
 
21
  HF_TOKEN = os.getenv("HF_TOKEN")
22
 
23
+ physical_cores = psutil.cpu_count(logical=False) or psutil.cpu_count() or 1
24
+ torch.set_num_threads(min(physical_cores, 4))
25
 
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
27
 
28
+ processor = AutoProcessor.from_pretrained(
29
+ MODEL_ID,
30
+ token=HF_TOKEN,
31
+ use_fast=True,
32
+ )
33
  model = AutoModelForZeroShotImageClassification.from_pretrained(
34
  MODEL_ID,
35
  token=HF_TOKEN,
36
+ torch_dtype=torch.float32,
37
  ).to(device)
38
  model.eval()
39
 
40
+ configure_cache(model, processor)
41
+
42
 
43
  LABEL_OVERRIDES = {
44
  "xray": "chest_labels.json",
 
71
  return {}
72
 
73
  candidate_labels = get_candidate_labels(image_path)
74
+ scores = cached_inference(image_path, candidate_labels)
75
 
76
  if not scores:
77
  return {}
 
88
  outputs=gr.Label(num_top_classes=5, label="🧠 Top Predictions"),
89
  title="🩻 MedSigLIP Smart Medical Classifier",
90
  description="Zero-shot model with automatic label filtering for different modalities.",
 
91
  )
92
 
93
 
94
  if __name__ == "__main__":
95
+ demo.launch(server_name="0.0.0.0", server_port=7860, queue=True)
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,10 +1,11 @@
1
- torch
2
- transformers>=4.44.0
3
- gradio>=4.44.1
4
- huggingface_hub>=0.24.0
5
  sentencepiece
6
  Pillow
7
  numpy
8
  scikit-image
9
  timm
10
  tensorflow
 
 
1
+ torch>=2.4.0
2
+ transformers>=4.45.0
3
+ gradio>=4.44.0
4
+ huggingface_hub>=0.25.0
5
  sentencepiece
6
  Pillow
7
  numpy
8
  scikit-image
9
  timm
10
  tensorflow
11
+ psutil