tiny-digits / inference.py
rigelbar's picture
test
cdde7db verified
raw
history blame
1.08 kB
import io, base64
import numpy as np
from PIL import Image
import torch
from safetensors.torch import load_file
from model import TinyConv
LABELS = [str(i) for i in range(10)]
def load_model(path: str = ""):
m = TinyConv().eval()
try:
state = load_file(f"{path}model.safetensors")
m.load_state_dict(state, strict=True)
except Exception:
# Fallback to randomly initialized weights if file missing
pass
return m
MODEL = load_model()
def preprocess_pil(img: Image.Image) -> torch.Tensor:
img = img.convert("L").resize((28, 28))
x = torch.from_numpy(np.array(img)).float() / 255.0
x = x.unsqueeze(0).unsqueeze(0) # [1,1,28,28]
return x
def predict_from_pil(img: Image.Image):
x = preprocess_pil(img)
with torch.inference_mode():
logits = MODEL(x)
probs = torch.softmax(logits, dim=-1)[0].tolist()
return {label: float(probs[i]) for i, label in enumerate(LABELS)}
def predict_from_base64(b64: str):
img = Image.open(io.BytesIO(base64.b64decode(b64)))
return predict_from_pil(img)