import torch import torch.nn as nn import torch.nn.functional as F class TinyConv(nn.Module): def __init__(self, num_classes: int = 10): super().__init__() self.conv1 = nn.Conv2d(1, 8, 3, padding=1) # 1x28x28 -> 8x28x28 self.conv2 = nn.Conv2d(8, 16, 3, padding=1) # 8x28x28 -> 16x28x28 self.pool = nn.AdaptiveAvgPool2d((1, 1)) # 16x1x1 self.fc = nn.Linear(16, num_classes) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = self.pool(x).flatten(1) return self.fc(x)