Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import cv2 | |
| import time | |
| import copy | |
| import random | |
| import torch | |
| import spaces | |
| import requests | |
| import subprocess | |
| import gradio as gr | |
| from PIL import Image | |
| import importlib.util | |
| from threading import Thread | |
| from typing import Iterable, Optional, Tuple, List | |
| def check_and_install_package(package_name, import_name=None, pip_name=None): | |
| """Check if a package is installed, and if not, install it.""" | |
| if import_name is None: | |
| import_name = package_name | |
| if pip_name is None: | |
| pip_name = package_name | |
| spec = importlib.util.find_spec(import_name) | |
| if spec is None: | |
| print(f"Installing {package_name}...") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name, "-q"]) | |
| print(f"✓ {package_name} installed successfully") | |
| return True | |
| print("Checking and installing transformers==4.57.3 ...") | |
| check_and_install_package("transformers", "transformers", "transformers==4.57.3") | |
| print("Done!") | |
| from transformers import ( | |
| Qwen2_5_VLForConditionalGeneration, | |
| AutoModelForImageTextToText, | |
| AutoModelForCausalLM, | |
| AutoProcessor, | |
| TextIteratorStreamer, | |
| ) | |
| try: | |
| from vibevoice.modular.modeling_vibevoice_streaming_inference import ( | |
| VibeVoiceStreamingForConditionalGenerationInference, | |
| ) | |
| from vibevoice.processor.vibevoice_streaming_processor import ( | |
| VibeVoiceStreamingProcessor, | |
| ) | |
| except ImportError: | |
| print("CRITICAL WARNING: 'vibevoice' modules not found. Ensure the vibevoice repository structure is present.") | |
| VibeVoiceStreamingForConditionalGenerationInference = None | |
| VibeVoiceStreamingProcessor = None | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| colors.orange_red = colors.Color( | |
| name="orange_red", | |
| c50="#FFF0E5", | |
| c100="#FFE0CC", | |
| c200="#FFC299", | |
| c300="#FFA366", | |
| c400="#FF8533", | |
| c500="#FF4500", | |
| c600="#E63E00", | |
| c700="#CC3700", | |
| c800="#B33000", | |
| c900="#992900", | |
| c950="#802200", | |
| ) | |
| class OrangeRedTheme(Soft): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.orange_red, | |
| neutral_hue: colors.Color | str = colors.slate, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Outfit"), "Arial", "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| super().set( | |
| background_fill_primary="*primary_50", | |
| background_fill_primary_dark="*primary_900", | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| button_primary_text_color="white", | |
| button_primary_text_color_hover="white", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_secondary_text_color="black", | |
| button_secondary_text_color_hover="white", | |
| button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", | |
| button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", | |
| button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", | |
| button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", | |
| slider_color="*secondary_500", | |
| slider_color_dark="*secondary_600", | |
| block_title_text_weight="600", | |
| block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", | |
| button_primary_shadow="*shadow_drop_lg", | |
| button_large_padding="11px", | |
| color_accent_soft="*primary_100", | |
| block_label_background_fill="*primary_200", | |
| ) | |
| orange_red_theme = OrangeRedTheme() | |
| css = """ | |
| #main-title h1 { | |
| font-size: 2.3em !important; | |
| } | |
| #output-title h2 { | |
| font-size: 2.1em !important; | |
| } | |
| .generating { | |
| border: 2px solid #4682B4; | |
| } | |
| """ | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(f"Using Main Device: {device}") | |
| QWEN_VL_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" | |
| print(f"Loading OCR Model: {QWEN_VL_MODEL_ID}...") | |
| qwen_processor = AutoProcessor.from_pretrained(QWEN_VL_MODEL_ID, trust_remote_code=True) | |
| qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| QWEN_VL_MODEL_ID, | |
| #attn_implementation="flash_attention_2", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).to(device).eval() | |
| print("Model loaded successfully.") | |
| TTS_MODEL_PATH = "microsoft/VibeVoice-Realtime-0.5B" | |
| print(f"Loading TTS Model: {TTS_MODEL_PATH}...") | |
| print("VibeVoice Model loaded successfully.") | |
| tts_processor = VibeVoiceStreamingProcessor.from_pretrained(TTS_MODEL_PATH) | |
| tts_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( | |
| TTS_MODEL_PATH, | |
| torch_dtype=torch.float16, | |
| device_map="cuda", | |
| attn_implementation="sdpa", | |
| ) | |
| tts_model.eval() | |
| tts_model.set_ddpm_inference_steps(num_steps=5) | |
| class VoiceMapper: | |
| """Maps speaker names to voice file paths""" | |
| def __init__(self): | |
| self.setup_voice_presets() | |
| new_dict = {} | |
| for name, path in self.voice_presets.items(): | |
| if "_" in name: name = name.split("_")[0] | |
| if "-" in name: name = name.split("-")[-1] | |
| new_dict[name] = path | |
| self.voice_presets.update(new_dict) | |
| def setup_voice_presets(self): | |
| voices_dir = os.path.join(os.path.dirname(__file__), "demo/voices/streaming_model") | |
| if not os.path.exists(voices_dir): | |
| print(f"Warning: Voices directory not found at {voices_dir}") | |
| self.voice_presets = {} | |
| self.available_voices = {} | |
| return | |
| self.voice_presets = {} | |
| pt_files = [f for f in os.listdir(voices_dir) if f.lower().endswith(".pt") and os.path.isfile(os.path.join(voices_dir, f))] | |
| for pt_file in pt_files: | |
| name = os.path.splitext(pt_file)[0] | |
| full_path = os.path.join(voices_dir, pt_file) | |
| self.voice_presets[name] = full_path | |
| self.voice_presets = dict(sorted(self.voice_presets.items())) | |
| self.available_voices = {name: path for name, path in self.voice_presets.items() if os.path.exists(path)} | |
| print(f"Found {len(self.available_voices)} voice files.") | |
| def get_voice_path(self, speaker_name: str) -> str: | |
| if speaker_name in self.voice_presets: | |
| return self.voice_presets[speaker_name] | |
| speaker_lower = speaker_name.lower() | |
| for preset_name, path in self.voice_presets.items(): | |
| if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower(): | |
| return path | |
| if self.voice_presets: | |
| return list(self.voice_presets.values())[0] | |
| return "" | |
| VOICE_MAPPER = VoiceMapper() | |
| print("TTS Model loaded successfully.") | |
| def process_pipeline( | |
| image: Image.Image, | |
| query: str, | |
| speaker_name: str, | |
| cfg_scale: float, | |
| ocr_max_tokens: int, | |
| ocr_temp: float, | |
| progress=gr.Progress() | |
| ): | |
| """ | |
| Combined pipeline: Image - Text -> TTS - Audio | |
| """ | |
| if image is None: | |
| return "Please upload an image.", None, "Error: No image provided." | |
| progress(0.2, desc="Analyzing Image ()...") | |
| if not query: | |
| query = "Analyze the content perfectly." | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": query}, | |
| ] | |
| }] | |
| prompt_full = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = qwen_processor( | |
| text=[prompt_full], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True | |
| ).to(device) | |
| generated_ids = qwen_model.generate( | |
| **inputs, | |
| max_new_tokens=ocr_max_tokens, | |
| do_sample=True, | |
| temperature=ocr_temp, | |
| top_p=0.9, | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| extracted_text = qwen_processor.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |
| extracted_text = extracted_text.replace("<|im_end|>", "").strip() | |
| progress(0.5, desc=f"Analyzing completed. Converting to speech ({len(extracted_text)} chars)...") | |
| if not extracted_text: | |
| return extracted_text, None, "produced no text." | |
| try: | |
| full_script = extracted_text.replace("'", "'").replace('"', '"').replace('"', '"') | |
| voice_path = VOICE_MAPPER.get_voice_path(speaker_name) | |
| if not voice_path: | |
| return extracted_text, None, "Error: Voice file not found." | |
| all_prefilled_outputs = torch.load(voice_path, map_location="cuda", weights_only=False) | |
| tts_inputs = tts_processor.process_input_with_cached_prompt( | |
| text=full_script, | |
| cached_prompt=all_prefilled_outputs, | |
| padding=True, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| tts_model.to("cuda") | |
| for k, v in tts_inputs.items(): | |
| if torch.is_tensor(v): | |
| tts_inputs[k] = v.to("cuda") | |
| with torch.cuda.amp.autocast(): | |
| outputs = tts_model.generate( | |
| **tts_inputs, | |
| max_new_tokens=None, | |
| cfg_scale=cfg_scale, | |
| tokenizer=tts_processor.tokenizer, | |
| generation_config={"do_sample": False}, | |
| verbose=False, | |
| all_prefilled_outputs=copy.deepcopy(all_prefilled_outputs) | |
| ) | |
| tts_model.to("cpu") | |
| torch.cuda.empty_cache() | |
| if outputs.speech_outputs and outputs.speech_outputs[0] is not None: | |
| sample_rate = 24000 | |
| output_dir = "./outputs" | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_path = os.path.join(output_dir, f"generated_{int(time.time())}.wav") | |
| tts_processor.save_audio( | |
| outputs.speech_outputs[0].cpu(), | |
| output_path=output_path, | |
| ) | |
| status = f"✅ Success! Text Length: {len(extracted_text)} chars." | |
| return extracted_text, output_path, status | |
| else: | |
| return extracted_text, None, "TTS Generation failed (no output)." | |
| except Exception as e: | |
| tts_model.to("cpu") | |
| torch.cuda.empty_cache() | |
| import traceback | |
| return extracted_text, None, f"Error during TTS: {str(e)}" | |
| url = "https://huggingface.co/datasets/strangervisionhf/image-examples/resolve/main/2.jpg?download=true" | |
| example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# **Vision-to-VibeVoice-en**", elem_id="main-title") | |
| gr.Markdown("Perform vision-to-audio inference with [Qwen2.5VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) + [VibeVoice-Realtime-0.5B](https://huggingface.co/microsoft/VibeVoice-Realtime-0.5B).") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Vision Input") | |
| image_upload = gr.Image(type="pil", label="Upload Image", value=example_image, height=300) | |
| image_query = gr.Textbox(label="Enter the prompt", value="Give a short description indicating whether the image is safe or unsafe.", placeholder="E.g., Read this page...") | |
| gr.Markdown("### 2. Voice Settings") | |
| voice_choices = list(VOICE_MAPPER.available_voices.keys()) | |
| if not voice_choices: voice_choices = ["Default"] | |
| speaker_dropdown = gr.Dropdown( | |
| choices=voice_choices, | |
| value=voice_choices[0], | |
| label="Speaker Voice" | |
| ) | |
| cfg_slider = gr.Slider(minimum=1.0, maximum=3.0, value=1.5, step=0.1, label="CFG Scale (Speech Fidelity)") | |
| with gr.Accordion("Advanced Options", open=False): | |
| max_new_tokens = gr.Slider(label="Max Tokens", minimum=128, maximum=4096, step=128, value=1024) | |
| temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.1) | |
| submit_btn = gr.Button("Generate Speech", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 3. Results", elem_id="output-title") | |
| text_output = gr.Textbox( | |
| label="Extracted Text (Editable)", | |
| interactive=True, | |
| lines=11, | |
| ) | |
| audio_output = gr.Audio( | |
| label="Generated Speech", | |
| type="filepath", | |
| interactive=False | |
| ) | |
| status_output = gr.Textbox(label="Status Log", lines=2) | |
| gr.Examples( | |
| examples=[["Caption the image...", "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/venice.jpg"]], | |
| inputs=[image_query, image_upload], | |
| label="Example" | |
| ) | |
| submit_btn.click( | |
| fn=process_pipeline, | |
| inputs=[ | |
| image_upload, | |
| image_query, | |
| speaker_dropdown, | |
| cfg_slider, | |
| max_new_tokens, | |
| temperature | |
| ], | |
| outputs=[text_output, audio_output, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=40).launch(css=css, theme=orange_red_theme, mcp_server=True, ssr_mode=False, show_error=True) |