|
|
import io, base64 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
from model import TinyConv |
|
|
|
|
|
LABELS = [str(i) for i in range(10)] |
|
|
|
|
|
def load_model(path: str = ""): |
|
|
m = TinyConv().eval() |
|
|
try: |
|
|
state = load_file(f"{path}model.safetensors") |
|
|
m.load_state_dict(state, strict=True) |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
return m |
|
|
|
|
|
MODEL = load_model() |
|
|
|
|
|
def preprocess_pil(img: Image.Image) -> torch.Tensor: |
|
|
img = img.convert("L").resize((28, 28)) |
|
|
x = torch.from_numpy(np.array(img)).float() / 255.0 |
|
|
x = x.unsqueeze(0).unsqueeze(0) |
|
|
return x |
|
|
|
|
|
def predict_from_pil(img: Image.Image): |
|
|
x = preprocess_pil(img) |
|
|
with torch.inference_mode(): |
|
|
logits = MODEL(x) |
|
|
probs = torch.softmax(logits, dim=-1)[0].tolist() |
|
|
return {label: float(probs[i]) for i, label in enumerate(LABELS)} |
|
|
|
|
|
def predict_from_base64(b64: str): |
|
|
img = Image.open(io.BytesIO(base64.b64decode(b64))) |
|
|
return predict_from_pil(img) |
|
|
|