Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| class MNISTNet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(1, 32, 3, 1) | |
| self.conv2 = nn.Conv2d(32, 64, 3, 1) | |
| self.dropout1 = nn.Dropout(0.25) | |
| self.dropout2 = nn.Dropout(0.5) | |
| self.fc1 = nn.Linear(9216, 128) | |
| self.fc2 = nn.Linear(128, 10) | |
| def forward(self, x): | |
| x = F.relu(self.conv1(x)) | |
| x = F.relu(self.conv2(x)) | |
| x = F.max_pool2d(x, 2) | |
| x = self.dropout1(x) | |
| x = torch.flatten(x, 1) | |
| x = F.relu(self.fc1(x)) | |
| x = self.dropout2(x) | |
| x = self.fc2(x) | |
| return x | |
| model = MNISTNet() | |
| model_file = hf_hub_download(repo_id="Gaimundo/mnist-nn", filename="mnist_cnn.pt") | |
| model.load_state_dict(torch.load(model_file, map_location="cpu")) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Grayscale(), | |
| transforms.Resize((28,28)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| def predict(image): | |
| image = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(image) | |
| probs = torch.softmax(output, dim=1)[0] | |
| return {str(i): float(probs[i]) for i in range(10)} | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", image_mode="L", label="Draw a digit"), | |
| outputs=gr.Label(num_top_classes=10), | |
| title="MNIST Digit Classifier (PyTorch)" | |
| ) | |
| iface.launch() | |