File size: 9,182 Bytes
81a2e17
 
d7cb9d0
 
81a2e17
 
 
 
 
e0e224f
d7cb9d0
81a2e17
 
 
 
d7cb9d0
 
 
 
e0e224f
d7cb9d0
 
 
 
 
 
e0e224f
 
81a2e17
d7cb9d0
81a2e17
 
 
 
 
d7cb9d0
81a2e17
 
 
 
 
e0e224f
d7cb9d0
e0e224f
d7cb9d0
 
 
e0e224f
81a2e17
e0e224f
d7cb9d0
 
 
e0e224f
81a2e17
e0e224f
81a2e17
e0e224f
 
d7cb9d0
 
 
 
 
e0e224f
 
 
d7cb9d0
e0e224f
 
d7cb9d0
 
 
 
 
 
 
 
 
 
 
 
 
81a2e17
d7cb9d0
81a2e17
 
 
 
 
d7cb9d0
e0e224f
 
 
 
 
 
 
d7cb9d0
e0e224f
d7cb9d0
e0e224f
 
d7cb9d0
e0e224f
 
d7cb9d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0e224f
 
 
 
 
 
 
 
 
 
 
81a2e17
 
 
 
 
e0e224f
 
81a2e17
e0e224f
 
81a2e17
 
 
 
 
 
 
 
 
d7cb9d0
 
 
81a2e17
d7cb9d0
 
 
 
e0e224f
d7cb9d0
 
 
 
e0e224f
d7cb9d0
 
 
 
e0e224f
 
d7cb9d0
 
 
 
 
 
 
 
 
e0e224f
d7cb9d0
 
e0e224f
81a2e17
 
 
e0e224f
 
81a2e17
 
 
 
d7cb9d0
81a2e17
d7cb9d0
81a2e17
d7cb9d0
81a2e17
 
 
d7cb9d0
 
81a2e17
 
d7cb9d0
 
 
 
 
 
 
 
 
 
81a2e17
 
 
 
 
e0e224f
81a2e17
 
 
 
e0e224f
81a2e17
e0e224f
 
d7cb9d0
 
 
e0e224f
81a2e17
 
 
d7cb9d0
81a2e17
 
 
 
 
d7cb9d0
 
 
 
 
 
 
 
 
 
 
81a2e17
d7cb9d0
 
 
 
81a2e17
 
 
d7cb9d0
 
 
 
 
 
 
81a2e17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0e224f
d7cb9d0
e0e224f
81a2e17
 
 
 
e0e224f
81a2e17
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
#!/usr/bin/env python3
"""
Zero123++ 工程六视图生成器
通过旋转输入图片来模拟不同视角
"""

import gradio as gr
import torch
from PIL import Image
import numpy as np
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler

# 全局变量存储 pipeline
pipeline = None

# 工程六视图配置
# Zero123++ 固定输出 6 个视图,我们选择最接近正交视图的角度
# 方位角: 30, 90, 150, 210, 270, 330
# 我们使用: 90° (右), 210° (后偏左), 330° (前偏右)
ENGINEERING_VIEWS = {
    "主视图": {"rotate_input": 0, "select_index": 0, "position": (0, 0)},     # 使用原始输入
    "右视图": {"rotate_input": 0, "select_index": 1, "position": (1, 0)},     # 90度视角
    "后视图": {"rotate_input": 0, "select_index": 3, "position": (2, 0)},     # 210度视角
    "左视图": {"rotate_input": 0, "select_index": 4, "position": (0, 1)},     # 270度视角
    "俯视图": {"rotate_input": -90, "select_index": 0, "position": (1, 1)},   # 旋转输入-90度
    "底视图": {"rotate_input": 90, "select_index": 0, "position": (2, 1)},    # 旋转输入+90度
}

def load_model():
    """加载 Zero123++ 模型"""
    global pipeline

    if pipeline is not None:
        return

    print("正在加载 Zero123++ 模型...")

    # 检查 CUDA 可用性
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    try:
        # 加载 Zero123++ pipeline
        pipeline = DiffusionPipeline.from_pretrained(
            "sudo-ai/zero123plus-v1.2",
            custom_pipeline="sudo-ai/zero123plus-pipeline",
            torch_dtype=dtype
        )

        # 设置调度器
        pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
            pipeline.scheduler.config,
            timestep_spacing='trailing'
        )

        pipeline.to(device)

        # 启用内存优化
        if torch.cuda.is_available():
            try:
                pipeline.enable_attention_slicing()
                pipeline.enable_vae_slicing()
            except:
                pass

        print(f"✓ 模型加载完成 (设备: {device})")
    except Exception as e:
        print(f"错误: 无法加载 Zero123++ - {e}")
        raise

def rotate_image(image, angle):
    """旋转图像"""
    if angle == 0:
        return image
    elif angle == 90:
        return image.rotate(-90, expand=True)
    elif angle == -90:
        return image.rotate(90, expand=True)
    elif angle == 180:
        return image.rotate(180, expand=True)
    return image

def generate_multiview(input_image):
    """
    生成多视图输出

    输入:
        input_image: PIL Image

    输出:
        PIL Image with 6 views (2x3 grid)
    """
    global pipeline

    if pipeline is None:
        load_model()

    # 预处理输入图像
    img = input_image.resize((320, 320), Image.LANCZOS)

    # 运行推理生成6个视图
    result = pipeline(
        img,
        num_inference_steps=75
    ).images[0]

    # Zero123++ 输出是 2x3 的网格,需要拆分
    # 输出尺寸应该是 320*3 x 320*2 = 960x640
    result_w, result_h = result.size
    view_w = result_w // 3
    view_h = result_h // 2

    # 提取6个视图
    views = []
    for row in range(2):
        for col in range(3):
            x = col * view_w
            y = row * view_h
            view = result.crop((x, y, x + view_w, y + view_h))
            views.append(view)

    return views

def process_image(input_image, progress=gr.Progress()):
    """
    处理输入图像,生成工程六视图

    输入:
        input_image: PIL Image
        progress: Gradio Progress 跟踪器

    输出:
        result_image: PIL Image (六视图合成图)
    """
    if input_image is None:
        return None

    try:
        # 确保模型已加载
        load_model()

        # 预处理 - 转为正方形
        img = input_image
        if img.size[0] != img.size[1]:
            size = min(img.size)
            img = img.crop((
                (img.size[0] - size) // 2,
                (img.size[1] - size) // 2,
                (img.size[0] + size) // 2,
                (img.size[1] + size) // 2
            ))

        progress(0.1, desc="生成水平视图...")
        # 生成水平方向的视图 (主/右/后/左)
        horizontal_views = generate_multiview(img)

        progress(0.5, desc="生成俯视图...")
        # 旋转输入-90度生成俯视图
        img_rotated_up = rotate_image(img, -90)
        top_views = generate_multiview(img_rotated_up)

        progress(0.8, desc="生成底视图...")
        # 旋转输入+90度生成底视图
        img_rotated_down = rotate_image(img, 90)
        bottom_views = generate_multiview(img_rotated_down)

        # 组合成工程六视图
        # 布局: 主/右/后
        #       左/俯/底
        view_size = 320
        combined = Image.new('RGB', (view_size * 3, view_size * 2))

        # 第一行: 主视图(0), 右视图(90), 后视图(210)
        combined.paste(horizontal_views[0], (0 * view_size, 0))          # 主视图 (30度)
        combined.paste(horizontal_views[1], (1 * view_size, 0))          # 右视图 (90度)
        combined.paste(horizontal_views[3], (2 * view_size, 0))          # 后视图 (210度)

        # 第二行: 左视图(270), 俯视图, 底视图
        combined.paste(horizontal_views[4], (0 * view_size, view_size))  # 左视图 (270度)
        combined.paste(top_views[0], (1 * view_size, view_size))         # 俯视图 (旋转输入-90)
        combined.paste(bottom_views[0], (2 * view_size, view_size))      # 底视图 (旋转输入+90)

        progress(1.0, desc="完成!")
        print("✓ 所有视图生成完成")
        return combined

    except Exception as e:
        print(f"错误: {e}")
        import traceback
        traceback.print_exc()
        raise gr.Error(f"处理失败: {str(e)}")

# 创建 Gradio 界面
def create_demo():
    with gr.Blocks(title="Zero123++ 工程六视图生成器") as demo:
        gr.Markdown("""
        # Zero123++ 工程六视图生成器

        将单张主视图转换为工程六视图

        **输入要求:**
        - 建议使用正方形图片
        - 推荐分辨率 >= 320x320
        - 脚本会自动裁剪非正方形图片

        **输出说明:**
        生成工程六视图,排列为 2 行 3 列:

        | 主视图 | 右视图 | 后视图 |
        |-------|-------|-------|
        | 左视图 | 俯视图 | 底视图 |

        **技术原理:**
        - 使用 Zero123++ v1.2 模型
        - 通过 3 次推理 + 旋转输入实现工程六视图
        - 每次推理约 30-60 秒,总计约 2-3 分钟
        """)

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(
                    label="输入主视图",
                    type="pil",
                    height=400
                )

                generate_btn = gr.Button("生成工程六视图", variant="primary", size="lg")

                gr.Markdown("""
                **注意:**
                - 需要运行 3 次推理(水平+俯视+底视)
                - 总耗时约 2-3 分钟
                - 请耐心等待
                """)

            with gr.Column():
                output_image = gr.Image(
                    label="工程六视图输出",
                    type="pil",
                    height=400
                )

        gr.Markdown("""
        ### 视角说明

        | 视图 | 方法 | 说明 |
        |-----|------|------|
        | 主视图 | Zero123++ 30° 视角 | 正面 |
        | 右视图 | Zero123++ 90° 视角 | 右侧 |
        | 后视图 | Zero123++ 210° 视角 | 背面 |
        | 左视图 | Zero123++ 270° 视角 | 左侧 |
        | 俯视图 | 输入旋转-90° → Zero123++ | 从上往下 |
        | 底视图 | 输入旋转+90° → Zero123++ | 从下往上 |

        ### 技术说明
        - 模型: [Zero123++ v1.2](https://huggingface.co/sudo-ai/zero123plus-v1.2)
        - Zero123++ 固定输出 6 个环绕视图
        - 通过选择合适的视角 + 旋转输入实现工程视图
        - v1.2 改进: 更稳定的视角,FOV 30°

        ### 引用
        ```bibtex
        @misc{shi2023zero123plus,
            title={Zero123++: a Single Image to Consistent Multi-view Diffusion Base Model},
            author={Ruoxi Shi and Hansheng Chen and others},
            year={2023},
            eprint={2310.15110},
            archivePrefix={arXiv},
            primaryClass={cs.CV}
        }
        ```
        """)

        # 绑定事件
        generate_btn.click(
            fn=process_image,
            inputs=[input_image],
            outputs=output_image
        )

    return demo

if __name__ == "__main__":
    # 预加载模型
    print("=" * 50)
    print("Zero123++ 工程六视图生成器")
    print("=" * 50)
    load_model()

    # 启动 demo
    demo = create_demo()
    demo.queue(max_size=5)
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )