import os
import gradio as gr
import numpy as np
import spaces
import torch
import random
import io
import base64
import json
from PIL import Image
from gradio_client import Client
from huggingface_hub import InferenceClient
from deep_translator import GoogleTranslator
from transformers import pipeline
from diffusers import Flux2Pipeline, Flux2Transformer2DModel
from datetime import date
# ==========================================
# 1. تنظیمات و پیکربندی سیستم (Configuration)
# ==========================================
# رنگها و تنظیمات ظاهری
USAGE_LIMIT = 5
DATA_FILE = "usage_data.json"
PREMIUM_PAGE_ID = '1149636'
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
device = "cuda" if torch.cuda.is_available() else "cpu"
# بارگذاری مدل تشخیص محتوای نامناسب (Safety Checker)
print("Loading Safety Checker...")
safety_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=-1)
# کلاینتهای هوش مصنوعی
hf_client = InferenceClient(api_key=os.environ.get("HF_TOKEN"))
VLM_MODEL = "baidu/ERNIE-4.5-VL-424B-A47B-Base-PT"
# پرامپتهای سیستمی برای بهبود متن
SYSTEM_PROMPT_TEXT_ONLY = """You are an expert prompt engineer for FLUX.2. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent. Add concrete visual specifics."""
SYSTEM_PROMPT_WITH_IMAGES = """You are FLUX.2 image-editing expert. Convert editing requests into one concise instruction (50-80 words)."""
# لیست کلمات ممنوعه (Strict Mode)
BANNED_WORDS = [
"nsfw", "nude", "naked", "sex", "porn", "erotic", "xxx", "18+", "adult",
"explicit", "uncensored", "sexual", "lewd", "sensual", "lust", "horny",
"breast", "breasts", "nipple", "nipples", "vagina", "pussy", "cunt",
"penis", "dick", "cock", "genital", "genitals", "groin", "pubic",
"ass", "butt", "buttocks", "anus", "anal", "rectum",
"intercourse", "masturbation", "orgasm", "blowjob", "bj", "cum", "sperm",
"ejaculation", "penetration", "fucking", "sucking", "licking",
"lingerie", "bikini", "swimwear", "underwear", "panties", "bra", "thong",
"topless", "bottomless", "undressed", "unclothed", "skimpy", "transparent",
"fetish", "bdsm", "bondage", "latex", "hentai", "ecchi", "ahegao",
"gore", "bloody", "blood", "kill", "murder", "dead", "torture", "abuse"
]
# ==========================================
# 2. بارگذاری مدل FLUX.2
# ==========================================
print("Loading FLUX.2 Pipeline...")
repo_id = "black-forest-labs/FLUX.2-dev"
dit = Flux2Transformer2DModel.from_pretrained(
repo_id,
subfolder="transformer",
torch_dtype=torch.bfloat16
)
pipe = Flux2Pipeline.from_pretrained(
repo_id,
text_encoder=None,
transformer=dit,
torch_dtype=torch.bfloat16
)
pipe.to(device)
# بهینهسازی ZeroGPU
spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/FLUX.2", variant="fa3")
# ==========================================
# 3. توابع کمکی (Helpers)
# ==========================================
def load_usage_data():
if os.path.exists(DATA_FILE):
try:
with open(DATA_FILE, 'r') as f:
return json.load(f)
except:
return {}
return {}
def save_usage_data(data):
try:
with open(DATA_FILE, 'w') as f:
json.dump(data, f)
except Exception as e:
print(f"Error saving data: {e}")
usage_data_cache = load_usage_data()
def is_image_nsfw(image):
if image is None: return False
try:
# اگر ورودی لیست گالری باشد، اولین تصویر را چک کن
img_to_check = image
if isinstance(image, list):
# هندل کردن فرمت گالری گرادیو
if len(image) > 0:
img_to_check = image[0][0] if isinstance(image[0], tuple) else image[0]
else:
return False
results = safety_classifier(img_to_check)
for result in results:
if result['label'] == 'nsfw' and result['score'] > 0.75:
return True
return False
except Exception as e:
print(f"Safety check error: {e}")
return False
def check_text_safety(text):
if not text: return True
text_lower = text.lower()
padded_text = f" {text_lower} "
for char in [".", ",", "!", "?", "-", "_", "(", ")", "[", "]", "{", "}"]:
padded_text = padded_text.replace(char, " ")
for word in BANNED_WORDS:
if f" {word} " in padded_text:
return False
return True
def translate_prompt(text):
if not text: return ""
try:
translated = GoogleTranslator(source='auto', target='en').translate(text)
return translated
except Exception as e:
print(f"Translation Error: {e}")
return text
def get_error_html(message):
return f"""
⛔{message}
"""
def get_success_html(message):
return f"""✅{message}
"""
def get_quota_exceeded_html():
return """💎
اعتبار رایگان امروز تمام شد
شما از ۵ تصویر رایگان امروز استفاده کردهاید.
برای ساخت تصاویر نامحدود و حرفهای، لطفا نسخه خود را ارتقا دهید.
"""
def get_user_record(fingerprint):
global usage_data_cache
if not fingerprint: return None
usage_data_cache = load_usage_data()
today_str = date.today().isoformat()
user_record = usage_data_cache.get(fingerprint)
if not user_record or user_record.get("last_reset") != today_str:
return {"count": 0, "last_reset": today_str}
return user_record
def consume_quota(fingerprint):
global usage_data_cache
today_str = date.today().isoformat()
usage_data_cache = load_usage_data()
user_record = usage_data_cache.get(fingerprint)
if not user_record or user_record.get("last_reset") != today_str:
user_record = {"count": 0, "last_reset": today_str}
user_record["count"] += 1
usage_data_cache[fingerprint] = user_record
save_usage_data(usage_data_cache)
return user_record["count"]
def check_initial_quota(fingerprint, subscription_status):
if not fingerprint: return gr.update(visible=True), gr.update(visible=False), None
if subscription_status == 'paid': return gr.update(visible=True), gr.update(visible=False), None
user_record = get_user_record(fingerprint)
current_usage = user_record["count"] if user_record else 0
if current_usage >= USAGE_LIMIT:
return gr.update(visible=False), gr.update(visible=True), get_quota_exceeded_html()
else:
return gr.update(visible=True), gr.update(visible=False), None
def image_to_data_uri(img):
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/png;base64,{img_str}"
def remote_text_encoder(prompts):
client = Client("multimodalart/mistral-text-encoder")
result = client.predict(prompt=prompts, api_name="/encode_text")
prompt_embeds = torch.load(result[0])
return prompt_embeds
def upsample_prompt_logic(prompt, image_list):
try:
if image_list and len(image_list) > 0:
system_content = SYSTEM_PROMPT_WITH_IMAGES
user_content = [{"type": "text", "text": prompt}]
for img in image_list:
data_uri = image_to_data_uri(img)
user_content.append({"type": "image_url", "image_url": {"url": data_uri}})
messages = [{"role": "system", "content": system_content}, {"role": "user", "content": user_content}]
else:
system_content = SYSTEM_PROMPT_TEXT_ONLY
messages = [{"role": "system", "content": system_content}, {"role": "user", "content": prompt}]
completion = hf_client.chat.completions.create(model=VLM_MODEL, messages=messages, max_tokens=1024)
return completion.choices[0].message.content
except Exception as e:
print(f"Upsampling failed: {e}")
return prompt
def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
num_images = 0 if image_list is None else len(image_list)
step_duration = 1 + 0.8 * num_images
return max(65, num_inference_steps * step_duration + 10)
# ==========================================
# 4. تابع اصلی GPU (Inference)
# ==========================================
@spaces.GPU(duration=get_duration)
def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
prompt_embeds = prompt_embeds.to(device)
generator = torch.Generator(device=device).manual_seed(seed)
pipe_kwargs = {
"prompt_embeds": prompt_embeds,
"image": image_list,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"generator": generator,
"width": width,
"height": height,
}
if progress: progress(0, desc="Starting generation...")
image = pipe(**pipe_kwargs).images[0]
return image
def infer(
prompt, input_images, seed, randomize_seed, width, height,
num_inference_steps, guidance_scale, prompt_upsampling,
fingerprint, subscription_status,
progress=gr.Progress(track_tqdm=True)
):
# 1. بررسی اعتبار قبل از شروع
if subscription_status != 'paid':
user_record = get_user_record(fingerprint)
if user_record and user_record["count"] >= USAGE_LIMIT:
return None, seed, get_quota_exceeded_html(), gr.update(visible=False), gr.update(visible=True)
# 2. بررسیهای ایمنی (Safety Checks)
# الف) بررسی تصویر ورودی
image_list = None
if input_images is not None and len(input_images) > 0:
image_list = [item[0] for item in input_images]
if is_image_nsfw(image_list):
return None, seed, get_error_html("تصویر ورودی دارای محتوای نامناسب است."), gr.update(visible=True), gr.update(visible=False)
# ب) ترجمه و بررسی متن
progress(0.1, desc="Translating...")
english_prompt = translate_prompt(prompt)
if not check_text_safety(english_prompt):
return None, seed, get_error_html("متن درخواست شامل کلمات غیرمجاز است."), gr.update(visible=True), gr.update(visible=False)
# 3. کسر اعتبار (اگر کاربر رایگان است)
if subscription_status != 'paid':
consume_quota(fingerprint)
# 4. آمادهسازی تنظیمات
if randomize_seed:
seed = random.randint(0, MAX_SEED)
try:
# Upsampling Prompt (Optional)
final_prompt = english_prompt
if prompt_upsampling:
progress(0.2, desc="Enhancing prompt...")
final_prompt = upsample_prompt_logic(english_prompt, image_list)
# Text Encoding (CPU/Network)
progress(0.3, desc="Encoding...")
prompt_embeds = remote_text_encoder(final_prompt)
# Generation (GPU)
progress(0.4, desc="Generating...")
result_image = generate_image(
prompt_embeds, image_list, width, height,
num_inference_steps, guidance_scale, seed, progress
)
# 5. بررسی تصویر خروجی
if is_image_nsfw(result_image):
return None, seed, get_error_html("تصویر تولید شده حاوی محتوای نامناسب بود."), gr.update(visible=True), gr.update(visible=False)
# 6. محاسبه اعتبار باقیمانده
user_record = get_user_record(fingerprint)
remaining = USAGE_LIMIT - user_record["count"] if user_record else 0
success_msg = f"تصویر با موفقیت ساخته شد."
if subscription_status != 'paid':
success_msg += f" (اعتبار باقیمانده امروز: {remaining})"
btn_run_update = gr.update(visible=True)
btn_upg_update = gr.update(visible=False)
if subscription_status != 'paid' and remaining <= 0:
btn_run_update = gr.update(visible=False)
btn_upg_update = gr.update(visible=True)
return result_image, seed, get_success_html(success_msg), btn_run_update, btn_upg_update
except Exception as e:
error_str = str(e)
if "quota" in error_str.lower() or "exceeded" in error_str.lower():
raise e # Raise to be caught by JS
return None, seed, get_error_html(f"خطا در پردازش: {error_str}"), gr.update(visible=True), gr.update(visible=False)
def update_dimensions_from_image(image_list):
if image_list is None or len(image_list) == 0:
return 1024, 1024
img = image_list[0][0]
img_width, img_height = img.size
aspect_ratio = img_width / img_height
if aspect_ratio >= 1:
new_width = 1024
new_height = int(1024 / aspect_ratio)
else:
new_height = 1024
new_width = int(1024 * aspect_ratio)
new_width = round(new_width / 8) * 8
new_height = round(new_height / 8) * 8
return max(256, min(1024, new_width)), max(256, min(1024, new_height))
# ==========================================
# 5. جاوااسکریپت و CSS (UI/UX)
# ==========================================
js_download_func = """
async (image) => {
if (!image) { alert("لطفاً ابتدا تصویر را تولید کنید."); return; }
let fileUrl = image.url;
if (fileUrl && !fileUrl.startsWith('http')) { fileUrl = window.location.origin + fileUrl; }
else if (!fileUrl && image.path) { fileUrl = window.location.origin + "/file=" + image.path; }
window.parent.postMessage({ type: 'DOWNLOAD_REQUEST', url: fileUrl }, '*');
}
"""
js_upgrade_func = """() => { window.parent.postMessage({ type: 'NAVIGATE_TO_PREMIUM' }, '*'); }"""
js_global_content = """
"""
css_code = """
"""
# ==========================================
# 6. ساخت رابط کاربری (Gradio Blocks)
# ==========================================
# ******** این خط اصلاح شده است ********
with gr.Blocks() as demo:
gr.HTML(js_global_content + css_code)
fingerprint_box = gr.Textbox(elem_id="fingerprint_storage", visible=True)
status_box_input = gr.Textbox(elem_id="status_storage", visible=True)
with gr.Column(elem_id="col-container"):
gr.Markdown("# **ساخت تصویر با FLUX.2 (پیشرفته)**", elem_id="main-title")
gr.Markdown("با استفاده از مدل قدرتمند FLUX.2 متن فارسی خود را به تصاویر شگفتانگیز تبدیل کنید.", elem_id="main-description")
gr.HTML('
')
with gr.Row():
with gr.Column():
with gr.Row():
prompt = gr.Text(
label="توصیف تصویر (به فارسی)",
show_label=True,
max_lines=3,
placeholder="یک منظره زیبا از...",
rtl=True
)
with gr.Accordion("بارگذاری تصویر (اختیاری برای ویرایش/ایده)", open=False):
input_images = gr.Gallery(
label="تصاویر ورودی",
type="pil",
columns=3,
rows=1,
height=200
)
status_box = gr.HTML(label="وضعیت")
run_button = gr.Button("✨ ساخت تصویر", variant="primary", elem_classes="primary-btn", elem_id="run-btn", visible=True)
upgrade_button = gr.Button("💎 خرید نسخه نامحدود", variant="primary", elem_classes="upgrade-btn", elem_id="upgrade-btn", visible=False)
with gr.Accordion("تنظیمات پیشرفته", open=False):
prompt_upsampling = gr.Checkbox(label="بهبود خودکار پرامپت (هوشمند)", value=True)
seed = gr.Slider(label="دانه تصادفی (Seed)", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Seed تصادفی", value=True)
with gr.Row():
width = gr.Slider(label="عرض (Width)", minimum=256, maximum=MAX_IMAGE_SIZE, step=8, value=1024)
height = gr.Slider(label="ارتفاع (Height)", minimum=256, maximum=MAX_IMAGE_SIZE, step=8, value=1024)
with gr.Row():
num_inference_steps = gr.Slider(label="تعداد مراحل (Steps)", minimum=1, maximum=50, step=1, value=28)
guidance_scale = gr.Slider(label="میزان وفاداری (Guidance)", minimum=1.0, maximum=10.0, step=0.1, value=3.5)
with gr.Column():
result = gr.Image(label="تصویر نهایی", show_label=True, interactive=False)
download_button = gr.Button("📥 دانلود تصویر", variant="secondary", elem_id="download-btn")
# اتصال رویدادها
# 1. آپدیت ابعاد بر اساس تصویر آپلودی
input_images.upload(
fn=update_dimensions_from_image,
inputs=[input_images],
outputs=[width, height]
)
# 2. بررسی اولیه اعتبار
fingerprint_box.change(
fn=check_initial_quota,
inputs=[fingerprint_box, status_box_input],
outputs=[run_button, upgrade_button, status_box]
)
# 3. اجرای مدل
run_button.click(
fn=infer,
inputs=[
prompt, input_images, seed, randomize_seed, width, height,
num_inference_steps, guidance_scale, prompt_upsampling,
fingerprint_box, status_box_input
],
outputs=[result, seed, status_box, run_button, upgrade_button]
)
# 4. دکمههای دانلود و ارتقا
upgrade_button.click(fn=None, js=js_upgrade_func)
download_button.click(fn=None, inputs=[result], js=js_download_func)
if __name__ == "__main__":
demo.queue(max_size=30).launch(show_error=True)