Spaces:
Running
Running
| import os | |
| import spaces | |
| # only debug for hf now | |
| REPO_TYPE = "hf" | |
| if REPO_TYPE not in ["hf", "ms"]: | |
| raise ValueError("REPO_TYPE must be either 'hf' for Hugging Face or 'ms' for ModelScope.") | |
| if REPO_TYPE == "hf": | |
| from huggingface_hub import snapshot_download | |
| else: | |
| from modelscope.hub.snapshot_download import snapshot_download | |
| # 1. 定义本地路径和远程仓库ID | |
| MODEL_CACHE_DIR = "./models" | |
| FUN_ASR_NANO_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "Fun-ASR-Nano") | |
| SENSE_VOICE_SMALL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "SenseVoiceSmall") | |
| VAD_MODEL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "fsmn-vad") | |
| # 创建模型缓存目录 | |
| os.makedirs(MODEL_CACHE_DIR, exist_ok=True) | |
| # 设置ModelScope环境变量以使用本地缓存 | |
| os.environ['MODELSCOPE_CACHE'] = MODEL_CACHE_DIR | |
| # 禁用远程下载,强制使用本地模型(可选,如果想要确保只使用本地模型) | |
| # os.environ['MODELSCOPE_DISABLE_REMOTE'] = '1' | |
| print(f"ModelScope缓存目录设置为: {MODEL_CACHE_DIR}") | |
| if REPO_TYPE == "ms": | |
| FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512" | |
| SENSE_VOICE_SMALL_REPO_ID = "iic/SenseVoiceSmall" | |
| VAD_MODEL_REPO_ID = "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch" | |
| else: | |
| FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512" | |
| SENSE_VOICE_SMALL_REPO_ID = "FunAudioLLM/SenseVoiceSmall" | |
| VAD_MODEL_REPO_ID = "funasr/fsmn-vad" | |
| # 2. 检查本地是否存在,不存在则下载 | |
| def download_model_if_not_exists(repo_id, local_path, model_name): | |
| """如果本地模型不存在,则下载模型""" | |
| if not os.path.exists(local_path): | |
| print(f"正在下载模型 {model_name} 到 {local_path} ...") | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=local_path, | |
| ignore_patterns=["*.onnx"], # 如果你不需要onnx文件,可以过滤掉以节省时间和空间 | |
| ) | |
| print(f"{model_name} 模型下载完毕!") | |
| else: | |
| print(f"检测到本地 {model_name} 模型文件,跳过下载。") | |
| # 下载所有需要的模型 | |
| download_model_if_not_exists(FUN_ASR_NANO_REPO_ID, FUN_ASR_NANO_LOCAL_PATH, "Fun-ASR-Nano") | |
| download_model_if_not_exists(SENSE_VOICE_SMALL_REPO_ID, SENSE_VOICE_SMALL_LOCAL_PATH, "SenseVoiceSmall") | |
| download_model_if_not_exists(VAD_MODEL_REPO_ID, VAD_MODEL_LOCAL_PATH, "VAD Model") | |
| import gradio as gr | |
| import time | |
| import sys | |
| import io | |
| import tempfile | |
| import subprocess | |
| import requests | |
| from urllib.parse import urlparse | |
| from pydub import AudioSegment | |
| import logging | |
| import torch | |
| import importlib | |
| from funasr import AutoModel | |
| from funasr.utils.postprocess_utils import rich_transcription_postprocess | |
| # Model configurations for local deployment | |
| FUN_ASR_NANO_MODEL_PATH_LIST = [ | |
| FUN_ASR_NANO_LOCAL_PATH, # local path | |
| ] | |
| SENSEVOICE_MODEL_PATH_LIST = [ | |
| SENSE_VOICE_SMALL_LOCAL_PATH, # local path | |
| ] | |
| class LogCapture(io.StringIO): | |
| def __init__(self, callback): | |
| super().__init__() | |
| self.callback = callback | |
| def write(self, s): | |
| super().write(s) | |
| self.callback(s) | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Check for CUDA availability | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| logging.info(f"Using device: {device}") | |
| def download_audio(url, method_choice, proxy_url, proxy_username, proxy_password): | |
| """ | |
| Downloads audio from a given URL using the specified method and proxy settings. | |
| Args: | |
| url (str): The URL of the audio. | |
| method_choice (str): The method to use for downloading audio. | |
| proxy_url (str): Proxy URL if needed. | |
| proxy_username (str): Proxy username. | |
| proxy_password (str): Proxy password. | |
| Returns: | |
| tuple: (path to the downloaded audio file, is_temp_file), or (None, False) if failed. | |
| """ | |
| parsed_url = urlparse(url) | |
| logging.info(f"Downloading audio from URL: {url} using method: {method_choice}") | |
| try: | |
| if 'youtube.com' in parsed_url.netloc or 'youtu.be' in parsed_url.netloc: | |
| error_msg = f"YouTube download is not supported. Please use direct audio URLs instead." | |
| logging.error(error_msg) | |
| return None, False | |
| elif parsed_url.scheme == 'rtsp': | |
| audio_file = download_rtsp_audio(url, proxy_url) | |
| if not audio_file: | |
| error_msg = f"Failed to download RTSP audio from {url}" | |
| logging.error(error_msg) | |
| return None, False | |
| else: | |
| audio_file = download_direct_audio(url, method_choice, proxy_url, proxy_username, proxy_password) | |
| if not audio_file: | |
| error_msg = f"Failed to download audio from {url} using method {method_choice}" | |
| logging.error(error_msg) | |
| return None, False | |
| return audio_file, True | |
| except Exception as e: | |
| error_msg = f"Error downloading audio from {url} using method {method_choice}: {str(e)}" | |
| logging.error(error_msg) | |
| return None, False | |
| def download_rtsp_audio(url, proxy_url): | |
| """ | |
| Downloads audio from an RTSP URL using FFmpeg. | |
| Args: | |
| url (str): The RTSP URL. | |
| proxy_url (str): Proxy URL if needed. | |
| Returns: | |
| str: Path to the downloaded audio file, or None if failed. | |
| """ | |
| logging.info("Using FFmpeg to download RTSP stream") | |
| output_file = tempfile.mktemp(suffix='.mp3') | |
| command = ['ffmpeg', '-i', url, '-acodec', 'libmp3lame', '-ab', '192k', '-y', output_file] | |
| env = os.environ.copy() | |
| if proxy_url and len(proxy_url.strip()) > 0: | |
| env['http_proxy'] = proxy_url | |
| env['https_proxy'] = proxy_url | |
| try: | |
| subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) | |
| logging.info(f"Downloaded RTSP audio to: {output_file}") | |
| return output_file | |
| except subprocess.CalledProcessError as e: | |
| logging.error(f"FFmpeg error: {e.stderr.decode()}") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error downloading RTSP audio: {str(e)}") | |
| return None | |
| def download_direct_audio(url, method_choice, proxy_url, proxy_username, proxy_password): | |
| """ | |
| Downloads audio from a direct URL using the specified method. | |
| Args: | |
| url (str): The direct URL of the audio file. | |
| method_choice (str): The method to use for downloading. | |
| proxy_url (str): Proxy URL if needed. | |
| proxy_username (str): Proxy username. | |
| proxy_password (str): Proxy password. | |
| Returns: | |
| str: Path to the downloaded audio file, or None if failed. | |
| """ | |
| logging.info(f"Downloading direct audio from: {url} using method: {method_choice}") | |
| methods = { | |
| 'wget': wget_method, | |
| 'requests': requests_method, | |
| 'ffmpeg': ffmpeg_method, | |
| 'aria2': aria2_method, | |
| } | |
| method = methods.get(method_choice, requests_method) | |
| try: | |
| audio_file = method(url, proxy_url, proxy_username, proxy_password) | |
| if not audio_file or not os.path.exists(audio_file): | |
| error_msg = f"Failed to download direct audio from {url} using method {method_choice}" | |
| logging.error(error_msg) | |
| return None | |
| return audio_file | |
| except Exception as e: | |
| logging.error(f"Error downloading direct audio with {method_choice}: {str(e)}") | |
| return None | |
| def requests_method(url, proxy_url, proxy_username, proxy_password): | |
| """ | |
| Downloads audio using the requests library. | |
| Args: | |
| url (str): The URL of the audio file. | |
| proxy_url (str): Proxy URL if needed. | |
| proxy_username (str): Proxy username. | |
| proxy_password (str): Proxy password. | |
| Returns: | |
| str: Path to the downloaded audio file, or None if failed. | |
| """ | |
| try: | |
| proxies = None | |
| auth = None | |
| if proxy_url and len(proxy_url.strip()) > 0: | |
| proxies = { | |
| "http": proxy_url, | |
| "https": proxy_url | |
| } | |
| if proxy_username and proxy_password: | |
| auth = (proxy_username, proxy_password) | |
| response = requests.get(url, stream=True, proxies=proxies, auth=auth) | |
| if response.status_code == 200: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| temp_file.write(chunk) | |
| logging.info(f"Downloaded direct audio to: {temp_file.name}") | |
| return temp_file.name | |
| else: | |
| logging.error(f"Failed to download audio from {url} with status code {response.status_code}") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error in requests_method: {str(e)}") | |
| return None | |
| def wget_method(url, proxy_url, proxy_username, proxy_password): | |
| """ | |
| Downloads audio using the wget command-line tool. | |
| Args: | |
| url (str): The URL of the audio file. | |
| proxy_url (str): Proxy URL if needed. | |
| proxy_username (str): Proxy username. | |
| proxy_password (str): Proxy password. | |
| Returns: | |
| str: Path to the downloaded audio file, or None if failed. | |
| """ | |
| logging.info("Using wget method") | |
| output_file = tempfile.mktemp(suffix='.mp3') | |
| command = ['wget', '-O', output_file, url] | |
| env = os.environ.copy() | |
| if proxy_url and len(proxy_url.strip()) > 0: | |
| env['http_proxy'] = proxy_url | |
| env['https_proxy'] = proxy_url | |
| try: | |
| subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) | |
| logging.info(f"Downloaded audio to: {output_file}") | |
| return output_file | |
| except subprocess.CalledProcessError as e: | |
| logging.error(f"Wget error: {e.stderr.decode()}") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error in wget_method: {str(e)}") | |
| return None | |
| def ffmpeg_method(url, proxy_url, proxy_username, proxy_password): | |
| """ | |
| Downloads audio using FFmpeg. | |
| Args: | |
| url (str): The URL of the audio file. | |
| proxy_url (str): Proxy URL if needed. | |
| proxy_username (str): Proxy username. | |
| proxy_password (str): Proxy password. | |
| Returns: | |
| str: Path to the downloaded audio file, or None if failed. | |
| """ | |
| logging.info("Using ffmpeg method") | |
| output_file = tempfile.mktemp(suffix='.mp3') | |
| command = ['ffmpeg', '-i', url, '-vn', '-acodec', 'libmp3lame', '-q:a', '2', output_file] | |
| env = os.environ.copy() | |
| if proxy_url and len(proxy_url.strip()) > 0: | |
| env['http_proxy'] = proxy_url | |
| env['https_proxy'] = proxy_url | |
| try: | |
| subprocess.run(command, check=True, capture_output=True, text=True, env=env) | |
| logging.info(f"Downloaded and converted audio to: {output_file}") | |
| return output_file | |
| except subprocess.CalledProcessError as e: | |
| logging.error(f"FFmpeg error: {e.stderr}") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error in ffmpeg_method: {str(e)}") | |
| return None | |
| def aria2_method(url, proxy_url, proxy_username, proxy_password): | |
| """ | |
| Downloads audio using aria2. | |
| Args: | |
| url (str): The URL of the audio file. | |
| proxy_url (str): Proxy URL if needed. | |
| proxy_username (str): Proxy username. | |
| proxy_password (str): Proxy password. | |
| Returns: | |
| str: Path to the downloaded audio file, or None if failed. | |
| """ | |
| logging.info("Using aria2 method") | |
| output_file = tempfile.mktemp(suffix='.mp3') | |
| command = ['aria2c', '--split=4', '--max-connection-per-server=4', '--out', output_file, url] | |
| if proxy_url and len(proxy_url.strip()) > 0: | |
| command.extend(['--all-proxy', proxy_url]) | |
| try: | |
| subprocess.run(command, check=True, capture_output=True, text=True) | |
| logging.info(f"Downloaded audio to: {output_file}") | |
| return output_file | |
| except subprocess.CalledProcessError as e: | |
| logging.error(f"Aria2 error: {e.stderr}") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error in aria2_method: {str(e)}") | |
| return None | |
| def trim_audio(audio_path, start_time, end_time): | |
| """ | |
| Trims an audio file to the specified start and end times. | |
| Args: | |
| audio_path (str): Path to the audio file. | |
| start_time (float): Start time in seconds. | |
| end_time (float): End time in seconds. | |
| Returns: | |
| str: Path to the trimmed audio file. | |
| Raises: | |
| gr.Error: If invalid start or end times are provided. | |
| """ | |
| try: | |
| logging.info(f"Trimming audio from {start_time} to {end_time}") | |
| audio = AudioSegment.from_file(audio_path) | |
| audio_duration = len(audio) / 1000 # Duration in seconds | |
| # Default start and end times if None | |
| start_time = max(0, start_time) if start_time is not None else 0 | |
| end_time = min(audio_duration, end_time) if end_time is not None else audio_duration | |
| # Validate times | |
| if start_time >= end_time: | |
| raise gr.Error("End time must be greater than start time.") | |
| trimmed_audio = audio[int(start_time * 1000):int(end_time * 1000)] | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio_file: | |
| trimmed_audio.export(temp_audio_file.name, format="wav") | |
| logging.info(f"Trimmed audio saved to: {temp_audio_file.name}") | |
| return temp_audio_file.name | |
| except Exception as e: | |
| logging.error(f"Error trimming audio: {str(e)}") | |
| raise gr.Error(f"Error trimming audio: {str(e)}") | |
| def save_transcription(transcription): | |
| """ | |
| Saves the transcription text to a temporary file. | |
| Args: | |
| transcription (str): The transcription text. | |
| Returns: | |
| str: The path to the transcription file. | |
| """ | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.txt', mode='w', encoding='utf-8') as temp_file: | |
| temp_file.write(transcription) | |
| logging.info(f"Transcription saved to: {temp_file.name}") | |
| return temp_file.name | |
| def get_model_options(pipeline_type): | |
| """ | |
| Returns a list of model IDs based on the selected pipeline type. | |
| Args: | |
| pipeline_type (str): The type of pipeline. | |
| Returns: | |
| list: A list of model IDs. | |
| """ | |
| if pipeline_type == "fun-asr-nano": | |
| return FUN_ASR_NANO_MODEL_PATH_LIST | |
| elif pipeline_type == "sensevoice": | |
| return SENSEVOICE_MODEL_PATH_LIST | |
| else: | |
| return [] | |
| # if pipeline_type == "sensevoice": | |
| # return SENSEVOICE_MODEL_PATH_LIST | |
| # else: | |
| # return [] | |
| # Dictionary to store loaded models | |
| loaded_models = {} | |
| def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_password, pipeline_type, model_id, download_method, start_time=None, end_time=None, verbose=False): | |
| """ | |
| Transcribes audio from a given source using SenseVoice. | |
| Args: | |
| audio_input (str): Path to uploaded audio file or recorded audio. | |
| audio_url (str): URL of audio. | |
| proxy_url (str): Proxy URL if needed. | |
| proxy_username (str): Proxy username. | |
| proxy_password (str): Proxy password. | |
| pipeline_type (str): Type of pipeline to use ('sensevoice'). | |
| model_id (str): The ID of the model to use. | |
| download_method (str): Method to use for downloading audio. | |
| start_time (float, optional): Start time in seconds for trimming audio. | |
| end_time (float, optional): End time in seconds for trimming audio. | |
| verbose (bool, optional): Whether to output verbose logging. | |
| Yields: | |
| Tuple[str, str, str or None]: Metrics and messages, transcription text, path to transcription file. | |
| """ | |
| try: | |
| if verbose: | |
| logging.getLogger().setLevel(logging.INFO) | |
| else: | |
| logging.getLogger().setLevel(logging.WARNING) | |
| logging.info(f"Transcription parameters: pipeline_type={pipeline_type}, model_id={model_id}, download_method={download_method}") | |
| verbose_messages = f"Starting transcription with parameters:\nPipeline Type: {pipeline_type}\nModel ID: {model_id}\nDownload Method: {download_method}\n" | |
| if verbose: | |
| yield verbose_messages, "", None | |
| # Determine the audio source | |
| audio_path = None | |
| is_temp_file = False | |
| if audio_input is not None and len(audio_input) > 0: | |
| # audio_input is a filepath to uploaded or recorded audio | |
| audio_path = audio_input | |
| is_temp_file = False | |
| elif audio_url is not None and len(audio_url.strip()) > 0: | |
| # audio_url is provided | |
| audio_path, is_temp_file = download_audio(audio_url, download_method, proxy_url, proxy_username, proxy_password) | |
| if not audio_path: | |
| error_msg = f"Error downloading audio from {audio_url} using method {download_method}. Check logs for details." | |
| logging.error(error_msg) | |
| yield verbose_messages + error_msg, "", None | |
| return | |
| else: | |
| verbose_messages += f"Successfully downloaded audio from {audio_url}\n" | |
| if verbose: | |
| yield verbose_messages, "", None | |
| else: | |
| error_msg = "No audio source provided. Please upload an audio file, record audio, or enter a URL." | |
| logging.error(error_msg) | |
| yield verbose_messages + error_msg, "", None | |
| return | |
| # Convert start_time and end_time to float or None | |
| start_time = float(start_time) if start_time else None | |
| end_time = float(end_time) if end_time else None | |
| if start_time is not None or end_time is not None: | |
| audio_path = trim_audio(audio_path, start_time, end_time) | |
| is_temp_file = True # The trimmed audio is a temporary file | |
| verbose_messages += f"Audio trimmed from {start_time} to {end_time}\n" | |
| if verbose: | |
| yield verbose_messages, "", None | |
| # Model caching | |
| model_key = (pipeline_type, model_id) | |
| if model_key in loaded_models: | |
| model = loaded_models[model_key] | |
| logging.info("Loaded model from cache") | |
| else: | |
| if pipeline_type == "fun-asr-nano": | |
| model = AutoModel( | |
| model=model_id, | |
| trust_remote_code=True, | |
| remote_code=f"./Fun-ASR/model.py", | |
| vad_model=VAD_MODEL_LOCAL_PATH, # Use local VAD model path | |
| vad_kwargs={"max_single_segment_time": 30000}, | |
| device=device, | |
| disable_update=True, | |
| hub='ms', | |
| ) | |
| elif pipeline_type == "sensevoice": | |
| model = AutoModel( | |
| model=model_id, | |
| trust_remote_code=False, | |
| vad_model=VAD_MODEL_LOCAL_PATH, # Use local VAD model path | |
| vad_kwargs={"max_single_segment_time": 30000}, | |
| device=device, | |
| disable_update=True, | |
| hub='ms', | |
| ) | |
| else: | |
| error_msg = "Invalid pipeline type. Only 'sensevoice' is supported." | |
| logging.error(error_msg) | |
| yield verbose_messages + error_msg, "", None | |
| return | |
| loaded_models[model_key] = model | |
| # Perform the transcription | |
| start_time_perf = time.time() | |
| if pipeline_type == "fun-asr-nano": | |
| system_prompt = "You are a helpful assistant." | |
| user_prompt = f"语音转写:<|startofspeech|>!{audio_path}<|endofspeech|>" | |
| contents_i = [] | |
| contents_i.append({"role": "system", "content": system_prompt}) | |
| contents_i.append({"role": "user", "content": user_prompt}) | |
| contents_i.append({"role": "assistant", "content": "null"}) | |
| print(audio_path) | |
| res = model.generate( | |
| input=[audio_path], | |
| use_itn=True, | |
| batch_size=1, | |
| ) | |
| elif pipeline_type == "sensevoice": | |
| res = model.generate( | |
| input=audio_path, | |
| cache={}, | |
| language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" | |
| use_itn=True, | |
| batch_size_s=60, | |
| merge_vad=True, | |
| merge_length_s=15, | |
| ) | |
| transcription = rich_transcription_postprocess(res[0]["text"]) | |
| end_time_perf = time.time() | |
| # Calculate metrics | |
| transcription_time = end_time_perf - start_time_perf | |
| audio_file_size = os.path.getsize(audio_path) / (1024 * 1024) | |
| metrics_output = ( | |
| f"Transcription time: {transcription_time:.2f} seconds\n" | |
| f"Audio file size: {audio_file_size:.2f} MB\n" | |
| ) | |
| # Save the transcription to a file | |
| transcription_file = save_transcription(transcription) | |
| # Always yield the final result, regardless of verbose setting | |
| final_metrics = verbose_messages + metrics_output | |
| yield final_metrics, transcription, transcription_file | |
| except Exception as e: | |
| error_msg = f"An error occurred during transcription: {str(e)}" | |
| logging.error(error_msg) | |
| yield verbose_messages + error_msg, "", None | |
| finally: | |
| # Clean up temporary audio files | |
| if audio_path and is_temp_file and os.path.exists(audio_path): | |
| os.remove(audio_path) | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# Audio Transcription") | |
| gr.Markdown("Transcribe audio using SenseVoice model with multilingual support.") | |
| with gr.Row(): | |
| audio_input = gr.Audio(label="Upload or Record Audio", sources=["upload", "microphone"], type="filepath") | |
| audio_url = gr.Textbox(label="Or Enter URL of audio file (direct link only, no YouTube)") | |
| transcribe_button = gr.Button("Transcribe") | |
| with gr.Accordion("Advanced Options", open=False): | |
| with gr.Row(): | |
| proxy_url = gr.Textbox(label="Proxy URL", placeholder="Enter proxy URL if needed", value="", lines=1) | |
| proxy_username = gr.Textbox(label="Proxy Username", placeholder="Proxy username (optional)", value="", lines=1) | |
| proxy_password = gr.Textbox(label="Proxy Password", placeholder="Proxy password (optional)", value="", lines=1, type="password") | |
| with gr.Row(): | |
| pipeline_type = gr.Dropdown( | |
| choices=["sensevoice","fun-asr-nano"], | |
| label="Pipeline Type", | |
| value="fun-asr-nano" | |
| ) | |
| model_id = gr.Dropdown( | |
| label="Model", | |
| choices=get_model_options("fun-asr-nano"), | |
| value=FUN_ASR_NANO_MODEL_PATH_LIST[0] # Default to official Local Model | |
| ) | |
| with gr.Row(): | |
| download_method = gr.Dropdown( | |
| choices=["requests", "ffmpeg", "aria2", "wget"], | |
| label="Download Method", | |
| value="requests" | |
| ) | |
| with gr.Row(): | |
| start_time = gr.Number(label="Start Time (seconds)", value=None, minimum=0) | |
| end_time = gr.Number(label="End Time (seconds)", value=None, minimum=0) | |
| verbose = gr.Checkbox(label="Verbose Output", value=False) | |
| with gr.Row(): | |
| metrics_output = gr.Textbox(label="Transcription Metrics and Verbose Messages", lines=10) | |
| transcription_output = gr.Textbox(label="Transcription", lines=10) | |
| transcription_file = gr.File(label="Download Transcription") | |
| def update_model_dropdown(pipeline_type): | |
| """ | |
| Updates the model dropdown choices based on the selected pipeline type. | |
| Args: | |
| pipeline_type (str): The selected pipeline type. | |
| Returns: | |
| gr.update: Updated model dropdown component. | |
| """ | |
| try: | |
| model_choices = get_model_options(pipeline_type) | |
| logging.info(f"Model choices for {pipeline_type}: {model_choices}") | |
| if model_choices: | |
| return gr.update(choices=model_choices, value=model_choices[0], visible=True) | |
| else: | |
| return gr.update(choices=["No models available"], value=None, visible=False) | |
| except Exception as e: | |
| logging.error(f"Error in update_model_dropdown: {str(e)}") | |
| return gr.update(choices=["Error"], value="Error", visible=True) | |
| # Event handler for pipeline_type change | |
| pipeline_type.change(update_model_dropdown, inputs=[pipeline_type], outputs=[model_id]) | |
| def transcribe_with_progress(*args): | |
| # The audio_input is now the first argument | |
| for result in transcribe_audio(*args): | |
| yield result | |
| transcribe_button.click( | |
| transcribe_with_progress, | |
| inputs=[audio_input, audio_url, proxy_url, proxy_username, proxy_password, pipeline_type, model_id, download_method, start_time, end_time, verbose], | |
| outputs=[metrics_output, transcription_output, transcription_file] | |
| ) | |
| # Note: For examples, users should use local audio files or upload their own files | |
| # Examples with specific paths may not work for all users | |
| gr.Markdown(f""" | |
| ### Usage Examples: | |
| 1. **Upload Audio**: Click the "Upload or Record Audio" button to select your audio file | |
| 2. **Select Pipeline Type**: Choose from available pipelines: | |
| - **Fun-ASR-Nano** (default) - Large language model based ASR model | |
| - **SenseVoice** - CTC-based based ASR model with VAD | |
| 3. **Local Testing**: For development, you can use local paths as shown above | |
| Supported languages: | |
| - Fun-ASR-Nano: more than 50 languages and Chinese dialects. | |
| - SenseVoiceSmall:Chinese (zh), English (en), Cantonese (yue), Japanese (ja), Korean (ko). | |
| """) | |
| iface.queue().launch(share=False, debug=True) |