tiny-digits / model.py
rigelbar's picture
test
cdde7db verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class TinyConv(nn.Module):
def __init__(self, num_classes: int = 10):
super().__init__()
self.conv1 = nn.Conv2d(1, 8, 3, padding=1) # 1x28x28 -> 8x28x28
self.conv2 = nn.Conv2d(8, 16, 3, padding=1) # 8x28x28 -> 16x28x28
self.pool = nn.AdaptiveAvgPool2d((1, 1)) # 16x1x1
self.fc = nn.Linear(16, num_classes)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.pool(x).flatten(1)
return self.fc(x)