Fadri commited on
Commit
2f41bd6
·
verified ·
1 Parent(s): 619cee0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -13
app.py CHANGED
@@ -1,20 +1,47 @@
1
  import gradio as gr
 
 
 
2
  from transformers import pipeline
3
 
4
- # Load models
5
- cifar10_classifier = pipeline("image-classification", model="Fadri/results")
6
- clip_detector = pipeline(model="openai/clip-vit-large-patch14", task="zero-shot-image-classification")
7
-
8
- # CIFAR-10 Klassen
9
  labels_cifar10 = [
10
  'airplane', 'automobile', 'bird', 'cat', 'deer',
11
  'dog', 'frog', 'horse', 'ship', 'truck'
12
  ]
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def classify_image(image):
15
- # Klassifikation mit deinem trainierten CIFAR-10 Modell
16
- cifar10_results = cifar10_classifier(image)
17
- cifar10_output = {result['label']: result['score'] for result in cifar10_results}
18
 
19
  # Zero-Shot-Klassifikation mit CLIP
20
  clip_results = clip_detector(image, candidate_labels=labels_cifar10)
@@ -25,14 +52,12 @@ def classify_image(image):
25
  "CLIP Zero-Shot Klassifikation": clip_output
26
  }
27
 
28
- # Beispielbilder - Du kannst diese später anpassen
29
  example_images = [
30
  ["examples/airplane.jpg"],
31
  ["examples/car.jpg"],
32
  ["examples/dog.jpg"],
33
- ["examples/cat.jpg"],
34
- ["examples/ship.jpg"],
35
- ["examples/truck.jpg"]
36
  ]
37
 
38
  # Gradio Interface
@@ -41,7 +66,7 @@ iface = gr.Interface(
41
  inputs=gr.Image(type="filepath"),
42
  outputs=gr.JSON(),
43
  title="CIFAR-10 Klassifikation",
44
- description="Lade ein Bild hoch und vergleiche die Ergebnisse zwischen deinem trainierten ViT Modell und dem Zero-Shot CLIP Modell für CIFAR-10 Klassen.",
45
  examples=example_images
46
  )
47
 
 
1
  import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoImageProcessor, ViTForImageClassification
5
  from transformers import pipeline
6
 
7
+ # CIFAR-10 Klassenlabels
 
 
 
 
8
  labels_cifar10 = [
9
  'airplane', 'automobile', 'bird', 'cat', 'deer',
10
  'dog', 'frog', 'horse', 'ship', 'truck'
11
  ]
12
 
13
+ # Lade Modell und Processor separat
14
+ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
15
+ model = ViTForImageClassification.from_pretrained("Fadri/results")
16
+
17
+ # CLIP für Zero-Shot bleibt wie vorher
18
+ clip_detector = pipeline(model="openai/clip-vit-large-patch14", task="zero-shot-image-classification")
19
+
20
+ def predict_cifar10(image_path):
21
+ # Bild laden und vorverarbeiten
22
+ image = Image.open(image_path).convert("RGB")
23
+ inputs = processor(images=image, return_tensors="pt")
24
+
25
+ # Modellvorhersage
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+ logits = outputs.logits
29
+ predicted_class_idx = logits.argmax(-1).item()
30
+
31
+ # Top-3 Ergebnisse mit Wahrscheinlichkeiten
32
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
33
+ top3_probs, top3_indices = torch.topk(probabilities, 3)
34
+
35
+ results = {}
36
+ for idx, prob in zip(top3_indices, top3_probs):
37
+ label = model.config.id2label[idx.item()]
38
+ results[label] = round(prob.item(), 4)
39
+
40
+ return results
41
+
42
  def classify_image(image):
43
+ # Klassifikation mit deinem Modell
44
+ cifar10_output = predict_cifar10(image)
 
45
 
46
  # Zero-Shot-Klassifikation mit CLIP
47
  clip_results = clip_detector(image, candidate_labels=labels_cifar10)
 
52
  "CLIP Zero-Shot Klassifikation": clip_output
53
  }
54
 
55
+ # Beispielbilder (Pfade anpassen)
56
  example_images = [
57
  ["examples/airplane.jpg"],
58
  ["examples/car.jpg"],
59
  ["examples/dog.jpg"],
60
+ ["examples/cat.jpg"]
 
 
61
  ]
62
 
63
  # Gradio Interface
 
66
  inputs=gr.Image(type="filepath"),
67
  outputs=gr.JSON(),
68
  title="CIFAR-10 Klassifikation",
69
+ description="Lade ein Bild hoch und vergleiche die Ergebnisse zwischen deinem trainierten ViT Modell und CLIP.",
70
  examples=example_images
71
  )
72