import numpy as np
import gradio as gr
import numpy as np
import random
import torch
import spaces
import os
import base64
import json
import torchvision
from PIL import Image
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from src.pipelines.pipeline_stable_diffusion_outpaint import OutpaintPipeline
from src.pipelines.pipeline_controlnet_outpaint import ControlNetOutpaintPipeline
from src.schedulers.scheduling_pndm import CustomScheduler
from src.models.unet import U_Net
from src.models.light_source_regressor import LightSourceRegressor
from utils.dataset import HFCustomImageLoader
from utils.utils import (
blend_with_alpha,
load_mfdnet_checkpoint,
predict_flare_from_6_channel,
predict_flare_from_3_channel,
blend_light_source,
)
from SIFR_models.flare7kpp.model import Uformer
intro = """
NOTICE: This demo is limited to cpu inference only. For better experience, please run the code locally with a GPU.
"""
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
#edit_text{
margin-top: -62px !important
}
"""
def encode_image(pil_image):
import io
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
# --- UI Constants and Helpers ---
MAX_SEED = np.iinfo(np.int32).max
## --- Model Loading --- ##
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
# controlnet
controlnet = ControlNetModel.from_pretrained(
"RayTsai-030/LightsOut-controlnet", torch_dtype=dtype
)
# outpainter
pipe = ControlNetOutpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
controlnet=controlnet,
torch_dtype=dtype,
).to(device)
pipe.scheduler = CustomScheduler.from_config(pipe.scheduler.config)
pipe.unet.load_attn_procs("./weights/light_outpaint_lora", use_safetensors=True)
# blip
processor = Blip2Processor.from_pretrained(
"Salesforce/blip2-opt-2.7b", revision="51572668da0eb669e01a189dc22abe6088589a24"
)
blip2 = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
torch_dtype=dtype,
revision="51572668da0eb669e01a189dc22abe6088589a24",
)
blip2 = blip2.to(device)
# light regressor
lsr_module = LightSourceRegressor()
ckpt = torch.load(
"./weights/light_regress/model.pth", map_location="cpu" if device == "cpu" else None
)
lsr_module.load_state_dict(ckpt["model"])
lsr_module.to(device)
lsr_module.eval()
# SIFR model
sifr_model = Uformer(img_size=512, img_ch=3, output_ch=6).to(device)
sifr_model.load_state_dict(
torch.load(
"./weights/net_g_last.pth", map_location="cpu" if device == "cpu" else None
)
)
# --- Main Inference Function (with hardcoded negative prompt) ---
@spaces.GPU(duration=120)
def infer(
image,
seed=42,
cfg=7.5,
num_inference_steps=50,
repeat_time=4,
left_outpaint=64,
right_outpaint=64,
up_outpaint=64,
down_outpaint=64,
progress=gr.Progress(track_tqdm=True),
):
"""
Generates an image
"""
# dataset
dataset = HFCustomImageLoader(
image, left_outpaint, right_outpaint, up_outpaint, down_outpaint
)
data = dataset[0]
# generator
generator = torch.Generator(device=device).manual_seed(seed)
# transformation
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.5], std=[0.5]),
]
)
sifr_transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize((512, 512)),
]
)
threshold = 0.5
with torch.no_grad():
input_img = data["input_img"]
input_img = transform(input_img).unsqueeze(0).to(device)
pred_mask = lsr_module.forward_render(input_img)
pred_mask = (pred_mask > threshold).float()
if pred_mask.device != "cpu":
pred_mask = pred_mask.cpu()
pred_mask = pred_mask.numpy()
data["control_img"] = Image.fromarray((pred_mask[0, 0] * 255).astype(np.uint8))
print("Finish light source detection...")
# prepare text prompt
inputs = processor(data["blip_img"], return_tensors="pt").to(
device=device, dtype=dtype
)
generate_id = blip2.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generate_id, skip_special_tokens=True)[
0
].strip()
generated_text += (
", dynamic lighting, intense light source, prominent lens flare, best quality, high resolution, masterpiece, intricate details"
# ", full light sources with lens flare, best quality, high resolution"
)
print(f"Generated text prompt: {generated_text}")
# Blur mask
# data["mask_img"] = data["mask_img"].filter(ImageFilter.GaussianBlur(15))
# denoise
outpaint_result = pipe(
prompt=generated_text,
negative_prompt="NSFW, (word:1.5), watermark, blurry, missing body, amputation, mutilation",
image=data["input_img"],
mask_image=data["mask_img"],
control_image=data["control_img"],
num_inference_steps=num_inference_steps,
guidance_scale=cfg,
generator=generator,
repeat_time=repeat_time,
).images[0]
# save result
outpaint_result = np.array(outpaint_result)
input_img = np.array(data["input_img"])
box = data["box"]
input_img2 = outpaint_result.copy()
input_img2[box[2] : box[3] + 1, box[0] : box[1] + 1] = input_img[
box[2] : box[3] + 1, box[0] : box[1] + 1
]
outpaint_result = blend_with_alpha(outpaint_result, input_img2, box, blur_size=31)
outpaint_result = Image.fromarray(outpaint_result.astype(np.uint8))
print("Finish outpainting...")
# flare removal
img = sifr_transform(outpaint_result).unsqueeze(0).to(device)
with torch.no_grad():
output_img = sifr_model(img)
gamma = torch.Tensor([2.2])
# flare7k++
deflare_result, _, _ = predict_flare_from_6_channel(output_img, gamma, device)
# # mfdnet
# flare_mask = torch.zeros_like(img)
# deflare_img, _ = predict_flare_from_3_channel(
# output_img, flare_mask, output_img, img, img, gamma
# )
# deflare_img = blend_light_source(img, deflare_img, 0.999)
if deflare_result.device != "cpu":
deflare_result = deflare_result.cpu()
deflare_result = deflare_result.squeeze(0).permute(1, 2, 0).numpy()
deflare_result = np.clip(deflare_result, 0.0, 1.0)
deflare_result = (deflare_result * 255).astype(np.uint8)
deflare_result = deflare_result[box[2] : box[3] + 1, box[0] : box[1] + 1, :]
deflare_result = Image.fromarray(deflare_result).resize((512, 512), Image.LANCZOS)
print("Finish flare removal...")
return data["control_img"], outpaint_result, deflare_result
# --- Examples and UI Layout ---
examples = []
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(intro)
# gr.Markdown(
# "[Learn more](https://github.com/QwenLM/Qwen-Image) about the Qwen-Image series. Try on [Qwen Chat](https://chat.qwen.ai/), or [download model](https://huggingface.co/Qwen/Qwen-Image-Edit) to run locally with ComfyUI or diffusers."
# )
with gr.Row():
input_image = gr.Image(label="Input Image", show_label=False, type="pil")
# with gr.Row():
# with gr.Column():
with gr.Accordion("Advanced Settings", open=False):
# Negative prompt UI element is removed here
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
with gr.Column():
left_outpaint = gr.Slider(
label="Left outpaint (px)",
minimum=32,
maximum=128,
step=32,
value=64,
)
right_outpaint = gr.Slider(
label="Right outpaint (px)",
minimum=32,
maximum=128,
step=32,
value=64,
)
up_outpaint = gr.Slider(
label="Up outpaint (px)", minimum=32, maximum=128, step=32, value=64
)
down_outpaint = gr.Slider(
label="Down outpaint (px)",
minimum=32,
maximum=128,
step=32,
value=64,
)
# randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
true_guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1.0,
maximum=10.0,
step=0.5,
value=7.5,
)
num_inference_steps = gr.Slider(
label="Steps",
minimum=30,
maximum=70,
step=1,
value=50,
)
repeat_time = gr.Slider(
label="Repeat time",
minimum=0,
maximum=4,
step=1,
value=4,
)
with gr.Row():
with gr.Column():
lightmask_result = gr.Image(
label="Lightmask Result", show_label=True, type="pil"
)
with gr.Column():
outpainted_result = gr.Image(
label="Outpainted Result", show_label=True, type="pil"
)
flarefree_result = gr.Image(
label="Flare-free Result", show_label=True, type="pil"
)
with gr.Row():
run_button = gr.Button("Edit!", variant="primary")
# gr.Examples(
# examples=[
# ["neon_sign.png", "change the text to read 'Qwen Image Edit is here'"],
# [
# "cat_sitting.jpg",
# "make the cat floating in the air and holding a sign that reads 'this is fun' written with a blue crayon",
# ],
# ["pie.png", "turn the style of the photo to vintage comic book"],
# ],
# inputs=[input_image, prompt],
# outputs=[result, seed],
# fn=infer,
# cache_examples="lazy",
# )
gr.on(
triggers=[run_button.click],
fn=infer,
inputs=[
input_image,
seed,
true_guidance_scale,
num_inference_steps,
repeat_time,
left_outpaint,
right_outpaint,
up_outpaint,
down_outpaint,
],
outputs=[lightmask_result, outpainted_result, flarefree_result],
)
if __name__ == "__main__":
demo.launch()