Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,69 +7,39 @@ import streamlit as st
|
|
| 7 |
|
| 8 |
is_colab = utils.is_google_colab()
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
|
|
|
| 12 |
|
| 13 |
-
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
| 14 |
-
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path,
|
| 15 |
-
|
| 16 |
-
|
| 17 |
|
| 18 |
-
if torch.cuda.is_available():
|
| 19 |
-
|
| 20 |
|
| 21 |
device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
|
| 22 |
|
| 23 |
|
| 24 |
-
def inference(
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
for model in models:
|
| 28 |
-
if model.name == model_name:
|
| 29 |
-
current_model = model
|
| 30 |
-
model_path = current_model.path
|
| 31 |
|
| 32 |
-
generator = torch.Generator('cuda').manual_seed(seed) if seed != 0 else None
|
| 33 |
-
|
| 34 |
-
if img is not None:
|
| 35 |
-
return img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height, generator)
|
| 36 |
-
else:
|
| 37 |
-
return txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, generator, inpaint_image)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height, generator=None):
|
| 41 |
-
|
| 42 |
-
global last_mode
|
| 43 |
-
global pipe
|
| 44 |
-
global current_model_path
|
| 45 |
-
if model_path != current_model_path or last_mode != "img2img":
|
| 46 |
-
current_model_path = model_path
|
| 47 |
-
|
| 48 |
-
if is_colab or current_model == custom_model:
|
| 49 |
-
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(current_model_path, torch_dtype=torch.float16, scheduler=scheduler)
|
| 50 |
-
else:
|
| 51 |
-
pipe.to("cpu")
|
| 52 |
-
pipe = current_model.pipe_i2i
|
| 53 |
-
|
| 54 |
-
if torch.cuda.is_available():
|
| 55 |
-
pipe = pipe.to("cuda")
|
| 56 |
-
last_mode = "img2img"
|
| 57 |
-
|
| 58 |
-
prompt = current_model.prefix + prompt
|
| 59 |
ratio = min(height / img.height, width / img.width)
|
| 60 |
-
img = img.resize((int(img.width * ratio), int(img.height * ratio))
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
return replace_nsfw_images(result)
|
| 74 |
|
| 75 |
|
|
@@ -99,60 +69,52 @@ with gr.Blocks(css=css) as demo:
|
|
| 99 |
"""
|
| 100 |
)
|
| 101 |
with gr.Row():
|
| 102 |
-
|
| 103 |
-
with gr.Column(scale=55):
|
| 104 |
-
with gr.Group():
|
| 105 |
-
|
| 106 |
-
with gr.Row():
|
| 107 |
-
prompt = gr.Textbox(label="Prompt", show_label=False, max_lines=2,placeholder="Enter prompt. Style applied automatically").style(container=False)
|
| 108 |
-
generate = gr.Button(value="Generate").style(rounded=(False, True, True, False))
|
| 109 |
|
| 110 |
-
|
| 111 |
-
# gallery = gr.Gallery(
|
| 112 |
-
# label="Generated images", show_label=False, elem_id="gallery"
|
| 113 |
-
# ).style(grid=[1], height="auto")
|
| 114 |
-
|
| 115 |
-
with gr.Column(scale=45):
|
| 116 |
-
with gr.Tab("Options"):
|
| 117 |
with gr.Group():
|
| 118 |
-
neg_prompt = gr.Textbox(label="Negative prompt", placeholder="What to exclude from the image")
|
| 119 |
|
| 120 |
-
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
strength = gr.Slider(label="Transformation strength", minimum=0, maximum=1, step=0.01, value=0.5)
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
|
|
|
| 139 |
|
| 140 |
-
|
| 141 |
-
if is_colab:
|
| 142 |
-
custom_model_path.change(custom_model_changed, inputs=custom_model_path, outputs=None)
|
| 143 |
-
# n_images.change(lambda n: gr.Gallery().style(grid=[2 if n > 1 else 1], height="auto"), inputs=n_images, outputs=gallery)
|
| 144 |
|
| 145 |
-
inputs = [
|
|
|
|
| 146 |
prompt.submit(inference, inputs=inputs, outputs=image_out)
|
| 147 |
generate.click(inference, inputs=inputs, outputs=image_out)
|
| 148 |
|
| 149 |
-
ex = gr.Examples(
|
| 150 |
-
[
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
[
|
| 155 |
-
|
|
|
|
| 156 |
|
| 157 |
gr.Markdown('''
|
| 158 |
Models by [@nitrosocke](https://huggingface.co/nitrosocke), [@haruu1367](https://twitter.com/haruu1367), [@Helixngc7293](https://twitter.com/DGSpitzer) and others. ❤️<br>
|
|
|
|
| 7 |
|
| 8 |
is_colab = utils.is_google_colab()
|
| 9 |
|
| 10 |
+
if False:
|
| 11 |
+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
| 12 |
+
num_train_timesteps=1000, clip_sample=False, set_alpha_to_one=False)
|
| 13 |
|
| 14 |
+
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
| 15 |
+
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path,
|
| 16 |
+
use_auth_token=st.secrets["USER_TOKEN"],
|
| 17 |
+
scheduler=scheduler)
|
| 18 |
|
| 19 |
+
if torch.cuda.is_available():
|
| 20 |
+
pipe = pipe.to("cuda")
|
| 21 |
|
| 22 |
device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
|
| 23 |
|
| 24 |
|
| 25 |
+
def inference(source_prompt, target_prompt, source_guidance_scale=1, guidance_scale=5, num_inference_steps=100,
|
| 26 |
+
width=512, height=512, seed=0, img=None, strength=0.7):
|
| 27 |
|
| 28 |
+
torch.manual_seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
ratio = min(height / img.height, width / img.width)
|
| 31 |
+
img = img.resize((int(img.width * ratio), int(img.height * ratio)))
|
| 32 |
+
|
| 33 |
+
result = pipe(prompt=target_prompt,
|
| 34 |
+
source_prompt=source_prompt,
|
| 35 |
+
init_image=img,
|
| 36 |
+
num_inference_steps=num_inference_steps,
|
| 37 |
+
eta=0.1,
|
| 38 |
+
strength=strength,
|
| 39 |
+
guidance_scale=guidance_scale,
|
| 40 |
+
source_guidance_scale=source_guidance_scale,
|
| 41 |
+
).images[0]
|
| 42 |
+
|
|
|
|
| 43 |
return replace_nsfw_images(result)
|
| 44 |
|
| 45 |
|
|
|
|
| 69 |
"""
|
| 70 |
)
|
| 71 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
with gr.Column(scale=55):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
with gr.Group():
|
|
|
|
| 75 |
|
| 76 |
+
with gr.Row():
|
| 77 |
+
generate = gr.Button(value="Generate").style(rounded=(False, True, True, False))
|
| 78 |
+
image = gr.Image(label="Source image", height=256, tool="editor", type="pil")
|
| 79 |
|
| 80 |
+
image_out = gr.Image(height=512)
|
| 81 |
+
# gallery = gr.Gallery(
|
| 82 |
+
# label="Generated images", show_label=False, elem_id="gallery"
|
| 83 |
+
# ).style(grid=[1], height="auto")
|
| 84 |
|
| 85 |
+
with gr.Column(scale=45):
|
| 86 |
+
with gr.Tab("Options"):
|
| 87 |
+
with gr.Group():
|
| 88 |
+
source_prompt = gr.Textbox(label="Source prompt", placeholder="Source prompt describes the input image")
|
| 89 |
+
target_prompt = gr.Textbox(label="Target prompt", placeholder="Target prompt describes the output image")
|
| 90 |
|
| 91 |
+
with gr.Row():
|
| 92 |
+
source_guidance_scale = gr.Slider(label="Source guidance scale", value=1, minimum=1, maximum=10)
|
| 93 |
+
guidance_scale = gr.Slider(label="Target guidance scale", value=5, minimum=1, maximum=10)
|
| 94 |
|
| 95 |
+
with gr.Row():
|
| 96 |
+
steps = gr.Slider(label="Number of inference steps", value=100, minimum=25, maximum=500, step=1)
|
| 97 |
+
strength = gr.Slider(label="Strength", value=0.7, minimum=0.5, maximum=1, step=0.01)
|
|
|
|
| 98 |
|
| 99 |
+
with gr.Row():
|
| 100 |
+
width = gr.Slider(label="Width", value=512, minimum=64, maximum=1024, step=8)
|
| 101 |
+
height = gr.Slider(label="Height", value=512, minimum=64, maximum=1024, step=8)
|
| 102 |
|
| 103 |
+
seed = gr.Slider(0, 2147483647, label='Seed', value=0, step=1)
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
inputs = [source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
|
| 106 |
+
width, height, seed, img, strength]
|
| 107 |
prompt.submit(inference, inputs=inputs, outputs=image_out)
|
| 108 |
generate.click(inference, inputs=inputs, outputs=image_out)
|
| 109 |
|
| 110 |
+
ex = gr.Examples(
|
| 111 |
+
[
|
| 112 |
+
["A", "B", 7.5, 50, None], # TODO: load image from a file.
|
| 113 |
+
[],
|
| 114 |
+
],
|
| 115 |
+
[source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
|
| 116 |
+
width, height, seed, img, strength],
|
| 117 |
+
image_out, inference, cache_examples=False)
|
| 118 |
|
| 119 |
gr.Markdown('''
|
| 120 |
Models by [@nitrosocke](https://huggingface.co/nitrosocke), [@haruu1367](https://twitter.com/haruu1367), [@Helixngc7293](https://twitter.com/DGSpitzer) and others. ❤️<br>
|