import os import csv import torch import diffusers import gradio as gr import numpy as np import onnxruntime as ort from PIL import Image from transformers import BlipProcessor, BlipForConditionalGeneration from huggingface_hub import snapshot_download from pathlib import Path # ----------------------------- # 環境與效能設定(CPU 友善) # ----------------------------- os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") # 限制 CPU 執行緒數(避免在 Spaces 上搶太多資源) torch.set_num_threads(max(1, min(4, os.cpu_count() or 2))) DEVICE = "cpu" # 若有 GPU 可改為 "cuda" DTYPE = torch.float32 # ----------------------------- # 載入模型(首次啟動會自動下載) # ----------------------------- def load_sd_pipe(): model_id = "Laxhar/noobai-XL-1.0" pipe = diffusers.StableDiffusionXLPipeline.from_pretrained( model_id, torch_dtype=DTYPE, use_safetensors=True ) pipe = pipe.to(DEVICE) # 省記憶體設定(對 CPU/低資源環境友好) pipe.enable_attention_slicing() pipe.enable_vae_tiling() return pipe def load_blip(): processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE) return processor, model def load_wd_tagger(): """ 下載並載入 SmilingWolf WD Tagger (ONNX) + tags.csv 回傳: (onnxruntime.InferenceSession, tags_list, categories_list) """ repo_id = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" cache_dir = snapshot_download(repo_id, local_files_only=False) cache = Path(cache_dir) # 嘗試找出 onnx 與 tags 檔 onnx_files = list(cache.glob("*.onnx")) if not onnx_files: raise FileNotFoundError("WD Tagger .onnx not found in repo.") onnx_path = str(onnx_files[0]) # 新版通常叫 selected_tags.csv,舊版叫 tags.csv csv_path = None for name in ["selected_tags.csv", "tags.csv"]: p = cache / name if p.exists(): csv_path = p break if csv_path is None: raise FileNotFoundError("WD Tagger tags csv not found.") # 讀取 tags 與類別 tags, cats = [], [] with open(csv_path, "r", encoding="utf-8") as f: reader = csv.DictReader(f) # 欄位通常至少有: "name","category" for row in reader: tags.append(row["name"]) # category: 0=general,1=character,2=copyright,3=artist,4=rating cats.append(int(row.get("category", 0))) # 建 session(CPU) sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) return sess, tags, cats SD_PIPE = load_sd_pipe() BLIP_PROCESSOR, BLIP_MODEL = load_blip() WD_SESS, WD_TAGS, WD_CATS = load_wd_tagger() # ----------------------------- # Scheduler 工具函式 # ----------------------------- # 提供多種取樣方法(sampler/scheduler) SCHEDULER_FACTORIES = { "DPM++ 2M": lambda cfg: diffusers.DPMSolverMultistepScheduler.from_config(cfg), "Euler": lambda cfg: diffusers.EulerDiscreteScheduler.from_config(cfg), "Euler A": lambda cfg: diffusers.EulerAncestralDiscreteScheduler.from_config(cfg), "LMS": lambda cfg: diffusers.LMSDiscreteScheduler.from_config(cfg), "Heun": lambda cfg: diffusers.HeunDiscreteScheduler.from_config(cfg), "PNDM": lambda cfg: diffusers.PNDMScheduler.from_config(cfg), "DDIM": lambda cfg: diffusers.DDIMScheduler.from_config(cfg), "UniPC": lambda cfg: diffusers.UniPCMultistepScheduler.from_config(cfg), } def apply_scheduler(pipe, name: str, use_karras: bool): name = name if name in SCHEDULER_FACTORIES else "DPM++ 2M" new_scheduler = SCHEDULER_FACTORIES[name](pipe.scheduler.config) if hasattr(new_scheduler, "config") and hasattr(new_scheduler.config, "use_karras_sigmas"): new_scheduler.config.use_karras_sigmas = bool(use_karras) pipe.scheduler = new_scheduler return pipe # ----------------------------- # 功能:Text → Image(Stable Diffusion) # ----------------------------- def txt2img( prompt: str, neg_prompt: str, steps: int, guidance: float, seed: int | None, scheduler_name: str, use_karras: bool, width: int, height: int ): if not prompt or prompt.strip() == "": return None, "請輸入 prompt。" # 每次呼叫前套用最新 scheduler 設定 apply_scheduler(SD_PIPE, scheduler_name, use_karras) generator = None if seed is not None and seed >= 0: generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) with torch.inference_mode(): image = SD_PIPE( prompt=prompt, negative_prompt=neg_prompt or None, num_inference_steps=int(steps), guidance_scale=float(guidance), generator=generator, height=int(height), width=int(width), ).images[0] return image, "生成完成" # ----------------------------- # 功能:Image → Text(BLIP Caption) # ----------------------------- def caption_image(image: Image.Image, max_len: int = 50): if image is None: return "請先提供圖片" image = image.convert("RGB") inputs = BLIP_PROCESSOR(image, return_tensors="pt").to(DEVICE) with torch.inference_mode(): out = BLIP_MODEL.generate(**inputs, max_length=int(max_len)) caption = BLIP_PROCESSOR.decode(out[0], skip_special_tokens=True) return caption # ----------------------------- # 功能:Image → Tags(WD Tagger) # ----------------------------- def _wd_preprocess(img: Image.Image, target=448, layout="NHWC"): """將圖片等比縮放+邊緣填充到正方形,回傳 layout 指定形狀的 float32 陣列。""" img = img.convert("RGB") w, h = img.size scale = target / max(w, h) nw, nh = int(w * scale), int(h * scale) # resize + pad 到 target x target img = img.resize((nw, nh), Image.Resampling.LANCZOS) canvas = Image.new("RGB", (target, target), (255, 255, 255)) left = (target - nw) // 2 top = (target - nh) // 2 canvas.paste(img, (left, top)) arr = np.asarray(canvas, dtype=np.float32) / 255.0 # HWC, [0,1] if layout == "NCHW": arr = arr.transpose(2, 0, 1) # HWC -> CHW arr = np.expand_dims(arr, 0) # -> NCHW else: # NHWC arr = np.expand_dims(arr, 0) # -> NHWC return arr def wd_tagger_predict( image: Image.Image, general_thresh: float = 0.35, char_thresh: float = 0.35, ip_thresh: float = 0.35, artist_thresh: float = 0.5, rating_thresh: float = 0.5, top_k: int = 50, exclude: str = "", output_mode: str = "Comma (Danbooru style)" ): if image is None: return "請先上傳圖片。" # 🔧 動態偵測 ONNX 輸入佈局 inp = WD_SESS.get_inputs()[0] shp = inp.shape layout = "NHWC" if (len(shp) == 4 and (shp[-1] == 3 or str(shp[-1]) == '3')) else "NCHW" x = _wd_preprocess(image, target=448, layout=layout).astype(np.float32) in_name = inp.name out = WD_SESS.run(None, {in_name: x})[0][0] # (num_tags,) probs = out.tolist() # === 修正: 動態類別 === # 已知對應(若 CSV 裡有新類別,會 fallback 到 cat_) KNOWN_CAT_NAMES = {0: "general", 1: "character", 2: "copyright", 3: "artist", 4: "rating"} # 實際出現的類別集合 unique_cats = sorted(set(WD_CATS)) # 類別名稱表(未知類別用 "cat_") CAT_NAMES = {c: KNOWN_CAT_NAMES.get(c, f"cat_{c}") for c in unique_cats} # 門檻表: 既有類別用對應門檻,未知類別預設用 general 門檻 THRESH = {c: general_thresh for c in unique_cats} THRESH.update({ 0: general_thresh, 1: char_thresh, 2: ip_thresh, 3: artist_thresh, 4: rating_thresh }) # 依實際類別建立容器 grouped = {c: [] for c in unique_cats} exclude_set = set([t.strip() for t in exclude.split(",") if t.strip()]) # 蒐集符合門檻的標籤 for idx, (tag, cat) in enumerate(zip(WD_TAGS, WD_CATS)): p = probs[idx] thr = THRESH.get(cat, general_thresh) # 未知類別走 general 門檻 if p >= thr and tag not in exclude_set: grouped[cat].append((tag, p)) # 各類別排序與裁切 for c in grouped: grouped[c] = sorted(grouped[c], key=lambda x: x[1], reverse=True)[:top_k] # 輸出 if output_mode == "Comma (Danbooru style)": flat = [t for c in unique_cats for (t, _) in grouped[c]] return ", ".join(flat) if flat else "(無符合門檻的標籤)" else: # 分類分段輸出(未知類別一樣顯示) parts = [] for c in unique_cats: items = grouped[c] title = CAT_NAMES.get(c, f"cat_{c}") if not items: parts.append(f"### {title}\n(無)\n") else: lines = [f"{tag} ({p:.3f})" for tag, p in items] parts.append(f"### {title}\n" + "\n".join(lines) + "\n") return "\n".join(parts) # ----------------------------- # Gradio 介面 # ----------------------------- with gr.Blocks(title="NoobAI-XL + BLIP", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # SDXL (Laxhar/noobai-XL-1.0) + BLIP + WD Tagger 動漫風格的 Stable Diffusion XL 生圖,BLIP 圖片描述,與 WD 標籤自動生成。\n 提示:在 CPU 上建議 Steps 20 左右、Guidance Scale 7 左右、解析度 512~768。 """ ) with gr.Tabs(): # Tab 1: Text -> Image with gr.TabItem("Text → Image"): with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox(label="Prompt", lines=3, placeholder="正向提示詞 Prompt") neg_prompt = gr.Textbox(label="Negative Prompt", lines=3, placeholder="反向提示詞 Negative Prompt") steps = gr.Slider(1, 100, value=20, step=1, label="Steps") guidance = gr.Slider(1.0, 20.0, value=7, step=0.5, label="Guidance Scale") width = gr.Slider(256, 1536, value=768, step=8, label="寬度 Width") height = gr.Slider(256, 1536, value=768, step=8, label="高度 Height") seed = gr.Number(label="Seed(-1 表示隨機)", value=-1, precision=0) scheduler_name = gr.Dropdown( choices=list(SCHEDULER_FACTORIES.keys()), value="DPM++ 2M", label="取樣方法 (Sampler / Scheduler)" ) use_karras = gr.Checkbox(value=True, label="使用 Karras Sigmas") run_btn = gr.Button("圖片生成", variant="primary") with gr.Column(scale=1): out_img = gr.Image(label="Generated Image", format="png") status = gr.Markdown() def _run_txt2img(prompt, neg_prompt, steps, guidance, seed, scheduler_name, use_karras, width, height): prog = None try: prog = gr.Progress(track_tqdm=True) prog(0, desc="圖片生成中...") except Exception: prog = None # 某些舊版 gradio 沒有這個介面 image, msg = txt2img( prompt, neg_prompt, steps, guidance, int(seed) if seed is not None else None, scheduler_name, bool(use_karras), int(width), int(height) ) try: if prog is not None: prog(1, desc="完成") except Exception: pass return image, msg run_btn.click( _run_txt2img, [prompt, neg_prompt, steps, guidance, seed, scheduler_name, use_karras, width, height], [out_img, status] ) # Tab 2: Image -> Caption with gr.TabItem("Image → Caption"): with gr.Row(): with gr.Column(scale=1): in_img = gr.Image(label="Upload Image", type="pil") max_len = gr.Slider(10, 100, value=50, step=5, label="Caption 長度上限") cap_btn = gr.Button("產生描述") with gr.Column(scale=1): caption_out = gr.Textbox(label="BLIP Caption") cap_btn.click(caption_image, [in_img, max_len], [caption_out]) # Tab 3: Image -> WD Tags with gr.TabItem("WD Tagger"): gr.Markdown("上傳圖片,產生 Danbooru/WD 風格的標籤。") with gr.Row(): with gr.Column(scale=1): tag_img = gr.Image(label="Upload Image", type="pil") general_thresh = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="General 門檻") char_thresh = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Character 門檻") ip_thresh = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="IP/版權 門檻") artist_thresh = gr.Slider(0.0, 1.0, value=0.50, step=0.01, label="Artist 門檻") rating_thresh = gr.Slider(0.0, 1.0, value=0.50, step=0.01, label="Rating 門檻") top_k = gr.Slider(1, 200, value=50, step=1, label="每類 Top-K") exclude_tags = gr.Textbox(label="排除標籤(逗號分隔)", placeholder="nsfw, lowres, ...") output_mode = gr.Radio(choices=["Comma (Danbooru style)", "分類分段 (含信心分數)"], value="Comma (Danbooru style)", label="輸出格式") tag_btn = gr.Button("產生標籤", variant="primary") with gr.Column(scale=1): tag_output = gr.Textbox(label="Tags Output", lines=18) def _run_wd_tagger(img, gt, ct, it, at, rt, k, exc, mode): return wd_tagger_predict( img, general_thresh=float(gt), char_thresh=float(ct), ip_thresh=float(it), artist_thresh=float(at), rating_thresh=float(rt), top_k=int(k), exclude=exc or "", output_mode=mode ) tag_btn.click( _run_wd_tagger, [tag_img, general_thresh, char_thresh, ip_thresh, artist_thresh, rating_thresh, top_k, exclude_tags, output_mode], [tag_output] ) gr.Examples( examples=[ ["examples/miku.png", "masterpiece, best quality, amazing quality, highres, absurdres, 1girl, hatsune miku, white pupils, power elements, microphone, vibrant blue color palette, abstract,abstract background, dreamlike atmosphere, delicate linework, wind-swept hair, energy", "worst quality, old, early, low quality, lowres, signature, username, logo, bad hands, mutated hands, mammal, anthro, furry, ambiguous form, feral, semi-anthro, nsfw", 28, 7.0, 24, "DPM++ 2M", True, 768, 768], ["examples/gura.png", "masterpiece, best quality, amazing quality, highres, absurdres, 1girl, virtual youtuber, gawr gura, solo, blue hair, grey hair, medium hair, long hair, multicolored hair, streaked hair, bangs, blunt bangs, side ponytail, blue eyes, sharp teeth, teeth, fang, cat ears, hair ornament, blue nails, shark tail, nail polish, tail, fish tail, teeth, shark girl, sleeveless, dress, shirt, :d, blush, smile, paw pose, open mouth, full body, simple_background, looking_at_viewer, upper_body, front view, bubble, flat color, no lineart", "worst quality, old, early, low quality, lowres, signature, username, logo, bad hands, mutated hands, mammal, anthro, furry, ambiguous form, feral, semi-anthro, nsfw", 20, 7.0, 42, "DPM++ 2M", True, 768, 768], ["examples/cat_girl.png", "masterpiece, best quality, amazing quality, highres, absurdres, 1girl, clean lineart, usnr, :d, cat_pose, blush, kawaii, white hair, long hair, hair between eyes, bangs, pink eyes, fang, pink_nails, open mouth, white fake cat ears, simple_background, hugging a cat, light pink dress, hair ornament, long sleeves, puffy long sleeves, pink bow, puffy sleeves, black choker, upper body, looking_at_viewer, front view", "worst quality, old, early, low quality, lowres, signature, username, logo, bad hands, mutated hands, mammal, anthro, furry, ambiguous form, feral, semi-anthro, nsfw", 20, 8.0, 1234, "DPM++ 2M", True, 768, 768], ], inputs=[in_img, prompt, neg_prompt, steps, guidance, seed, scheduler_name, use_karras, width, height], label="Prompt 範例(可直接點選)", ) if __name__ == "__main__": demo.launch()