Spaces:
No application file
No application file
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| import io | |
| import torch | |
| from torchvision import models, transforms | |
| # 加载预训练的ResNet-50模型 | |
| model = models.resnet50(pretrained=True) | |
| model.eval() # 设置模型为评估模式 | |
| # 图像预处理 | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # 创建FastAPI应用实例 | |
| app = FastAPI() | |
| async def predict(file: UploadFile = File(...)): | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| # 预处理图片 | |
| input_tensor = preprocess(image) | |
| input_batch = input_tensor.unsqueeze(0) # 添加批处理维度 | |
| with torch.no_grad(): | |
| output = model(input_batch) | |
| # 获取预测结果 | |
| _, predicted_idx = torch.max(output, 1) | |
| # 可以在此处添加代码来获取类别名称,这里只返回索引 | |
| return JSONResponse(content={"predicted_class": int(predicted_idx[0])}) | |
| # 运行服务 | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |