|
|
|
|
|
""" |
|
|
Zero123++ 工程六视图生成器 |
|
|
通过旋转输入图片来模拟不同视角 |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler |
|
|
|
|
|
|
|
|
pipeline = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ENGINEERING_VIEWS = { |
|
|
"主视图": {"rotate_input": 0, "select_index": 0, "position": (0, 0)}, |
|
|
"右视图": {"rotate_input": 0, "select_index": 1, "position": (1, 0)}, |
|
|
"后视图": {"rotate_input": 0, "select_index": 3, "position": (2, 0)}, |
|
|
"左视图": {"rotate_input": 0, "select_index": 4, "position": (0, 1)}, |
|
|
"俯视图": {"rotate_input": -90, "select_index": 0, "position": (1, 1)}, |
|
|
"底视图": {"rotate_input": 90, "select_index": 0, "position": (2, 1)}, |
|
|
} |
|
|
|
|
|
def load_model(): |
|
|
"""加载 Zero123++ 模型""" |
|
|
global pipeline |
|
|
|
|
|
if pipeline is not None: |
|
|
return |
|
|
|
|
|
print("正在加载 Zero123++ 模型...") |
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
try: |
|
|
|
|
|
pipeline = DiffusionPipeline.from_pretrained( |
|
|
"sudo-ai/zero123plus-v1.2", |
|
|
custom_pipeline="sudo-ai/zero123plus-pipeline", |
|
|
torch_dtype=dtype |
|
|
) |
|
|
|
|
|
|
|
|
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( |
|
|
pipeline.scheduler.config, |
|
|
timestep_spacing='trailing' |
|
|
) |
|
|
|
|
|
pipeline.to(device) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
try: |
|
|
pipeline.enable_attention_slicing() |
|
|
pipeline.enable_vae_slicing() |
|
|
except: |
|
|
pass |
|
|
|
|
|
print(f"✓ 模型加载完成 (设备: {device})") |
|
|
except Exception as e: |
|
|
print(f"错误: 无法加载 Zero123++ - {e}") |
|
|
raise |
|
|
|
|
|
def rotate_image(image, angle): |
|
|
"""旋转图像""" |
|
|
if angle == 0: |
|
|
return image |
|
|
elif angle == 90: |
|
|
return image.rotate(-90, expand=True) |
|
|
elif angle == -90: |
|
|
return image.rotate(90, expand=True) |
|
|
elif angle == 180: |
|
|
return image.rotate(180, expand=True) |
|
|
return image |
|
|
|
|
|
def generate_multiview(input_image): |
|
|
""" |
|
|
生成多视图输出 |
|
|
|
|
|
输入: |
|
|
input_image: PIL Image |
|
|
|
|
|
输出: |
|
|
PIL Image with 6 views (2x3 grid) |
|
|
""" |
|
|
global pipeline |
|
|
|
|
|
if pipeline is None: |
|
|
load_model() |
|
|
|
|
|
|
|
|
img = input_image.resize((320, 320), Image.LANCZOS) |
|
|
|
|
|
|
|
|
result = pipeline( |
|
|
img, |
|
|
num_inference_steps=75 |
|
|
).images[0] |
|
|
|
|
|
|
|
|
|
|
|
result_w, result_h = result.size |
|
|
view_w = result_w // 3 |
|
|
view_h = result_h // 2 |
|
|
|
|
|
|
|
|
views = [] |
|
|
for row in range(2): |
|
|
for col in range(3): |
|
|
x = col * view_w |
|
|
y = row * view_h |
|
|
view = result.crop((x, y, x + view_w, y + view_h)) |
|
|
views.append(view) |
|
|
|
|
|
return views |
|
|
|
|
|
def process_image(input_image, progress=gr.Progress()): |
|
|
""" |
|
|
处理输入图像,生成工程六视图 |
|
|
|
|
|
输入: |
|
|
input_image: PIL Image |
|
|
progress: Gradio Progress 跟踪器 |
|
|
|
|
|
输出: |
|
|
result_image: PIL Image (六视图合成图) |
|
|
""" |
|
|
if input_image is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
load_model() |
|
|
|
|
|
|
|
|
img = input_image |
|
|
if img.size[0] != img.size[1]: |
|
|
size = min(img.size) |
|
|
img = img.crop(( |
|
|
(img.size[0] - size) // 2, |
|
|
(img.size[1] - size) // 2, |
|
|
(img.size[0] + size) // 2, |
|
|
(img.size[1] + size) // 2 |
|
|
)) |
|
|
|
|
|
progress(0.1, desc="生成水平视图...") |
|
|
|
|
|
horizontal_views = generate_multiview(img) |
|
|
|
|
|
progress(0.5, desc="生成俯视图...") |
|
|
|
|
|
img_rotated_up = rotate_image(img, -90) |
|
|
top_views = generate_multiview(img_rotated_up) |
|
|
|
|
|
progress(0.8, desc="生成底视图...") |
|
|
|
|
|
img_rotated_down = rotate_image(img, 90) |
|
|
bottom_views = generate_multiview(img_rotated_down) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
view_size = 320 |
|
|
combined = Image.new('RGB', (view_size * 3, view_size * 2)) |
|
|
|
|
|
|
|
|
combined.paste(horizontal_views[0], (0 * view_size, 0)) |
|
|
combined.paste(horizontal_views[1], (1 * view_size, 0)) |
|
|
combined.paste(horizontal_views[3], (2 * view_size, 0)) |
|
|
|
|
|
|
|
|
combined.paste(horizontal_views[4], (0 * view_size, view_size)) |
|
|
combined.paste(top_views[0], (1 * view_size, view_size)) |
|
|
combined.paste(bottom_views[0], (2 * view_size, view_size)) |
|
|
|
|
|
progress(1.0, desc="完成!") |
|
|
print("✓ 所有视图生成完成") |
|
|
return combined |
|
|
|
|
|
except Exception as e: |
|
|
print(f"错误: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise gr.Error(f"处理失败: {str(e)}") |
|
|
|
|
|
|
|
|
def create_demo(): |
|
|
with gr.Blocks(title="Zero123++ 工程六视图生成器") as demo: |
|
|
gr.Markdown(""" |
|
|
# Zero123++ 工程六视图生成器 |
|
|
|
|
|
将单张主视图转换为工程六视图 |
|
|
|
|
|
**输入要求:** |
|
|
- 建议使用正方形图片 |
|
|
- 推荐分辨率 >= 320x320 |
|
|
- 脚本会自动裁剪非正方形图片 |
|
|
|
|
|
**输出说明:** |
|
|
生成工程六视图,排列为 2 行 3 列: |
|
|
|
|
|
| 主视图 | 右视图 | 后视图 | |
|
|
|-------|-------|-------| |
|
|
| 左视图 | 俯视图 | 底视图 | |
|
|
|
|
|
**技术原理:** |
|
|
- 使用 Zero123++ v1.2 模型 |
|
|
- 通过 3 次推理 + 旋转输入实现工程六视图 |
|
|
- 每次推理约 30-60 秒,总计约 2-3 分钟 |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image( |
|
|
label="输入主视图", |
|
|
type="pil", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("生成工程六视图", variant="primary", size="lg") |
|
|
|
|
|
gr.Markdown(""" |
|
|
**注意:** |
|
|
- 需要运行 3 次推理(水平+俯视+底视) |
|
|
- 总耗时约 2-3 分钟 |
|
|
- 请耐心等待 |
|
|
""") |
|
|
|
|
|
with gr.Column(): |
|
|
output_image = gr.Image( |
|
|
label="工程六视图输出", |
|
|
type="pil", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### 视角说明 |
|
|
|
|
|
| 视图 | 方法 | 说明 | |
|
|
|-----|------|------| |
|
|
| 主视图 | Zero123++ 30° 视角 | 正面 | |
|
|
| 右视图 | Zero123++ 90° 视角 | 右侧 | |
|
|
| 后视图 | Zero123++ 210° 视角 | 背面 | |
|
|
| 左视图 | Zero123++ 270° 视角 | 左侧 | |
|
|
| 俯视图 | 输入旋转-90° → Zero123++ | 从上往下 | |
|
|
| 底视图 | 输入旋转+90° → Zero123++ | 从下往上 | |
|
|
|
|
|
### 技术说明 |
|
|
- 模型: [Zero123++ v1.2](https://huggingface.co/sudo-ai/zero123plus-v1.2) |
|
|
- Zero123++ 固定输出 6 个环绕视图 |
|
|
- 通过选择合适的视角 + 旋转输入实现工程视图 |
|
|
- v1.2 改进: 更稳定的视角,FOV 30° |
|
|
|
|
|
### 引用 |
|
|
```bibtex |
|
|
@misc{shi2023zero123plus, |
|
|
title={Zero123++: a Single Image to Consistent Multi-view Diffusion Base Model}, |
|
|
author={Ruoxi Shi and Hansheng Chen and others}, |
|
|
year={2023}, |
|
|
eprint={2310.15110}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.CV} |
|
|
} |
|
|
``` |
|
|
""") |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=process_image, |
|
|
inputs=[input_image], |
|
|
outputs=output_image |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("=" * 50) |
|
|
print("Zero123++ 工程六视图生成器") |
|
|
print("=" * 50) |
|
|
load_model() |
|
|
|
|
|
|
|
|
demo = create_demo() |
|
|
demo.queue(max_size=5) |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |
|
|
|