zero123 / app.py
oscnet
重构: 使用 Zero123++ 通过旋转输入生成工程六视图
d7cb9d0
#!/usr/bin/env python3
"""
Zero123++ 工程六视图生成器
通过旋转输入图片来模拟不同视角
"""
import gradio as gr
import torch
from PIL import Image
import numpy as np
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
# 全局变量存储 pipeline
pipeline = None
# 工程六视图配置
# Zero123++ 固定输出 6 个视图,我们选择最接近正交视图的角度
# 方位角: 30, 90, 150, 210, 270, 330
# 我们使用: 90° (右), 210° (后偏左), 330° (前偏右)
ENGINEERING_VIEWS = {
"主视图": {"rotate_input": 0, "select_index": 0, "position": (0, 0)}, # 使用原始输入
"右视图": {"rotate_input": 0, "select_index": 1, "position": (1, 0)}, # 90度视角
"后视图": {"rotate_input": 0, "select_index": 3, "position": (2, 0)}, # 210度视角
"左视图": {"rotate_input": 0, "select_index": 4, "position": (0, 1)}, # 270度视角
"俯视图": {"rotate_input": -90, "select_index": 0, "position": (1, 1)}, # 旋转输入-90度
"底视图": {"rotate_input": 90, "select_index": 0, "position": (2, 1)}, # 旋转输入+90度
}
def load_model():
"""加载 Zero123++ 模型"""
global pipeline
if pipeline is not None:
return
print("正在加载 Zero123++ 模型...")
# 检查 CUDA 可用性
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
try:
# 加载 Zero123++ pipeline
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
custom_pipeline="sudo-ai/zero123plus-pipeline",
torch_dtype=dtype
)
# 设置调度器
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config,
timestep_spacing='trailing'
)
pipeline.to(device)
# 启用内存优化
if torch.cuda.is_available():
try:
pipeline.enable_attention_slicing()
pipeline.enable_vae_slicing()
except:
pass
print(f"✓ 模型加载完成 (设备: {device})")
except Exception as e:
print(f"错误: 无法加载 Zero123++ - {e}")
raise
def rotate_image(image, angle):
"""旋转图像"""
if angle == 0:
return image
elif angle == 90:
return image.rotate(-90, expand=True)
elif angle == -90:
return image.rotate(90, expand=True)
elif angle == 180:
return image.rotate(180, expand=True)
return image
def generate_multiview(input_image):
"""
生成多视图输出
输入:
input_image: PIL Image
输出:
PIL Image with 6 views (2x3 grid)
"""
global pipeline
if pipeline is None:
load_model()
# 预处理输入图像
img = input_image.resize((320, 320), Image.LANCZOS)
# 运行推理生成6个视图
result = pipeline(
img,
num_inference_steps=75
).images[0]
# Zero123++ 输出是 2x3 的网格,需要拆分
# 输出尺寸应该是 320*3 x 320*2 = 960x640
result_w, result_h = result.size
view_w = result_w // 3
view_h = result_h // 2
# 提取6个视图
views = []
for row in range(2):
for col in range(3):
x = col * view_w
y = row * view_h
view = result.crop((x, y, x + view_w, y + view_h))
views.append(view)
return views
def process_image(input_image, progress=gr.Progress()):
"""
处理输入图像,生成工程六视图
输入:
input_image: PIL Image
progress: Gradio Progress 跟踪器
输出:
result_image: PIL Image (六视图合成图)
"""
if input_image is None:
return None
try:
# 确保模型已加载
load_model()
# 预处理 - 转为正方形
img = input_image
if img.size[0] != img.size[1]:
size = min(img.size)
img = img.crop((
(img.size[0] - size) // 2,
(img.size[1] - size) // 2,
(img.size[0] + size) // 2,
(img.size[1] + size) // 2
))
progress(0.1, desc="生成水平视图...")
# 生成水平方向的视图 (主/右/后/左)
horizontal_views = generate_multiview(img)
progress(0.5, desc="生成俯视图...")
# 旋转输入-90度生成俯视图
img_rotated_up = rotate_image(img, -90)
top_views = generate_multiview(img_rotated_up)
progress(0.8, desc="生成底视图...")
# 旋转输入+90度生成底视图
img_rotated_down = rotate_image(img, 90)
bottom_views = generate_multiview(img_rotated_down)
# 组合成工程六视图
# 布局: 主/右/后
# 左/俯/底
view_size = 320
combined = Image.new('RGB', (view_size * 3, view_size * 2))
# 第一行: 主视图(0), 右视图(90), 后视图(210)
combined.paste(horizontal_views[0], (0 * view_size, 0)) # 主视图 (30度)
combined.paste(horizontal_views[1], (1 * view_size, 0)) # 右视图 (90度)
combined.paste(horizontal_views[3], (2 * view_size, 0)) # 后视图 (210度)
# 第二行: 左视图(270), 俯视图, 底视图
combined.paste(horizontal_views[4], (0 * view_size, view_size)) # 左视图 (270度)
combined.paste(top_views[0], (1 * view_size, view_size)) # 俯视图 (旋转输入-90)
combined.paste(bottom_views[0], (2 * view_size, view_size)) # 底视图 (旋转输入+90)
progress(1.0, desc="完成!")
print("✓ 所有视图生成完成")
return combined
except Exception as e:
print(f"错误: {e}")
import traceback
traceback.print_exc()
raise gr.Error(f"处理失败: {str(e)}")
# 创建 Gradio 界面
def create_demo():
with gr.Blocks(title="Zero123++ 工程六视图生成器") as demo:
gr.Markdown("""
# Zero123++ 工程六视图生成器
将单张主视图转换为工程六视图
**输入要求:**
- 建议使用正方形图片
- 推荐分辨率 >= 320x320
- 脚本会自动裁剪非正方形图片
**输出说明:**
生成工程六视图,排列为 2 行 3 列:
| 主视图 | 右视图 | 后视图 |
|-------|-------|-------|
| 左视图 | 俯视图 | 底视图 |
**技术原理:**
- 使用 Zero123++ v1.2 模型
- 通过 3 次推理 + 旋转输入实现工程六视图
- 每次推理约 30-60 秒,总计约 2-3 分钟
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="输入主视图",
type="pil",
height=400
)
generate_btn = gr.Button("生成工程六视图", variant="primary", size="lg")
gr.Markdown("""
**注意:**
- 需要运行 3 次推理(水平+俯视+底视)
- 总耗时约 2-3 分钟
- 请耐心等待
""")
with gr.Column():
output_image = gr.Image(
label="工程六视图输出",
type="pil",
height=400
)
gr.Markdown("""
### 视角说明
| 视图 | 方法 | 说明 |
|-----|------|------|
| 主视图 | Zero123++ 30° 视角 | 正面 |
| 右视图 | Zero123++ 90° 视角 | 右侧 |
| 后视图 | Zero123++ 210° 视角 | 背面 |
| 左视图 | Zero123++ 270° 视角 | 左侧 |
| 俯视图 | 输入旋转-90° → Zero123++ | 从上往下 |
| 底视图 | 输入旋转+90° → Zero123++ | 从下往上 |
### 技术说明
- 模型: [Zero123++ v1.2](https://huggingface.co/sudo-ai/zero123plus-v1.2)
- Zero123++ 固定输出 6 个环绕视图
- 通过选择合适的视角 + 旋转输入实现工程视图
- v1.2 改进: 更稳定的视角,FOV 30°
### 引用
```bibtex
@misc{shi2023zero123plus,
title={Zero123++: a Single Image to Consistent Multi-view Diffusion Base Model},
author={Ruoxi Shi and Hansheng Chen and others},
year={2023},
eprint={2310.15110},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
""")
# 绑定事件
generate_btn.click(
fn=process_image,
inputs=[input_image],
outputs=output_image
)
return demo
if __name__ == "__main__":
# 预加载模型
print("=" * 50)
print("Zero123++ 工程六视图生成器")
print("=" * 50)
load_model()
# 启动 demo
demo = create_demo()
demo.queue(max_size=5)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)