| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class TinyConv(nn.Module): | |
| def __init__(self, num_classes: int = 10): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(1, 8, 3, padding=1) # 1x28x28 -> 8x28x28 | |
| self.conv2 = nn.Conv2d(8, 16, 3, padding=1) # 8x28x28 -> 16x28x28 | |
| self.pool = nn.AdaptiveAvgPool2d((1, 1)) # 16x1x1 | |
| self.fc = nn.Linear(16, num_classes) | |
| def forward(self, x): | |
| x = F.relu(self.conv1(x)) | |
| x = F.relu(self.conv2(x)) | |
| x = self.pool(x).flatten(1) | |
| return self.fc(x) |