Spaces:
Sleeping
Sleeping
File size: 7,863 Bytes
4e0ff2f 3230210 4e0ff2f 3230210 4e0ff2f 3230210 4e0ff2f 3230210 4e0ff2f 3230210 4e0ff2f f14b4f0 4e0ff2f f14b4f0 4e0ff2f 1d8b715 4e0ff2f 1d8b715 3230210 4e0ff2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import gradio as gr
from PIL import Image
import numpy as np
import os
import torch
import torchvision.transforms as transforms
from models import get_model
# 常量定义
MEAN = {
"imagenet": [0.485, 0.456, 0.406],
"clip": [0.48145466, 0.4578275, 0.40821073]
}
STD = {
"imagenet": [0.229, 0.224, 0.225],
"clip": [0.26862954, 0.26130258, 0.27577711]
}
def collect_examples_grouped(base_dir="examples"):
"""
扫描 base_dir 下的 real/ 和 fake/ 子目录,分别返回两个二维列表:
Returns:
real_examples: [["examples/real/0001.jpg"], ...]
fake_examples: [["examples/fake/0001.jpg"], ...]
"""
exts = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
real_examples = []
fake_examples = []
for cls in ["real", "fake"]:
subdir = os.path.join(base_dir, cls)
if not os.path.isdir(subdir):
continue
for fname in sorted(os.listdir(subdir)):
if fname.lower().endswith(exts):
path = [os.path.join(subdir, fname)]
if cls == "real":
real_examples.append(path)
else:
fake_examples.append(path)
return real_examples, fake_examples
class ForgeryDetector:
def __init__(self, arch='CLIP:ViT-L/14', ckpt_path='checkpoints/model_epoch_best.pth', device='cuda'):
"""
初始化伪造检测器
Args:
arch: 模型架构 (如 'res50', 'CLIP:ViT-B/32' 等)
ckpt_path: 预训练模型权重路径
device: 运行设备 ('cuda' 或 'cpu')
"""
self.arch = arch
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
# 根据架构选择统计信息
stat_from = "imagenet" if arch.lower().startswith("imagenet") else "clip"
# 数据预处理
self.transform = transforms.Compose([
transforms.Resize(256), # 与验证代码保持一致
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN[stat_from], std=STD[stat_from]),
])
# 创建模型选项对象 (模拟TrainOptions)
self.opt = self._create_opt(arch)
# 加载模型
self.model = self._load_model(ckpt_path)
def _create_opt(self, arch):
"""创建模型配置选项"""
class Options:
pass
opt = Options()
opt.arch = arch
opt.head_type = "attention" if arch.startswith("CLIP:") else "fc"
opt.shuffle = True
opt.shuffle_times = 1
opt.original_times = 1
opt.patch_size = [14]
opt.penultimate_feature = False
opt.patch_base = False
return opt
def _load_model(self, ckpt_path):
"""加载预训练模型"""
print(f"正在加载模型架构: {self.arch}")
model = get_model(self.opt)
# 加载权重
print(f"正在加载权重: {ckpt_path}")
state_dict = torch.load(ckpt_path, map_location=self.device)
# 根据模型类型加载不同的权重
if self.opt.head_type == "fc":
model.fc.load_state_dict(state_dict)
elif self.opt.head_type == "attention":
model.attention_head.load_state_dict(state_dict)
else:
model.load_state_dict(state_dict)
model.eval()
model.to(self.device)
print("模型加载完成!")
return model
def preprocess_image(self, image:Image.Image):
"""
预处理单张图片
Args:
image_path: 图片路径
Returns:
处理后的tensor
"""
try:
image_tensor = self.transform(image).unsqueeze(0) # 添加batch维度
return image_tensor.to(self.device)
except Exception as e:
raise ValueError(f"图片预处理失败: {e}")
def predict(self, image: Image.Image, return_probabilities=True):
"""
对单张图片进行伪造检测
Args:
image: Image.Image
return_probabilities: 是否返回概率值,否则返回类别
Returns:
如果return_probabilities=True: 返回 (real_prob, fake_prob)
如果return_probabilities=False: 返回 "real" 或 "fake"
"""
# 预处理图片
image_tensor = self.preprocess_image(image)
# 推理
with torch.no_grad():
output = self.model(image_tensor)
# 处理不同的输出格式
if output.shape[-1] == 2:
# 二分类输出
probs = torch.softmax(output, dim=1)[0]
real_prob = probs[0].item()
fake_prob = probs[1].item()
else:
# 单值输出 (通常用sigmoid)
fake_prob = torch.sigmoid(output).item()
real_prob = 1 - fake_prob
if return_probabilities:
return real_prob, fake_prob
else:
return "real" if real_prob > fake_prob else "fake"
def batch_predict(self, image_paths, return_probabilities=True):
"""
批量预测多张图片
Args:
image_paths: 图片路径列表
return_probabilities: 是否返回概率值
Returns:
预测结果列表
"""
results = []
for image_path in image_paths:
try:
result = self.predict(image_path, return_probabilities)
results.append({
'image_path': image_path,
'result': result,
'status': 'success'
})
except Exception as e:
results.append({
'image_path': image_path,
'result': None,
'status': f'error: {e}'
})
return results
model = ForgeryDetector()
# ------- 1) (示例)假模型 -------
def predict(img: Image.Image):
result = model.predict(img)
probs = np.array(result)
return {"Real Image": float(probs[0]), "Fake Image": float(probs[1])}
real_examples, fake_examples = collect_examples_grouped()
# ------- 2) UI -------
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Face Forgery Detection</h1>")
with gr.Row():
# —— 左列:图片 + 按钮 ——
with gr.Column(scale=1):
img_input = gr.Image(label="Input", type="pil")
# ✨ 示例图片组件
# 把按钮放到 **同一列** 下方
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.ClearButton(value="Clear", components=[img_input])
gr.Markdown("### 🖼 Example Images")
with gr.Row():
with gr.Column():
gr.Markdown("#### 🟢 Real Examples")
gr.Examples(examples=real_examples, inputs=img_input, cache_examples=False)
with gr.Column():
gr.Markdown("#### 🔴 Fake Examples")
gr.Examples(examples=fake_examples, inputs=img_input, cache_examples=False)
# —— 右列:结果条形图 ——
with gr.Column(scale=1):
label_output = gr.Label(label="output", num_top_classes=2)
# 交互
submit_btn.click(predict, inputs=img_input, outputs=label_output)
clear_btn.add([label_output]) # 同时清空右侧结果
if __name__ == "__main__":
demo.launch()
|