Spaces:
Sleeping
Sleeping
Add offset parameter to TJA writing functions and update inference methods for TC5, TC6, and TC7
db8b2d5
| import time | |
| import torch | |
| import torchaudio | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH | |
| import torch.profiler | |
| # --- PREPROCESSING (match training) --- | |
| def preprocess_audio(audio_path): | |
| wav, sr = torchaudio.load(audio_path) | |
| wav = wav.mean(dim=0) # mono | |
| if sr != SAMPLE_RATE: | |
| wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE) | |
| wav = wav / (wav.abs().max() + 1e-8) # Normalize audio | |
| mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=SAMPLE_RATE, | |
| n_mels=N_MELS, | |
| hop_length=HOP_LENGTH, | |
| n_fft=2048, | |
| ) | |
| mel = mel_transform(wav) | |
| return mel # mel is (N_MELS, T_mel) | |
| # --- INFERENCE --- | |
| def run_inference(model, mel_input, nps_input, difficulty_input, level_input, device): | |
| model.eval() | |
| with torch.no_grad(): | |
| mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel) | |
| nps = nps_input.to(device).unsqueeze(0) # (1,) | |
| difficulty = difficulty_input.to(device).unsqueeze(0) # (1,) | |
| level = level_input.to(device).unsqueeze(0) # (1,) | |
| mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel) | |
| conformer_lengths = torch.tensor( | |
| [mel_cnn_input.shape[-1]], dtype=torch.long, device=device | |
| ) | |
| with torch.profiler.profile( | |
| activities=[ | |
| torch.profiler.ProfilerActivity.CPU, | |
| *( | |
| [torch.profiler.ProfilerActivity.CUDA] | |
| if device.type == "cuda" | |
| else [] | |
| ), | |
| ], | |
| record_shapes=True, | |
| profile_memory=True, | |
| with_stack=False, | |
| with_flops=True, | |
| ) as prof: | |
| out_dict = model(mel_cnn_input, conformer_lengths, nps, difficulty, level) | |
| print( | |
| prof.key_averages().table( | |
| sort_by=( | |
| "self_cuda_memory_usage" | |
| if device.type == "cuda" | |
| else "self_cpu_time_total" | |
| ), | |
| row_limit=20, | |
| ) | |
| ) | |
| energies = out_dict["presence"].squeeze(0).cpu().numpy() | |
| don_energy = energies[:, 0] | |
| ka_energy = energies[:, 1] | |
| drumroll_energy = energies[:, 2] | |
| return don_energy, ka_energy, drumroll_energy | |
| # --- DECODE TO ONSETS --- | |
| def decode_onsets( | |
| don_energy, | |
| ka_energy, | |
| drumroll_energy, | |
| hop_sec, | |
| threshold=0.5, | |
| min_distance_frames=3, | |
| ): | |
| results = [] | |
| T_out = len(don_energy) | |
| last_onset_frame = -min_distance_frames | |
| for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection | |
| if i < last_onset_frame + min_distance_frames: | |
| continue | |
| e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i] | |
| energies_at_i = { | |
| 1: e_don, | |
| 2: e_ka, | |
| 5: e_drum, | |
| } # Type mapping: 1:Don, 2:Ka, 5:Drumroll | |
| # Find which energy is max and if it's a peak above threshold | |
| # Sort by energy value descending to prioritize higher energy in case of ties for peak condition | |
| sorted_types_by_energy = sorted( | |
| energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True | |
| ) | |
| detected_this_frame = False | |
| for onset_type in sorted_types_by_energy: | |
| current_energy_series = None | |
| if onset_type == 1: | |
| current_energy_series = don_energy | |
| elif onset_type == 2: | |
| current_energy_series = ka_energy | |
| elif onset_type == 5: | |
| current_energy_series = drumroll_energy | |
| energy_val = current_energy_series[i] | |
| if ( | |
| energy_val > threshold | |
| and energy_val > current_energy_series[i - 1] | |
| and energy_val > current_energy_series[i + 1] | |
| ): | |
| # Check if this energy is the highest among the three at this frame | |
| # This check is implicitly handled by iterating `sorted_types_by_energy` | |
| # and breaking after the first detection. | |
| results.append((i * hop_sec, onset_type)) | |
| last_onset_frame = i | |
| detected_this_frame = True | |
| break # Only one onset type per frame | |
| return results | |
| # --- VISUALIZATION --- | |
| def plot_results( | |
| mel_spectrogram, | |
| don_energy, | |
| ka_energy, | |
| drumroll_energy, | |
| onsets, | |
| hop_sec, | |
| out_path=None, | |
| ): | |
| # mel_spectrogram is (N_MELS, T_mel) | |
| T_mel = mel_spectrogram.shape[1] | |
| T_out = len(don_energy) # Length of energy arrays (model output time dimension) | |
| # Time axes | |
| time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE) | |
| # hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE | |
| # However, the model output T_out is related to T_mel (input to CNN). | |
| # If CNN does not change time dimension, T_out = T_mel. | |
| # If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB. | |
| # The `lengths` passed to conformer in `run_inference` is T_mel. | |
| # The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`. | |
| # So, T_out from model is T_mel. | |
| # The `hop_sec` for onsets should be based on the model output frame rate. | |
| # If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE. | |
| # The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE. | |
| # This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels. | |
| # The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`. | |
| # In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN. | |
| # The `lengths` for the conformer is based on this T_cnn_out. | |
| # So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps. | |
| # Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer. | |
| # Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames. | |
| time_axis_energies = np.arange(T_out) * hop_sec | |
| fig, ax1 = plt.subplots(figsize=(100, 10)) | |
| # Plot Mel Spectrogram on ax1 | |
| mel_db = torchaudio.functional.amplitude_to_DB( | |
| mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0 | |
| ) | |
| img = ax1.imshow( | |
| mel_db.numpy(), | |
| aspect="auto", | |
| origin="lower", | |
| cmap="magma", | |
| extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS], | |
| ) | |
| ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets") | |
| ax1.set_xlabel("Time (s)") | |
| ax1.set_ylabel("Mel Bin") | |
| fig.colorbar(img, ax=ax1, format="%+2.0f dB") | |
| # Create a second y-axis for energies | |
| ax2 = ax1.twinx() | |
| ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red") | |
| ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue") | |
| ax2.plot( | |
| time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green" | |
| ) | |
| ax2.set_ylabel("Energy") | |
| ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded | |
| # Overlay onsets from decode_onsets (t is already in seconds) | |
| labeled_types = set() | |
| # Group drumrolls into segments (reuse logic from write_tja) | |
| drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5] | |
| drumroll_times.sort() | |
| drumroll_segments = [] | |
| if drumroll_times: | |
| seg_start = drumroll_times[0] | |
| prev = drumroll_times[0] | |
| for t in drumroll_times[1:]: | |
| if t - prev <= hop_sec * 6: # up to 5-frame gap | |
| prev = t | |
| else: | |
| drumroll_segments.append((seg_start, prev)) | |
| seg_start = t | |
| prev = t | |
| drumroll_segments.append((seg_start, prev)) | |
| # Plot Don/Ka onsets as vertical lines | |
| for t_sec, typ in onsets: | |
| if typ == 5: | |
| continue # skip drumroll onsets | |
| color_map = {1: "darkred", 2: "darkblue"} | |
| label_map = {1: "Don Onset", 2: "Ka Onset"} | |
| line_color = color_map.get(typ, "black") | |
| line_label = label_map.get(typ, f"Type {typ} Onset") | |
| if typ not in labeled_types: | |
| ax1.axvline( | |
| t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label | |
| ) | |
| labeled_types.add(typ) | |
| else: | |
| ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9) | |
| # Plot drumroll segments as shaded regions | |
| for seg_start, seg_end in drumroll_segments: | |
| ax1.axvspan( | |
| seg_start, | |
| seg_end + hop_sec, | |
| color="green", | |
| alpha=0.2, | |
| label="Drumroll Segment" if "drumroll" not in labeled_types else None, | |
| ) | |
| labeled_types.add("drumroll") | |
| # Combine legends from both axes | |
| lines, labels = ax1.get_legend_handles_labels() | |
| lines2, labels2 = ax2.get_legend_handles_labels() | |
| ax2.legend(lines + lines2, labels + labels2, loc="upper right") | |
| fig.tight_layout() | |
| # Return plot as image buffer or save to file if path provided | |
| if out_path: | |
| plt.savefig(out_path) | |
| print(f"Saved plot to {out_path}") | |
| plt.close(fig) | |
| return out_path | |
| else: | |
| # Return plot as in-memory buffer | |
| return fig | |
| def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav", offset=0): | |
| # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd | |
| # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single) | |
| sec_per_beat = 60 / bpm | |
| beats_per_measure = 4 # Assuming 4/4 time signature | |
| sec_per_measure = sec_per_beat * beats_per_measure | |
| # Step 1: Map onsets to (measure_idx, slot, typ) | |
| slot_events = [] | |
| for t, typ in onsets: | |
| measure_idx = int(t // sec_per_measure) | |
| t_in_measure = t % sec_per_measure | |
| slot = int(round(t_in_measure / sec_per_measure * quantize)) | |
| if slot >= quantize: | |
| slot = quantize - 1 | |
| slot_events.append((measure_idx, slot, typ)) | |
| # Step 2: Build measure/slot grid | |
| if slot_events: | |
| max_measure_idx = max(m for m, _, _ in slot_events) | |
| else: | |
| max_measure_idx = -1 | |
| measures = {i: [0] * quantize for i in range(max_measure_idx + 1)} | |
| # Step 3: Place Don/Ka, collect drumrolls | |
| drumroll_slots = set() | |
| for m, s, typ in slot_events: | |
| if typ in [1, 2]: | |
| measures[m][s] = typ | |
| elif typ == 5: | |
| drumroll_slots.add((m, s)) | |
| # Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end) | |
| # Flatten all slots to a list of (measure, slot) sorted | |
| drumroll_list = sorted(list(drumroll_slots)) | |
| # Group into contiguous regions (allowing a gap of 5 slots) | |
| grouped = [] | |
| group = [] | |
| for ms in drumroll_list: | |
| if not group: | |
| group = [ms] | |
| else: | |
| last_m, last_s = group[-1] | |
| m, s = ms | |
| # Calculate slot distance, considering measure wrap | |
| slot_dist = None | |
| if m == last_m: | |
| slot_dist = s - last_s | |
| elif m == last_m + 1 and last_s <= quantize - 1: | |
| slot_dist = (quantize - 1 - last_s) + s + 1 | |
| else: | |
| slot_dist = None | |
| # Allow gap of up to 5 slots (slot_dist <= 6) | |
| if slot_dist is not None and 1 <= slot_dist <= 6: | |
| group.append(ms) | |
| else: | |
| grouped.append(group) | |
| group = [ms] | |
| if group: | |
| grouped.append(group) | |
| # Mark 5 (start) and 8 (end) for each group | |
| for region in grouped: | |
| if len(region) == 1: | |
| m, s = region[0] | |
| measures[m][s] = 5 | |
| # Place 8 in next slot (or next measure if at end) | |
| if s < quantize - 1: | |
| measures[m][s + 1] = 8 | |
| elif m < max_measure_idx: | |
| measures[m + 1][0] = 8 | |
| else: | |
| m_start, s_start = region[0] | |
| m_end, s_end = region[-1] | |
| measures[m_start][s_start] = 5 | |
| measures[m_end][s_end] = 8 | |
| # Fill 0 for middle slots (already 0 by default) | |
| # Step 5: Generate TJA content | |
| tja_content = [] | |
| tja_content.append(f"TITLE:{audio} (TC6, {time.strftime('%Y-%m-%d %H:%M:%S')})") | |
| tja_content.append(f"BPM:{bpm}") | |
| tja_content.append(f"WAVE:{audio}") | |
| tja_content.append(f"OFFSET:{offset}") | |
| tja_content.append("COURSE:Oni\nLEVEL:9\n") | |
| tja_content.append("#START") | |
| for i in range(max_measure_idx + 1): | |
| notes = measures.get(i, [0] * quantize) | |
| line = "".join(str(n) for n in notes) | |
| tja_content.append(line + ",") | |
| tja_content.append("#END") | |
| tja_string = "\n".join(tja_content) | |
| # If out_path is provided, also write to file | |
| if out_path: | |
| with open(out_path, "w", encoding="utf-8") as f: | |
| f.write(tja_string) | |
| print(f"TJA chart saved to {out_path}") | |
| return tja_string | |