Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| import torch | |
| from tc5.config import SAMPLE_RATE, HOP_LENGTH | |
| from tc5.model import TaikoConformer5 | |
| from tc5 import infer as tc5infer | |
| from tc6.model import TaikoConformer6 | |
| from tc6 import infer as tc6infer | |
| from tc7.model import TaikoConformer7 | |
| from tc7 import infer as tc7infer | |
| from gradio_client import Client, handle_file | |
| import tempfile | |
| GPU_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model once | |
| tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5") | |
| tc5.to(GPU_DEVICE) | |
| tc5.eval() | |
| tc5_cpu = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5") | |
| tc5_cpu.to("cpu") | |
| tc5_cpu.eval() | |
| # Load TC6 model | |
| tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6") | |
| tc6.to(GPU_DEVICE) | |
| tc6.eval() | |
| tc6_cpu = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6") | |
| tc6_cpu.to("cpu") | |
| tc6_cpu.eval() | |
| # Load TC7 model | |
| tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7") | |
| tc7.to(GPU_DEVICE) | |
| tc7.eval() | |
| tc7_cpu = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7") | |
| tc7_cpu.to("cpu") | |
| tc7_cpu.eval() | |
| synthesizer = Client("ryanlinjui/taiko-music-generator") | |
| def infer_tc5(audio, nps, bpm, offset, DEVICE, MODEL): | |
| audio_path = audio | |
| filename = audio_path.split("/")[-1] | |
| # Preprocess | |
| mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps) | |
| # Inference | |
| don_energy, ka_energy, drumroll_energy = tc5infer.run_inference( | |
| MODEL, mel_input, nps_input, DEVICE | |
| ) | |
| output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
| onsets = tc5infer.decode_onsets( | |
| don_energy, | |
| ka_energy, | |
| drumroll_energy, | |
| output_frame_hop_sec, | |
| threshold=0.3, | |
| min_distance_frames=3, | |
| ) | |
| # Generate plot | |
| plot = tc5infer.plot_results( | |
| mel_input, | |
| don_energy, | |
| ka_energy, | |
| drumroll_energy, | |
| onsets, | |
| output_frame_hop_sec, | |
| ) | |
| # Generate TJA content | |
| tja_content = tc5infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset) | |
| # wrtie TJA content to a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: | |
| temp_tja_file.write(tja_content.encode("utf-8")) | |
| tja_path = temp_tja_file.name | |
| result = synthesizer.predict( | |
| param_0=handle_file(tja_path), | |
| param_1=handle_file(audio_path), | |
| param_2="達人譜面 / Master", | |
| param_3=16, | |
| param_4=7, | |
| param_5=5, | |
| param_6=5, | |
| param_7=5, | |
| param_8=5, | |
| param_9=5, | |
| param_10=5, | |
| param_11=5, | |
| param_12=5, | |
| param_13=5, | |
| param_14=5, | |
| param_15=5, | |
| api_name="/handle", | |
| ) | |
| oni_audio = result[1] | |
| return oni_audio, plot, tja_content | |
| def infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL): | |
| audio_path = audio | |
| filename = audio_path.split("/")[-1] | |
| # Preprocess | |
| mel_input = tc6infer.preprocess_audio(audio_path) | |
| nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE) | |
| difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE) | |
| level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE) | |
| # Inference | |
| don_energy, ka_energy, drumroll_energy = tc6infer.run_inference( | |
| MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE | |
| ) | |
| output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
| onsets = tc6infer.decode_onsets( | |
| don_energy, | |
| ka_energy, | |
| drumroll_energy, | |
| output_frame_hop_sec, | |
| threshold=0.3, | |
| min_distance_frames=3, | |
| ) | |
| # Generate plot | |
| plot = tc6infer.plot_results( | |
| mel_input, | |
| don_energy, | |
| ka_energy, | |
| drumroll_energy, | |
| onsets, | |
| output_frame_hop_sec, | |
| ) | |
| # Generate TJA content | |
| tja_content = tc6infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset) | |
| # wrtie TJA content to a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: | |
| temp_tja_file.write(tja_content.encode("utf-8")) | |
| tja_path = temp_tja_file.name | |
| result = synthesizer.predict( | |
| param_0=handle_file(tja_path), | |
| param_1=handle_file(audio_path), | |
| param_2="達人譜面 / Master", | |
| param_3=16, | |
| param_4=7, | |
| param_5=5, | |
| param_6=5, | |
| param_7=5, | |
| param_8=5, | |
| param_9=5, | |
| param_10=5, | |
| param_11=5, | |
| param_12=5, | |
| param_13=5, | |
| param_14=5, | |
| param_15=5, | |
| api_name="/handle", | |
| ) | |
| oni_audio = result[1] | |
| return oni_audio, plot, tja_content | |
| def infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL): | |
| audio_path = audio | |
| filename = audio_path.split("/")[-1] | |
| # Preprocess | |
| mel_input = tc7infer.preprocess_audio(audio_path) | |
| nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE) | |
| difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE) | |
| level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE) | |
| # Inference | |
| don_energy, ka_energy, drumroll_energy = tc7infer.run_inference( | |
| MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE | |
| ) | |
| output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
| onsets = tc7infer.decode_onsets( | |
| don_energy, | |
| ka_energy, | |
| drumroll_energy, | |
| output_frame_hop_sec, | |
| threshold=0.3, | |
| min_distance_frames=3, | |
| ) | |
| # Generate plot | |
| plot = tc7infer.plot_results( | |
| mel_input, | |
| don_energy, | |
| ka_energy, | |
| drumroll_energy, | |
| onsets, | |
| output_frame_hop_sec, | |
| ) | |
| # Generate TJA content | |
| tja_content = tc7infer.write_tja(onsets, bpm=bpm, audio=filename, offset=offset) | |
| # wrtie TJA content to a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file: | |
| temp_tja_file.write(tja_content.encode("utf-8")) | |
| tja_path = temp_tja_file.name | |
| result = synthesizer.predict( | |
| param_0=handle_file(tja_path), | |
| param_1=handle_file(audio_path), | |
| param_2="達人譜面 / Master", | |
| param_3=16, | |
| param_4=7, | |
| param_5=5, | |
| param_6=5, | |
| param_7=5, | |
| param_8=5, | |
| param_9=5, | |
| param_10=5, | |
| param_11=5, | |
| param_12=5, | |
| param_13=5, | |
| param_14=5, | |
| param_15=5, | |
| api_name="/handle", | |
| ) | |
| oni_audio = result[1] | |
| return oni_audio, plot, tja_content | |
| def run_inference_gpu(audio, model_choice, nps, bpm, offset, difficulty, level): | |
| if model_choice == "TC5": | |
| return infer_tc5(audio, nps, bpm, offset, GPU_DEVICE, tc5) | |
| elif model_choice == "TC6": | |
| return infer_tc6(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc6) | |
| else: # TC7 | |
| return infer_tc7(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc7) | |
| def run_inference_cpu(audio, model_choice, nps, bpm, offset, difficulty, level): | |
| DEVICE = torch.device("cpu") | |
| if model_choice == "TC5": | |
| return infer_tc5(audio, nps, bpm, offset, DEVICE, tc5_cpu) | |
| elif model_choice == "TC6": | |
| return infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, tc6_cpu) | |
| else: # TC7 | |
| return infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, tc7_cpu) | |
| def run_inference(with_gpu, audio, model_choice, nps, bpm, offset, difficulty, level): | |
| if with_gpu: | |
| return run_inference_gpu( | |
| audio, model_choice, nps, bpm, offset, difficulty, level | |
| ) | |
| else: | |
| return run_inference_cpu( | |
| audio, model_choice, nps, bpm, offset, difficulty, level | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Taiko Conformer 5/6/7 Demo") | |
| with gr.Row(): | |
| audio_input = gr.Audio(sources="upload", type="filepath", label="Input Audio") | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| choices=["TC5", "TC6", "TC7"], | |
| value="TC7", | |
| label="Model Selection", | |
| info="Choose between TaikoConformer 5, 6 or 7", | |
| ) | |
| with gr.Row(): | |
| nps = gr.Slider( | |
| value=5.0, | |
| minimum=0.5, | |
| maximum=11.0, | |
| step=0.5, | |
| label="NPS (Notes Per Second)", | |
| ) | |
| bpm = gr.Slider( | |
| value=240, | |
| minimum=160, | |
| maximum=640, | |
| step=1, | |
| label="BPM (Used by TJA Quantization)", | |
| ) | |
| offset = gr.Slider( | |
| value=0.0, | |
| minimum=-5.0, | |
| maximum=5.0, | |
| step=0.01, | |
| label="Offset (in seconds)", | |
| info="Adjust the offset for TJA", | |
| ) | |
| with gr.Row(): | |
| difficulty = gr.Slider( | |
| value=3.0, | |
| minimum=1.0, | |
| maximum=3.0, | |
| step=1.0, | |
| label="Difficulty", | |
| visible=False, | |
| info="1=Normal, 2=Hard, 3=Oni", | |
| ) | |
| level = gr.Slider( | |
| value=8.0, | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=1.0, | |
| label="Level", | |
| visible=False, | |
| info="Difficulty level from 1 to 10", | |
| ) | |
| with gr.Row(): | |
| with_gpu = gr.Checkbox( | |
| value=True, | |
| label="Use GPU for Inference", | |
| info="Enable this to use GPU for faster inference (if available)", | |
| ) | |
| run_btn = gr.Button("Run Inference", variant="primary") | |
| audio_output = gr.Audio(label="Generated Audio", type="filepath") | |
| plot_output = gr.Plot(label="Onset/Energy Plot") | |
| tja_output = gr.Textbox(label="TJA File Content", show_copy_button=True) | |
| # Update visibility of TC7-specific controls based on model selection | |
| def update_visibility(model_choice): | |
| if model_choice == "TC7" or model_choice == "TC6": | |
| return gr.update(visible=True), gr.update(visible=True) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=False) | |
| model_choice.change( | |
| update_visibility, inputs=[model_choice], outputs=[difficulty, level] | |
| ) | |
| run_btn.click( | |
| run_inference, | |
| inputs=[ | |
| with_gpu, | |
| audio_input, | |
| model_choice, | |
| nps, | |
| bpm, | |
| offset, | |
| difficulty, | |
| level, | |
| ], | |
| outputs=[audio_output, plot_output, tja_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |