import numpy as np import gradio as gr import numpy as np import random import torch import spaces import os import base64 import json import torchvision from PIL import Image from diffusers import ControlNetModel, DPMSolverMultistepScheduler from transformers import Blip2Processor, Blip2ForConditionalGeneration from src.pipelines.pipeline_stable_diffusion_outpaint import OutpaintPipeline from src.pipelines.pipeline_controlnet_outpaint import ControlNetOutpaintPipeline from src.schedulers.scheduling_pndm import CustomScheduler from src.models.unet import U_Net from src.models.light_source_regressor import LightSourceRegressor from utils.dataset import HFCustomImageLoader from utils.utils import ( blend_with_alpha, load_mfdnet_checkpoint, predict_flare_from_6_channel, predict_flare_from_3_channel, blend_light_source, ) from SIFR_models.flare7kpp.model import Uformer intro = """

LightsOut LightsOut Logo

[Project page]
NOTICE: This demo is limited to cpu inference only. For better experience, please run the code locally with a GPU.
""" css = """ #col-container { margin: 0 auto; max-width: 1024px; } #edit_text{ margin-top: -62px !important } """ def encode_image(pil_image): import io buffered = io.BytesIO() pil_image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") # --- UI Constants and Helpers --- MAX_SEED = np.iinfo(np.int32).max ## --- Model Loading --- ## device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 # controlnet controlnet = ControlNetModel.from_pretrained( "RayTsai-030/LightsOut-controlnet", torch_dtype=dtype ) # outpainter pipe = ControlNetOutpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-2-inpainting", controlnet=controlnet, torch_dtype=dtype, ).to(device) pipe.scheduler = CustomScheduler.from_config(pipe.scheduler.config) pipe.unet.load_attn_procs("./weights/light_outpaint_lora", use_safetensors=True) # blip processor = Blip2Processor.from_pretrained( "Salesforce/blip2-opt-2.7b", revision="51572668da0eb669e01a189dc22abe6088589a24" ) blip2 = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-opt-2.7b", torch_dtype=dtype, revision="51572668da0eb669e01a189dc22abe6088589a24", ) blip2 = blip2.to(device) # light regressor lsr_module = LightSourceRegressor() ckpt = torch.load( "./weights/light_regress/model.pth", map_location="cpu" if device == "cpu" else None ) lsr_module.load_state_dict(ckpt["model"]) lsr_module.to(device) lsr_module.eval() # SIFR model sifr_model = Uformer(img_size=512, img_ch=3, output_ch=6).to(device) sifr_model.load_state_dict( torch.load( "./weights/net_g_last.pth", map_location="cpu" if device == "cpu" else None ) ) # --- Main Inference Function (with hardcoded negative prompt) --- @spaces.GPU(duration=120) def infer( image, seed=42, cfg=7.5, num_inference_steps=50, repeat_time=4, left_outpaint=64, right_outpaint=64, up_outpaint=64, down_outpaint=64, progress=gr.Progress(track_tqdm=True), ): """ Generates an image """ # dataset dataset = HFCustomImageLoader( image, left_outpaint, right_outpaint, up_outpaint, down_outpaint ) data = dataset[0] # generator generator = torch.Generator(device=device).manual_seed(seed) # transformation transform = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.5], std=[0.5]), ] ) sifr_transform = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Resize((512, 512)), ] ) threshold = 0.5 with torch.no_grad(): input_img = data["input_img"] input_img = transform(input_img).unsqueeze(0).to(device) pred_mask = lsr_module.forward_render(input_img) pred_mask = (pred_mask > threshold).float() if pred_mask.device != "cpu": pred_mask = pred_mask.cpu() pred_mask = pred_mask.numpy() data["control_img"] = Image.fromarray((pred_mask[0, 0] * 255).astype(np.uint8)) print("Finish light source detection...") # prepare text prompt inputs = processor(data["blip_img"], return_tensors="pt").to( device=device, dtype=dtype ) generate_id = blip2.generate(**inputs, max_new_tokens=20) generated_text = processor.batch_decode(generate_id, skip_special_tokens=True)[ 0 ].strip() generated_text += ( ", dynamic lighting, intense light source, prominent lens flare, best quality, high resolution, masterpiece, intricate details" # ", full light sources with lens flare, best quality, high resolution" ) print(f"Generated text prompt: {generated_text}") # Blur mask # data["mask_img"] = data["mask_img"].filter(ImageFilter.GaussianBlur(15)) # denoise outpaint_result = pipe( prompt=generated_text, negative_prompt="NSFW, (word:1.5), watermark, blurry, missing body, amputation, mutilation", image=data["input_img"], mask_image=data["mask_img"], control_image=data["control_img"], num_inference_steps=num_inference_steps, guidance_scale=cfg, generator=generator, repeat_time=repeat_time, ).images[0] # save result outpaint_result = np.array(outpaint_result) input_img = np.array(data["input_img"]) box = data["box"] input_img2 = outpaint_result.copy() input_img2[box[2] : box[3] + 1, box[0] : box[1] + 1] = input_img[ box[2] : box[3] + 1, box[0] : box[1] + 1 ] outpaint_result = blend_with_alpha(outpaint_result, input_img2, box, blur_size=31) outpaint_result = Image.fromarray(outpaint_result.astype(np.uint8)) print("Finish outpainting...") # flare removal img = sifr_transform(outpaint_result).unsqueeze(0).to(device) with torch.no_grad(): output_img = sifr_model(img) gamma = torch.Tensor([2.2]) # flare7k++ deflare_result, _, _ = predict_flare_from_6_channel(output_img, gamma, device) # # mfdnet # flare_mask = torch.zeros_like(img) # deflare_img, _ = predict_flare_from_3_channel( # output_img, flare_mask, output_img, img, img, gamma # ) # deflare_img = blend_light_source(img, deflare_img, 0.999) if deflare_result.device != "cpu": deflare_result = deflare_result.cpu() deflare_result = deflare_result.squeeze(0).permute(1, 2, 0).numpy() deflare_result = np.clip(deflare_result, 0.0, 1.0) deflare_result = (deflare_result * 255).astype(np.uint8) deflare_result = deflare_result[box[2] : box[3] + 1, box[0] : box[1] + 1, :] deflare_result = Image.fromarray(deflare_result).resize((512, 512), Image.LANCZOS) print("Finish flare removal...") return data["control_img"], outpaint_result, deflare_result # --- Examples and UI Layout --- examples = [] with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.HTML(intro) # gr.Markdown( # "[Learn more](https://github.com/QwenLM/Qwen-Image) about the Qwen-Image series. Try on [Qwen Chat](https://chat.qwen.ai/), or [download model](https://huggingface.co/Qwen/Qwen-Image-Edit) to run locally with ComfyUI or diffusers." # ) with gr.Row(): input_image = gr.Image(label="Input Image", show_label=False, type="pil") # with gr.Row(): # with gr.Column(): with gr.Accordion("Advanced Settings", open=False): # Negative prompt UI element is removed here seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) with gr.Column(): left_outpaint = gr.Slider( label="Left outpaint (px)", minimum=32, maximum=128, step=32, value=64, ) right_outpaint = gr.Slider( label="Right outpaint (px)", minimum=32, maximum=128, step=32, value=64, ) up_outpaint = gr.Slider( label="Up outpaint (px)", minimum=32, maximum=128, step=32, value=64 ) down_outpaint = gr.Slider( label="Down outpaint (px)", minimum=32, maximum=128, step=32, value=64, ) # randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): true_guidance_scale = gr.Slider( label="Guidance scale", minimum=1.0, maximum=10.0, step=0.5, value=7.5, ) num_inference_steps = gr.Slider( label="Steps", minimum=30, maximum=70, step=1, value=50, ) repeat_time = gr.Slider( label="Repeat time", minimum=0, maximum=4, step=1, value=4, ) with gr.Row(): with gr.Column(): lightmask_result = gr.Image( label="Lightmask Result", show_label=True, type="pil" ) with gr.Column(): outpainted_result = gr.Image( label="Outpainted Result", show_label=True, type="pil" ) flarefree_result = gr.Image( label="Flare-free Result", show_label=True, type="pil" ) with gr.Row(): run_button = gr.Button("Edit!", variant="primary") # gr.Examples( # examples=[ # ["neon_sign.png", "change the text to read 'Qwen Image Edit is here'"], # [ # "cat_sitting.jpg", # "make the cat floating in the air and holding a sign that reads 'this is fun' written with a blue crayon", # ], # ["pie.png", "turn the style of the photo to vintage comic book"], # ], # inputs=[input_image, prompt], # outputs=[result, seed], # fn=infer, # cache_examples="lazy", # ) gr.on( triggers=[run_button.click], fn=infer, inputs=[ input_image, seed, true_guidance_scale, num_inference_steps, repeat_time, left_outpaint, right_outpaint, up_outpaint, down_outpaint, ], outputs=[lightmask_result, outpainted_result, flarefree_result], ) if __name__ == "__main__": demo.launch()