import io, base64 import numpy as np from PIL import Image import torch from safetensors.torch import load_file from model import TinyConv LABELS = [str(i) for i in range(10)] def load_model(path: str = ""): m = TinyConv().eval() try: state = load_file(f"{path}model.safetensors") m.load_state_dict(state, strict=True) except Exception: # Fallback to randomly initialized weights if file missing pass return m MODEL = load_model() def preprocess_pil(img: Image.Image) -> torch.Tensor: img = img.convert("L").resize((28, 28)) x = torch.from_numpy(np.array(img)).float() / 255.0 x = x.unsqueeze(0).unsqueeze(0) # [1,1,28,28] return x def predict_from_pil(img: Image.Image): x = preprocess_pil(img) with torch.inference_mode(): logits = MODEL(x) probs = torch.softmax(logits, dim=-1)[0].tolist() return {label: float(probs[i]) for i, label in enumerate(LABELS)} def predict_from_base64(b64: str): img = Image.open(io.BytesIO(base64.b64decode(b64))) return predict_from_pil(img)