File size: 10,039 Bytes
066a555
 
 
 
 
47b2864
066a555
 
e7f7b12
f4cf4da
4f6229c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
effc301
47b2864
4f6229c
 
 
 
 
 
 
 
 
 
 
f01e85a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4cf4da
f01e85a
f4cf4da
066a555
 
 
 
e87a9f6
066a555
8e3da41
066a555
47b2864
 
4f6229c
066a555
 
 
 
 
 
 
8e3da41
066a555
 
b9938b4
 
 
 
066a555
 
 
 
 
 
 
 
 
8e3da41
066a555
 
 
 
 
 
 
8e3da41
066a555
 
 
2cabcff
 
 
968d5fe
2cabcff
 
 
 
968d5fe
2cabcff
968d5fe
2cabcff
 
066a555
 
 
 
 
 
 
 
 
 
 
6f1f301
0f4fde8
 
066a555
8e3da41
 
 
 
 
 
066a555
8e3da41
 
 
 
 
 
 
 
066a555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4cf4da
 
 
e87a9f6
066a555
 
 
f4cf4da
 
 
e87a9f6
066a555
 
 
 
 
 
e87a9f6
066a555
e87a9f6
066a555
519d68b
 
 
e87a9f6
519d68b
8e3da41
519d68b
066a555
 
 
 
8e3da41
066a555
 
 
 
 
 
 
 
 
8e3da41
066a555
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import gradio as gr
import numpy as np
import random
import spaces
import torch
import types
from diffusers.pipelines.prx import PRXPipeline

# monkey patch to add 1024 aspect ratios 
import diffusers.pipelines.prx.pipeline_prx as prx_mod
import math

def get_timestep_embedding(
    timesteps: torch.Tensor,
    embedding_dim: int,
    flip_sin_to_cos: bool = False,
    downscale_freq_shift: float = 1,
    scale: float = 0,
    max_period: int = 10000,
) -> torch.Tensor:
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

    Args
        timesteps (torch.Tensor):
            a 1-D Tensor of N indices, one per batch element. These may be fractional.
        embedding_dim (int):
            the dimension of the output.
        flip_sin_to_cos (bool):
            Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
        downscale_freq_shift (float):
            Controls the delta between frequencies between dimensions
        scale (float):
            Scaling factor applied to the embeddings.
        max_period (int):
            Controls the maximum frequency of the embeddings
    Returns
        torch.Tensor: an [N x dim] Tensor of positional embeddings.
    """
    assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

    half_dim = embedding_dim // 2
    exponent = -math.log(max_period) * torch.arange(
        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
    )
    exponent = exponent / (half_dim - downscale_freq_shift)

    emb = torch.exp(exponent)
    emb = timesteps[:, None].float() * emb[None, :]

    # scale embeddings
    emb = scale * emb

    # concat sine and cosine embeddings
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

    # flip sine and cosine embeddings
    if flip_sin_to_cos:
        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

    # zero pad
    if embedding_dim % 2 == 1:
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb
 
def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
        return self.time_in(
            get_timestep_embedding(
                timesteps=timestep,
                embedding_dim=256,
                max_period=self.time_max_period,
                scale=self.time_factor,
                flip_sin_to_cos=True,  # Match original cos, sin order
                downscale_freq_shift=0.0,
            ).to(dtype)
        )
        
CUSTOM_ASPECT_RATIO_512_BIN = {
    "0.49": [704, 1440],
    "0.52": [736, 1408],
    "0.53": [736, 1376],
    "0.57": [768, 1344],
    "0.59": [768, 1312],
    "0.62": [800, 1280],
    "0.67": [832, 1248],
    "0.68": [832, 1216],
    "0.78": [896, 1152],
    "0.83": [928, 1120],
    "0.94": [992, 1056],
    "1.0": [1024, 1024],
    "1.06": [1056, 992],
    "1.13": [1088, 960],
    "1.21": [1120, 928],
    "1.29": [1152, 896],
    "1.37": [1184, 864],
    "1.46": [1216, 832],
    "1.5": [1248, 832],
    "1.71": [1312, 768],
    "1.75": [1344, 768],
    "1.87": [1376, 736],
    "1.91": [1408, 736],
    "2.05": [1440, 704],
}

prx_mod.ASPECT_RATIO_512_BIN = CUSTOM_ASPECT_RATIO_512_BIN

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = PRXPipeline.from_pretrained(
    "Photoroom/prx-1024-t2i-beta",
    torch_dtype=dtype
).to(device)

# Properly bind the method to the instance using types.MethodType
pipe.transformer._compute_timestep_embedding = types.MethodType(_compute_timestep_embedding, pipe.transformer)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024


@spaces.GPU()
def infer(
    prompt,
    negative_prompt="",          # <-- NEW
    seed=42,
    randomize_seed=False,
    width=1024,
    height=1024,
    num_inference_steps=28,
    guidance_scale=4.0,
    progress=gr.Progress(track_tqdm=True)
):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator(device=device).manual_seed(seed)

    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,   # <-- NEW
        width=width,
        height=height,
        num_inference_steps=num_inference_steps,
        generator=generator,
        guidance_scale=guidance_scale,
    ).images[0]

    return image, seed   # <-- IMPORTANT: return for Gradio


examples = [
    # ["A massive black monolith standing alone in a mirror-like salt flat after rainfall, horizon dissolving into pastel pink and cyan, reflections perfect and infinite, minimalist 2.39:1 frame, cinematic atmosphere of silence, RED Komodo 6K capture, 35 mm lens, ND filter, high dynamic range, ultra-clean tones and soft ambient light.", ""],
    ["A turtle covered in vibrant ceramic mosaic tiles, tiny geometric patterns, resting on weathered stone steps in a Mediterranean town square, warm daylight, artisanal feel", ""],
     ["Hundreds of paper lanterns drifting along a quiet river at dusk, soft orange light piercing cold blue mist, reflections trembling across rippled water, camera at water level with shallow DOF, cinematic color contrast of warm and cool tones", ""],
    ["A woman standing ankle-deep in the ocean at dawn, gentle waves touching her feet, mist and pastel horizon, cinematic wide composition, calm and contemplative mood, filmic color grading reminiscent of Terrence Malick's imagery.", ""],
    # ["In the courtyard of a coastal house, white sheets flap slowly in the wind, a woman pauses between hanging clothes, eyes closed, light flickering through the fabric. A flock of seagulls turns sharply overhead, casting moving shadows on the walls. The sound of waves faintly audible, palette of whites, greys, and sun-bleached blues, evokes transience and memory.", ""],
    # ["A close-up portrait in a photography studio, multiple soft light sources creating gradients of shadow on her face, minimal background, cinematic 4 K realism, artistic focus on light and emotion rather than glamour.", ""],
    ["A cat sculpted from fine white porcelain with delicate blue floral motifs, standing gracefully in a minimalist contemporary art gallery, polished marble floor reflections, soft museum lighting, ultra-detailed ceramic gloss", ""],
    ["A whimsical fantasy dog made entirely from layered paper cutouts, textured handmade paper, watercolor patterns on its body, pastel tones, enchanted meadow, soft glow, playful mood, highly detailed illustration.", ""],
    ["A front-facing portrait of a lion on the golden savanna at sunset.", ""],
    ["An owl sculpted from layered book pages, faint text visible on feathers, perched on a wooden reading desk in a grand library, golden lamplight, quiet scholarly ambience", ""],
    ["Une peinture numérique d’un vieux tram rouillé reposant sur une plage de sable balayée par le vent, ses couleurs délavées brillant doucement sous la lumière dorée du soir", ""],
    ["A digital painting depicts a herd of African elephants traversing a dry, grassy savanna.", ""],
    ["A fox constructed from tightly coiled metal wire strands, intricate loops, semi-transparent silhouette, perched on an old brick rooftop in a calm city evening, soft warm window lights, poetic mood", ""]
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# PRX 1.2B Image Generator")
        gr.Markdown("Generate high-quality images using the beta-preview of PRX.")
        gr.Markdown("Works best with very detailed prompts in natural language.")

        prompt = gr.Text(
            label="Prompt",
            show_label=True,
            max_lines=2,
            placeholder="Enter your prompt",
        )

        negative_prompt = gr.Text(     # <-- NEW UI CONTROL
            label="Negative prompt",
            max_lines=2,
            placeholder="Things to avoid (e.g., blurry, low-res, extra limbs...)",
            value=""
        )

        with gr.Row():
            run_button = gr.Button("Run", scale=0)

        result = gr.Image(label="Result", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )

            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=700,
                    maximum=1440,
                    step=1,
                    value=1024,
                )
                height = gr.Slider(
                    label="Height",
                    minimum=700,
                    maximum=1440,
                    step=1,
                    value=1024,
                )

            with gr.Row():
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=28,
                )
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=1.0,
                    maximum=7.0,
                    step=0.1,
                    value=4.0,
                )

        gr.Examples(
            examples=examples,
            fn=infer,
            inputs=[prompt, negative_prompt],   # <-- NEW
            outputs=[result, seed],
            cache_examples="lazy"
        )

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[
            prompt,
            negative_prompt,   # <-- NEW
            seed,
            randomize_seed,
            width,
            height,
            num_inference_steps,
            guidance_scale,
        ],
        outputs=[result, seed]
    )

demo.launch()