Gaimundo commited on
Commit
b07da2d
·
verified ·
1 Parent(s): 83c9a4c

Deploy MNIST CNN Gradio app

Browse files
Files changed (2) hide show
  1. app.py +54 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+
9
+ class MNISTNet(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.conv1 = nn.Conv2d(1, 32, 3, 1)
13
+ self.conv2 = nn.Conv2d(32, 64, 3, 1)
14
+ self.dropout1 = nn.Dropout(0.25)
15
+ self.dropout2 = nn.Dropout(0.5)
16
+ self.fc1 = nn.Linear(9216, 128)
17
+ self.fc2 = nn.Linear(128, 10)
18
+ def forward(self, x):
19
+ x = F.relu(self.conv1(x))
20
+ x = F.relu(self.conv2(x))
21
+ x = F.max_pool2d(x, 2)
22
+ x = self.dropout1(x)
23
+ x = torch.flatten(x, 1)
24
+ x = F.relu(self.fc1(x))
25
+ x = self.dropout2(x)
26
+ x = self.fc2(x)
27
+ return x
28
+
29
+ model = MNISTNet()
30
+ model.load_state_dict(torch.hub.load_state_dict_from_url('https://huggingface.co/Gaimundo/mnist-nn/resolve/main/mnist_cnn.pt', map_location='cpu'))
31
+ model.eval()
32
+
33
+ transform = transforms.Compose([
34
+ transforms.Grayscale(),
35
+ transforms.Resize((28,28)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize((0.1307,), (0.3081,))
38
+ ])
39
+
40
+ def predict(image):
41
+ image = transform(image).unsqueeze(0)
42
+ with torch.no_grad():
43
+ output = model(image)
44
+ probs = torch.softmax(output, dim=1)[0]
45
+ return {str(i): float(probs[i]) for i in range(10)}
46
+
47
+ iface = gr.Interface(
48
+ fn=predict,
49
+ inputs=gr.Image(type="pil", shape=(28,28), image_mode="L", label="Draw a digit"),
50
+ outputs=gr.Label(num_top_classes=10),
51
+ title="MNIST Digit Classifier (PyTorch)"
52
+ )
53
+
54
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio