Spaces:
Runtime error
Runtime error
| """ | |
| Gradio demo of image classification with OOD detection. | |
| If the image example is probably OOD, the model will abstain from the prediction. | |
| """ | |
| import json | |
| import logging | |
| import pickle | |
| from glob import glob | |
| import gradio as gr | |
| import numpy as np | |
| import timm | |
| import torch | |
| import torch.nn.functional as F | |
| from gradio.components import JSON, Image, Label | |
| from timm.data import resolve_data_config | |
| from timm.data.transforms_factory import create_transform | |
| from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names | |
| _logger = logging.getLogger(__name__) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| TOPK = 3 | |
| # load model | |
| print("Loading model...") | |
| model = timm.create_model("resnet50.tv2_in1k", pretrained=True) | |
| model.to(device) | |
| model.eval() | |
| # dataset labels | |
| idx2label = json.loads(open("ilsvrc2012.json").read()) | |
| idx2label = {int(k): v for k, v in idx2label.items()} | |
| print(idx2label) | |
| print(idx2label.values()) | |
| # transformation | |
| config = resolve_data_config({}, model=model) | |
| config["is_training"] = False | |
| transform = create_transform(**config) | |
| # create feature extractor | |
| penultimate_features_key = "global_pool.flatten" | |
| logits_key = "fc" | |
| features_names = [penultimate_features_key, logits_key] | |
| feature_extractor = create_feature_extractor(model, features_names) | |
| # load centroids | |
| centroids = torch.load("centroids_resnet50.tv2_in1k_igeood_logits.pt") | |
| # OOD detector thresholds | |
| msp_threshold = 0.3796 | |
| energy_threshold = 8 | |
| igeood_threshold = 2.4984 | |
| def mahalanobis_penult(features): | |
| scores = torch.norm(features, dim=1, keepdims=True) | |
| s = torch.min(scores, dim=1)[0] | |
| return -s.item() | |
| def msp(logits): | |
| return torch.softmax(logits, dim=1).max(-1)[0].item() | |
| def energy(logits): | |
| return torch.logsumexp(logits, dim=1).item() | |
| def igeoodlogits_vec(logits, temperature, centroids, epsilon=1e-12): | |
| logits = torch.sqrt(F.softmax(logits / temperature, dim=1)) | |
| centroids = torch.sqrt(F.softmax(centroids / temperature, dim=1)) | |
| mult = logits @ centroids.T | |
| stack = 2 * torch.acos(torch.clamp(mult, -1 + epsilon, 1 - epsilon)) | |
| return stack.mean(dim=1).item() | |
| def predict(image): | |
| # forward pass | |
| inputs = transform(image).unsqueeze(0) | |
| inputs = inputs.to(device) | |
| with torch.no_grad(): | |
| features = feature_extractor(inputs) | |
| # top 5 predictions | |
| probabilities = torch.softmax(features[logits_key], dim=-1) | |
| softmax, class_idxs = torch.topk(probabilities, TOPK) | |
| _logger.info(softmax) | |
| _logger.info(class_idxs) | |
| result = {idx2label[i.item()]: v.item() for i, v in zip(class_idxs.squeeze(), softmax.squeeze())} | |
| # OOD | |
| msp_score = round(msp(features[logits_key]), 4) | |
| energy_score = round(energy(features[logits_key]), 4) | |
| igeood_scores = round(igeoodlogits_vec(features[logits_key], 1, centroids), 4) | |
| ood_scores = { | |
| "MSP": msp_score, | |
| "MSP, is the input OOD?": msp_score < msp_threshold, | |
| "Energy": energy_score, | |
| "Energy, is the input OOD?": energy_score < energy_threshold, | |
| "Igeood": igeood_scores, | |
| "Igeood, is the input OOD?": igeood_scores < igeood_threshold, | |
| } | |
| _logger.info(ood_scores) | |
| return result, ood_scores | |
| def main(): | |
| # image examples for demo shuffled | |
| examples = glob("images/imagenet/*") + glob("images/ood/*") | |
| np.random.seed(42) | |
| # np.random.shuffle(examples) | |
| # gradio interface | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=Image(type="pil"), | |
| outputs=[ | |
| Label(num_top_classes=TOPK, label="Model prediction"), | |
| JSON(label="OOD scores"), | |
| ], | |
| examples=examples, | |
| examples_per_page=len(examples), | |
| allow_flagging="never", | |
| theme="default", | |
| title="OOD Detection 🧐", | |
| description=( | |
| "Out-of-distribution (OOD) detection is an essential safety measure for machine learning models. " | |
| "The objective of an OOD detector is to determine wether the input sample comes from the distribution known by the AI model. " | |
| "For instance, an input that does not belong to any of the known classes or is from a different domain should be flagged by the detector.\n" | |
| "In this demo we will display the decision of three OOD detectors on a ResNet-50 model trained to classify on the ImageNet-1K dataset (top-1 accuracy 80%)." | |
| "This model can classify among 1000 classes from several categories, including `animals`, `vehicles`, `clothing`, `instruments`, `plants`, etc. " | |
| "For the complete hierarchy of classes, please check the website https://observablehq.com/@mbostock/imagenet-hierarchy. " | |
| "\n\n" | |
| "## Instructions:\n" | |
| "1. Upload an image of your choice or select one from the examples bar.\n" | |
| "2. The model will predict the top 3 most likely classes for the image.\n" | |
| "3. The OOD detectors will output their scores and decision on the image. The smaller the score, the least confident the detector is on the sample being in-distribution.\n" | |
| "4. If the image is OOD, the model will abstain from the prediction and flag it to the practicioner.\n" | |
| "\n\n\nEnjoy the demo!" | |
| ), | |
| cache_examples=True, | |
| ) | |
| interface.launch(server_port=7860) | |
| interface.close() | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.WARN) | |
| gr.close_all() | |
| main() | |