Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import torch | |
| import torchvision.transforms as transforms | |
| from models import get_model | |
| # 常量定义 | |
| MEAN = { | |
| "imagenet": [0.485, 0.456, 0.406], | |
| "clip": [0.48145466, 0.4578275, 0.40821073] | |
| } | |
| STD = { | |
| "imagenet": [0.229, 0.224, 0.225], | |
| "clip": [0.26862954, 0.26130258, 0.27577711] | |
| } | |
| def collect_examples_grouped(base_dir="examples"): | |
| """ | |
| 扫描 base_dir 下的 real/ 和 fake/ 子目录,分别返回两个二维列表: | |
| Returns: | |
| real_examples: [["examples/real/0001.jpg"], ...] | |
| fake_examples: [["examples/fake/0001.jpg"], ...] | |
| """ | |
| exts = (".jpg", ".jpeg", ".png", ".bmp", ".webp") | |
| real_examples = [] | |
| fake_examples = [] | |
| for cls in ["real", "fake"]: | |
| subdir = os.path.join(base_dir, cls) | |
| if not os.path.isdir(subdir): | |
| continue | |
| for fname in sorted(os.listdir(subdir)): | |
| if fname.lower().endswith(exts): | |
| path = [os.path.join(subdir, fname)] | |
| if cls == "real": | |
| real_examples.append(path) | |
| else: | |
| fake_examples.append(path) | |
| return real_examples, fake_examples | |
| class ForgeryDetector: | |
| def __init__(self, arch='CLIP:ViT-L/14', ckpt_path='checkpoints/model_epoch_best.pth', device='cuda'): | |
| """ | |
| 初始化伪造检测器 | |
| Args: | |
| arch: 模型架构 (如 'res50', 'CLIP:ViT-B/32' 等) | |
| ckpt_path: 预训练模型权重路径 | |
| device: 运行设备 ('cuda' 或 'cpu') | |
| """ | |
| self.arch = arch | |
| self.device = torch.device(device if torch.cuda.is_available() else 'cpu') | |
| # 根据架构选择统计信息 | |
| stat_from = "imagenet" if arch.lower().startswith("imagenet") else "clip" | |
| # 数据预处理 | |
| self.transform = transforms.Compose([ | |
| transforms.Resize(256), # 与验证代码保持一致 | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=MEAN[stat_from], std=STD[stat_from]), | |
| ]) | |
| # 创建模型选项对象 (模拟TrainOptions) | |
| self.opt = self._create_opt(arch) | |
| # 加载模型 | |
| self.model = self._load_model(ckpt_path) | |
| def _create_opt(self, arch): | |
| """创建模型配置选项""" | |
| class Options: | |
| pass | |
| opt = Options() | |
| opt.arch = arch | |
| opt.head_type = "attention" if arch.startswith("CLIP:") else "fc" | |
| opt.shuffle = True | |
| opt.shuffle_times = 1 | |
| opt.original_times = 1 | |
| opt.patch_size = [14] | |
| opt.penultimate_feature = False | |
| opt.patch_base = False | |
| return opt | |
| def _load_model(self, ckpt_path): | |
| """加载预训练模型""" | |
| print(f"正在加载模型架构: {self.arch}") | |
| model = get_model(self.opt) | |
| # 加载权重 | |
| print(f"正在加载权重: {ckpt_path}") | |
| state_dict = torch.load(ckpt_path, map_location=self.device) | |
| # 根据模型类型加载不同的权重 | |
| if self.opt.head_type == "fc": | |
| model.fc.load_state_dict(state_dict) | |
| elif self.opt.head_type == "attention": | |
| model.attention_head.load_state_dict(state_dict) | |
| else: | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| model.to(self.device) | |
| print("模型加载完成!") | |
| return model | |
| def preprocess_image(self, image:Image.Image): | |
| """ | |
| 预处理单张图片 | |
| Args: | |
| image_path: 图片路径 | |
| Returns: | |
| 处理后的tensor | |
| """ | |
| try: | |
| image_tensor = self.transform(image).unsqueeze(0) # 添加batch维度 | |
| return image_tensor.to(self.device) | |
| except Exception as e: | |
| raise ValueError(f"图片预处理失败: {e}") | |
| def predict(self, image: Image.Image, return_probabilities=True): | |
| """ | |
| 对单张图片进行伪造检测 | |
| Args: | |
| image: Image.Image | |
| return_probabilities: 是否返回概率值,否则返回类别 | |
| Returns: | |
| 如果return_probabilities=True: 返回 (real_prob, fake_prob) | |
| 如果return_probabilities=False: 返回 "real" 或 "fake" | |
| """ | |
| # 预处理图片 | |
| image_tensor = self.preprocess_image(image) | |
| # 推理 | |
| with torch.no_grad(): | |
| output = self.model(image_tensor) | |
| # 处理不同的输出格式 | |
| if output.shape[-1] == 2: | |
| # 二分类输出 | |
| probs = torch.softmax(output, dim=1)[0] | |
| real_prob = probs[0].item() | |
| fake_prob = probs[1].item() | |
| else: | |
| # 单值输出 (通常用sigmoid) | |
| fake_prob = torch.sigmoid(output).item() | |
| real_prob = 1 - fake_prob | |
| if return_probabilities: | |
| return real_prob, fake_prob | |
| else: | |
| return "real" if real_prob > fake_prob else "fake" | |
| def batch_predict(self, image_paths, return_probabilities=True): | |
| """ | |
| 批量预测多张图片 | |
| Args: | |
| image_paths: 图片路径列表 | |
| return_probabilities: 是否返回概率值 | |
| Returns: | |
| 预测结果列表 | |
| """ | |
| results = [] | |
| for image_path in image_paths: | |
| try: | |
| result = self.predict(image_path, return_probabilities) | |
| results.append({ | |
| 'image_path': image_path, | |
| 'result': result, | |
| 'status': 'success' | |
| }) | |
| except Exception as e: | |
| results.append({ | |
| 'image_path': image_path, | |
| 'result': None, | |
| 'status': f'error: {e}' | |
| }) | |
| return results | |
| model = ForgeryDetector() | |
| # ------- 1) (示例)假模型 ------- | |
| def predict(img: Image.Image): | |
| result = model.predict(img) | |
| probs = np.array(result) | |
| return {"Real Image": float(probs[0]), "Fake Image": float(probs[1])} | |
| real_examples, fake_examples = collect_examples_grouped() | |
| # ------- 2) UI ------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1 style='text-align: center;'>Face Forgery Detection</h1>") | |
| with gr.Row(): | |
| # —— 左列:图片 + 按钮 —— | |
| with gr.Column(scale=1): | |
| img_input = gr.Image(label="Input", type="pil") | |
| # ✨ 示例图片组件 | |
| # 把按钮放到 **同一列** 下方 | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| clear_btn = gr.ClearButton(value="Clear", components=[img_input]) | |
| gr.Markdown("### 🖼 Example Images") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### 🟢 Real Examples") | |
| gr.Examples(examples=real_examples, inputs=img_input, cache_examples=False) | |
| with gr.Column(): | |
| gr.Markdown("#### 🔴 Fake Examples") | |
| gr.Examples(examples=fake_examples, inputs=img_input, cache_examples=False) | |
| # —— 右列:结果条形图 —— | |
| with gr.Column(scale=1): | |
| label_output = gr.Label(label="output", num_top_classes=2) | |
| # 交互 | |
| submit_btn.click(predict, inputs=img_input, outputs=label_output) | |
| clear_btn.add([label_output]) # 同时清空右侧结果 | |
| if __name__ == "__main__": | |
| demo.launch() | |