Spaces:
Sleeping
Sleeping
Commit
·
812b01c
1
Parent(s):
7948b62
Implement TaikoConformer7 model, loss function, preprocessing, and training pipeline
Browse files- Added TaikoLoss class for custom loss calculation with NPS penalties.
- Developed TaikoConformer7 model architecture using Conformer and CNN layers.
- Created preprocessing functions to handle audio data and generate labels.
- Implemented training script with data loading, model training, and validation.
- Integrated TensorBoard logging for loss and energy comparisons during training.
- Added support for sliding NPS labels in preprocessing and loss calculation.
- .gitignore +1 -0
- app.py +300 -0
- requirements.txt +14 -0
- tc5/__init__.py +0 -0
- tc5/config.py +25 -0
- tc5/dataset.py +21 -0
- tc5/infer.py +356 -0
- tc5/loss.py +65 -0
- tc5/model.py +133 -0
- tc5/preprocess.py +215 -0
- tc5/train.py +323 -0
- tc6/__init__.py +0 -0
- tc6/config.py +25 -0
- tc6/dataset.py +21 -0
- tc6/infer.py +354 -0
- tc6/loss.py +65 -0
- tc6/model.py +166 -0
- tc6/preprocess.py +258 -0
- tc6/train.py +336 -0
- tc7/__init__.py +0 -0
- tc7/config.py +27 -0
- tc7/dataset.py +21 -0
- tc7/infer.py +354 -0
- tc7/loss.py +94 -0
- tc7/model.py +166 -0
- tc7/preprocess.py +400 -0
- tc7/train.py +300 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
app.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from tc5.config import SAMPLE_RATE, HOP_LENGTH
|
| 4 |
+
from tc5.model import TaikoConformer5
|
| 5 |
+
from tc5 import infer as tc5infer
|
| 6 |
+
from tc6.model import TaikoConformer6
|
| 7 |
+
from tc6 import infer as tc6infer
|
| 8 |
+
from tc7.model import TaikoConformer7
|
| 9 |
+
from tc7 import infer as tc7infer
|
| 10 |
+
from gradio_client import Client, handle_file
|
| 11 |
+
import tempfile
|
| 12 |
+
|
| 13 |
+
DEVICE = torch.device("cpu")
|
| 14 |
+
|
| 15 |
+
# Load model once
|
| 16 |
+
tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
|
| 17 |
+
tc5.to(DEVICE)
|
| 18 |
+
tc5.eval()
|
| 19 |
+
|
| 20 |
+
# Load TC6 model
|
| 21 |
+
tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
|
| 22 |
+
tc6.to(DEVICE)
|
| 23 |
+
tc6.eval()
|
| 24 |
+
|
| 25 |
+
# Load TC7 model
|
| 26 |
+
tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
|
| 27 |
+
tc7.to(DEVICE)
|
| 28 |
+
tc7.eval()
|
| 29 |
+
|
| 30 |
+
synthesizer = Client("ryanlinjui/taiko-music-generator")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def infer_tc5(audio, nps, bpm):
|
| 34 |
+
audio_path = audio
|
| 35 |
+
filename = audio_path.split("/")[-1]
|
| 36 |
+
# Preprocess
|
| 37 |
+
mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps)
|
| 38 |
+
# Inference
|
| 39 |
+
don_energy, ka_energy, drumroll_energy = tc5infer.run_inference(
|
| 40 |
+
tc5, mel_input, nps_input, DEVICE
|
| 41 |
+
)
|
| 42 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 43 |
+
onsets = tc5infer.decode_onsets(
|
| 44 |
+
don_energy,
|
| 45 |
+
ka_energy,
|
| 46 |
+
drumroll_energy,
|
| 47 |
+
output_frame_hop_sec,
|
| 48 |
+
threshold=0.3,
|
| 49 |
+
min_distance_frames=3,
|
| 50 |
+
)
|
| 51 |
+
# Generate plot
|
| 52 |
+
plot = tc5infer.plot_results(
|
| 53 |
+
mel_input,
|
| 54 |
+
don_energy,
|
| 55 |
+
ka_energy,
|
| 56 |
+
drumroll_energy,
|
| 57 |
+
onsets,
|
| 58 |
+
output_frame_hop_sec,
|
| 59 |
+
)
|
| 60 |
+
# Generate TJA content
|
| 61 |
+
tja_content = tc5infer.write_tja(onsets, bpm=bpm, audio=filename)
|
| 62 |
+
|
| 63 |
+
# wrtie TJA content to a temporary file
|
| 64 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
|
| 65 |
+
temp_tja_file.write(tja_content.encode("utf-8"))
|
| 66 |
+
tja_path = temp_tja_file.name
|
| 67 |
+
|
| 68 |
+
result = synthesizer.predict(
|
| 69 |
+
param_0=handle_file(tja_path),
|
| 70 |
+
param_1=handle_file(audio_path),
|
| 71 |
+
param_2="達人譜面 / Master",
|
| 72 |
+
param_3=16,
|
| 73 |
+
param_4=5,
|
| 74 |
+
param_5=5,
|
| 75 |
+
param_6=5,
|
| 76 |
+
param_7=5,
|
| 77 |
+
param_8=5,
|
| 78 |
+
param_9=5,
|
| 79 |
+
param_10=5,
|
| 80 |
+
param_11=5,
|
| 81 |
+
param_12=5,
|
| 82 |
+
param_13=5,
|
| 83 |
+
param_14=5,
|
| 84 |
+
param_15=5,
|
| 85 |
+
api_name="/handle",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
oni_audio = result[1]
|
| 89 |
+
|
| 90 |
+
return oni_audio, plot, tja_content
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def infer_tc6(audio, nps, bpm, difficulty, level):
|
| 94 |
+
audio_path = audio
|
| 95 |
+
filename = audio_path.split("/")[-1]
|
| 96 |
+
# Preprocess
|
| 97 |
+
mel_input = tc6infer.preprocess_audio(audio_path)
|
| 98 |
+
nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE)
|
| 99 |
+
difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE)
|
| 100 |
+
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
| 101 |
+
# Inference
|
| 102 |
+
don_energy, ka_energy, drumroll_energy = tc6infer.run_inference(
|
| 103 |
+
tc6, mel_input, nps_input, difficulty_input, level_input, DEVICE
|
| 104 |
+
)
|
| 105 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 106 |
+
onsets = tc6infer.decode_onsets(
|
| 107 |
+
don_energy,
|
| 108 |
+
ka_energy,
|
| 109 |
+
drumroll_energy,
|
| 110 |
+
output_frame_hop_sec,
|
| 111 |
+
threshold=0.3,
|
| 112 |
+
min_distance_frames=3,
|
| 113 |
+
)
|
| 114 |
+
# Generate plot
|
| 115 |
+
plot = tc6infer.plot_results(
|
| 116 |
+
mel_input,
|
| 117 |
+
don_energy,
|
| 118 |
+
ka_energy,
|
| 119 |
+
drumroll_energy,
|
| 120 |
+
onsets,
|
| 121 |
+
output_frame_hop_sec,
|
| 122 |
+
)
|
| 123 |
+
# Generate TJA content
|
| 124 |
+
tja_content = tc6infer.write_tja(onsets, bpm=bpm, audio=filename)
|
| 125 |
+
|
| 126 |
+
# wrtie TJA content to a temporary file
|
| 127 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
|
| 128 |
+
temp_tja_file.write(tja_content.encode("utf-8"))
|
| 129 |
+
tja_path = temp_tja_file.name
|
| 130 |
+
|
| 131 |
+
result = synthesizer.predict(
|
| 132 |
+
param_0=handle_file(tja_path),
|
| 133 |
+
param_1=handle_file(audio_path),
|
| 134 |
+
param_2="達人譜面 / Master",
|
| 135 |
+
param_3=16,
|
| 136 |
+
param_4=5,
|
| 137 |
+
param_5=5,
|
| 138 |
+
param_6=5,
|
| 139 |
+
param_7=5,
|
| 140 |
+
param_8=5,
|
| 141 |
+
param_9=5,
|
| 142 |
+
param_10=5,
|
| 143 |
+
param_11=5,
|
| 144 |
+
param_12=5,
|
| 145 |
+
param_13=5,
|
| 146 |
+
param_14=5,
|
| 147 |
+
param_15=5,
|
| 148 |
+
api_name="/handle",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
oni_audio = result[1]
|
| 152 |
+
|
| 153 |
+
return oni_audio, plot, tja_content
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def infer_tc7(audio, nps, bpm, difficulty, level):
|
| 157 |
+
audio_path = audio
|
| 158 |
+
filename = audio_path.split("/")[-1]
|
| 159 |
+
# Preprocess
|
| 160 |
+
mel_input = tc7infer.preprocess_audio(audio_path)
|
| 161 |
+
nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE)
|
| 162 |
+
difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE)
|
| 163 |
+
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
| 164 |
+
# Inference
|
| 165 |
+
don_energy, ka_energy, drumroll_energy = tc7infer.run_inference(
|
| 166 |
+
tc7, mel_input, nps_input, difficulty_input, level_input, DEVICE
|
| 167 |
+
)
|
| 168 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 169 |
+
onsets = tc7infer.decode_onsets(
|
| 170 |
+
don_energy,
|
| 171 |
+
ka_energy,
|
| 172 |
+
drumroll_energy,
|
| 173 |
+
output_frame_hop_sec,
|
| 174 |
+
threshold=0.3,
|
| 175 |
+
min_distance_frames=3,
|
| 176 |
+
)
|
| 177 |
+
# Generate plot
|
| 178 |
+
plot = tc7infer.plot_results(
|
| 179 |
+
mel_input,
|
| 180 |
+
don_energy,
|
| 181 |
+
ka_energy,
|
| 182 |
+
drumroll_energy,
|
| 183 |
+
onsets,
|
| 184 |
+
output_frame_hop_sec,
|
| 185 |
+
)
|
| 186 |
+
# Generate TJA content
|
| 187 |
+
tja_content = tc7infer.write_tja(onsets, bpm=bpm, audio=filename)
|
| 188 |
+
|
| 189 |
+
# wrtie TJA content to a temporary file
|
| 190 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
|
| 191 |
+
temp_tja_file.write(tja_content.encode("utf-8"))
|
| 192 |
+
tja_path = temp_tja_file.name
|
| 193 |
+
|
| 194 |
+
result = synthesizer.predict(
|
| 195 |
+
param_0=handle_file(tja_path),
|
| 196 |
+
param_1=handle_file(audio_path),
|
| 197 |
+
param_2="達人譜面 / Master",
|
| 198 |
+
param_3=16,
|
| 199 |
+
param_4=5,
|
| 200 |
+
param_5=5,
|
| 201 |
+
param_6=5,
|
| 202 |
+
param_7=5,
|
| 203 |
+
param_8=5,
|
| 204 |
+
param_9=5,
|
| 205 |
+
param_10=5,
|
| 206 |
+
param_11=5,
|
| 207 |
+
param_12=5,
|
| 208 |
+
param_13=5,
|
| 209 |
+
param_14=5,
|
| 210 |
+
param_15=5,
|
| 211 |
+
api_name="/handle",
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
oni_audio = result[1]
|
| 215 |
+
|
| 216 |
+
return oni_audio, plot, tja_content
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def run_inference(audio, model_choice, nps, bpm, difficulty, level):
|
| 220 |
+
if model_choice == "TC5":
|
| 221 |
+
return infer_tc5(audio, nps, bpm)
|
| 222 |
+
elif model_choice == "TC6":
|
| 223 |
+
return infer_tc6(audio, nps, bpm, difficulty, level)
|
| 224 |
+
else: # TC7
|
| 225 |
+
return infer_tc7(audio, nps, bpm, difficulty, level)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
with gr.Blocks() as demo:
|
| 229 |
+
gr.Markdown("# Taiko Conformer 5/7 Demo")
|
| 230 |
+
with gr.Row():
|
| 231 |
+
audio_input = gr.Audio(sources="upload", type="filepath", label="Input Audio")
|
| 232 |
+
|
| 233 |
+
with gr.Row():
|
| 234 |
+
model_choice = gr.Dropdown(
|
| 235 |
+
choices=["TC5", "TC6", "TC7"],
|
| 236 |
+
value="TC7",
|
| 237 |
+
label="Model Selection",
|
| 238 |
+
info="Choose between TaikoConformer 5, 6 or 7",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
with gr.Row():
|
| 242 |
+
nps = gr.Slider(
|
| 243 |
+
value=5.0,
|
| 244 |
+
minimum=0.5,
|
| 245 |
+
maximum=11.0,
|
| 246 |
+
step=0.5,
|
| 247 |
+
label="NPS (Notes Per Second)",
|
| 248 |
+
)
|
| 249 |
+
bpm = gr.Slider(
|
| 250 |
+
value=240,
|
| 251 |
+
minimum=160,
|
| 252 |
+
maximum=640,
|
| 253 |
+
step=1,
|
| 254 |
+
label="BPM (Used by TJA Quantization)",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
with gr.Row():
|
| 258 |
+
difficulty = gr.Slider(
|
| 259 |
+
value=3.0,
|
| 260 |
+
minimum=1.0,
|
| 261 |
+
maximum=3.0,
|
| 262 |
+
step=1.0,
|
| 263 |
+
label="Difficulty",
|
| 264 |
+
visible=False,
|
| 265 |
+
info="1=Normal, 2=Hard, 3=Oni",
|
| 266 |
+
)
|
| 267 |
+
level = gr.Slider(
|
| 268 |
+
value=8.0,
|
| 269 |
+
minimum=1.0,
|
| 270 |
+
maximum=10.0,
|
| 271 |
+
step=1.0,
|
| 272 |
+
label="Level",
|
| 273 |
+
visible=False,
|
| 274 |
+
info="Difficulty level from 1 to 10",
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
audio_output = gr.Audio(label="Generated Audio", type="filepath")
|
| 278 |
+
plot_output = gr.Plot(label="Onset/Energy Plot")
|
| 279 |
+
tja_output = gr.Textbox(label="TJA File Content", show_copy_button=True)
|
| 280 |
+
run_btn = gr.Button("Run Inference")
|
| 281 |
+
|
| 282 |
+
# Update visibility of TC7-specific controls based on model selection
|
| 283 |
+
def update_visibility(model_choice):
|
| 284 |
+
if model_choice == "TC7" or model_choice == "TC6":
|
| 285 |
+
return gr.update(visible=True), gr.update(visible=True)
|
| 286 |
+
else:
|
| 287 |
+
return gr.update(visible=False), gr.update(visible=False)
|
| 288 |
+
|
| 289 |
+
model_choice.change(
|
| 290 |
+
update_visibility, inputs=[model_choice], outputs=[difficulty, level]
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
run_btn.click(
|
| 294 |
+
run_inference,
|
| 295 |
+
inputs=[audio_input, model_choice, nps, bpm, difficulty, level],
|
| 296 |
+
outputs=[audio_output, plot_output, tja_output],
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if __name__ == "__main__":
|
| 300 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
datasets
|
| 4 |
+
huggingface_hub
|
| 5 |
+
librosa
|
| 6 |
+
soundfile
|
| 7 |
+
matplotlib
|
| 8 |
+
tensorboard
|
| 9 |
+
black
|
| 10 |
+
tqdm
|
| 11 |
+
safetensors
|
| 12 |
+
accelerate
|
| 13 |
+
tja
|
| 14 |
+
spaces
|
tc5/__init__.py
ADDED
|
File without changes
|
tc5/config.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# ─── 1) CONFIG ─────────────────────────────────────────────────────
|
| 4 |
+
SAMPLE_RATE = 22050
|
| 5 |
+
N_MELS = 80
|
| 6 |
+
HOP_LENGTH = 256 # ~86 fps
|
| 7 |
+
TIME_SUB = 1
|
| 8 |
+
CNN_CH = 128
|
| 9 |
+
N_HEADS = 4
|
| 10 |
+
D_MODEL = 256
|
| 11 |
+
FF_DIM = 512
|
| 12 |
+
N_LAYERS = 4
|
| 13 |
+
DEPTHWISE_CONV_KERNEL_SIZE = 31
|
| 14 |
+
DROPOUT = 0.1
|
| 15 |
+
HIDDEN_DIM = 64
|
| 16 |
+
N_TYPES = 7
|
| 17 |
+
BATCH_SIZE = 4
|
| 18 |
+
GRAD_ACCUM_STEPS = 4
|
| 19 |
+
LR = 3e-4
|
| 20 |
+
EPOCHS = 30
|
| 21 |
+
DEVICE = (
|
| 22 |
+
"cuda"
|
| 23 |
+
if torch.cuda.is_available()
|
| 24 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 25 |
+
)
|
tc5/dataset.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset, concatenate_datasets
|
| 2 |
+
|
| 3 |
+
# ds1 = load_dataset("JacobLinCool/taiko-2023-1.1", split="train")
|
| 4 |
+
# ds2 = load_dataset("JacobLinCool/taiko-2023-1.2", split="train")
|
| 5 |
+
# ds3 = load_dataset("JacobLinCool/taiko-2023-1.3", split="train")
|
| 6 |
+
# ds4 = load_dataset("JacobLinCool/taiko-2023-1.4", split="train")
|
| 7 |
+
# ds5 = load_dataset("JacobLinCool/taiko-2023-1.5", split="train")
|
| 8 |
+
# ds6 = load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
| 9 |
+
# ds7 = load_dataset("JacobLinCool/taiko-2023-1.7", split="train")
|
| 10 |
+
# ds = concatenate_datasets([ds1, ds2, ds3, ds4, ds5, ds6, ds7]).with_format("torch")
|
| 11 |
+
|
| 12 |
+
# good = list(range(len(ds)))
|
| 13 |
+
# good.remove(1079) # 1079 has file problem
|
| 14 |
+
# ds = ds.select(good)
|
| 15 |
+
|
| 16 |
+
# for local test
|
| 17 |
+
ds = (
|
| 18 |
+
load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
| 19 |
+
.with_format("torch")
|
| 20 |
+
.select(range(10))
|
| 21 |
+
)
|
tc5/infer.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH
|
| 7 |
+
import torch.profiler
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# --- PREPROCESSING (match training) ---
|
| 11 |
+
def preprocess_audio(audio_path, nps=5.0):
|
| 12 |
+
wav, sr = torchaudio.load(audio_path)
|
| 13 |
+
wav = wav.mean(dim=0) # mono
|
| 14 |
+
if sr != SAMPLE_RATE:
|
| 15 |
+
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
|
| 16 |
+
wav = wav / (wav.abs().max() + 1e-8) # Normalize audio
|
| 17 |
+
|
| 18 |
+
nps_tensor = torch.tensor(nps, dtype=torch.float32)
|
| 19 |
+
|
| 20 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
| 21 |
+
sample_rate=SAMPLE_RATE,
|
| 22 |
+
n_mels=N_MELS,
|
| 23 |
+
hop_length=HOP_LENGTH,
|
| 24 |
+
n_fft=2048,
|
| 25 |
+
)
|
| 26 |
+
mel = mel_transform(wav)
|
| 27 |
+
# mel shape is (n_mels, T_mel), unsqueeze for batch later in run_inference
|
| 28 |
+
return mel, nps_tensor # mel is (N_MELS, T_mel)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# --- INFERENCE ---
|
| 32 |
+
def run_inference(model, mel_input, nps_input, device):
|
| 33 |
+
model.eval()
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel)
|
| 36 |
+
nps = nps_input.to(device).unsqueeze(0) # (1,)
|
| 37 |
+
|
| 38 |
+
mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel)
|
| 39 |
+
|
| 40 |
+
conformer_lengths = torch.tensor(
|
| 41 |
+
[mel_cnn_input.shape[-1]], dtype=torch.long, device=device
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
with torch.profiler.profile(
|
| 45 |
+
activities=[
|
| 46 |
+
torch.profiler.ProfilerActivity.CPU,
|
| 47 |
+
*(
|
| 48 |
+
[torch.profiler.ProfilerActivity.CUDA]
|
| 49 |
+
if device.type == "cuda"
|
| 50 |
+
else []
|
| 51 |
+
),
|
| 52 |
+
],
|
| 53 |
+
record_shapes=True,
|
| 54 |
+
profile_memory=True,
|
| 55 |
+
with_stack=False,
|
| 56 |
+
with_flops=True,
|
| 57 |
+
) as prof:
|
| 58 |
+
out_dict = model(mel_cnn_input, conformer_lengths, nps)
|
| 59 |
+
print(
|
| 60 |
+
prof.key_averages().table(
|
| 61 |
+
sort_by=(
|
| 62 |
+
"self_cuda_memory_usage"
|
| 63 |
+
if device.type == "cuda"
|
| 64 |
+
else "self_cpu_time_total"
|
| 65 |
+
),
|
| 66 |
+
row_limit=20,
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
energies = out_dict["presence"].squeeze(0).cpu().numpy()
|
| 71 |
+
|
| 72 |
+
don_energy = energies[:, 0]
|
| 73 |
+
ka_energy = energies[:, 1]
|
| 74 |
+
drumroll_energy = energies[:, 2]
|
| 75 |
+
|
| 76 |
+
return don_energy, ka_energy, drumroll_energy
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# --- DECODE TO ONSETS ---
|
| 80 |
+
def decode_onsets(
|
| 81 |
+
don_energy,
|
| 82 |
+
ka_energy,
|
| 83 |
+
drumroll_energy,
|
| 84 |
+
hop_sec,
|
| 85 |
+
threshold=0.5,
|
| 86 |
+
min_distance_frames=3,
|
| 87 |
+
):
|
| 88 |
+
results = []
|
| 89 |
+
T_out = len(don_energy)
|
| 90 |
+
last_onset_frame = -min_distance_frames
|
| 91 |
+
|
| 92 |
+
for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection
|
| 93 |
+
if i < last_onset_frame + min_distance_frames:
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i]
|
| 97 |
+
energies_at_i = {
|
| 98 |
+
1: e_don,
|
| 99 |
+
2: e_ka,
|
| 100 |
+
5: e_drum,
|
| 101 |
+
} # Type mapping: 1:Don, 2:Ka, 5:Drumroll
|
| 102 |
+
|
| 103 |
+
# Find which energy is max and if it's a peak above threshold
|
| 104 |
+
# Sort by energy value descending to prioritize higher energy in case of ties for peak condition
|
| 105 |
+
sorted_types_by_energy = sorted(
|
| 106 |
+
energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
detected_this_frame = False
|
| 110 |
+
for onset_type in sorted_types_by_energy:
|
| 111 |
+
current_energy_series = None
|
| 112 |
+
if onset_type == 1:
|
| 113 |
+
current_energy_series = don_energy
|
| 114 |
+
elif onset_type == 2:
|
| 115 |
+
current_energy_series = ka_energy
|
| 116 |
+
elif onset_type == 5:
|
| 117 |
+
current_energy_series = drumroll_energy
|
| 118 |
+
|
| 119 |
+
energy_val = current_energy_series[i]
|
| 120 |
+
|
| 121 |
+
if (
|
| 122 |
+
energy_val > threshold
|
| 123 |
+
and energy_val > current_energy_series[i - 1]
|
| 124 |
+
and energy_val > current_energy_series[i + 1]
|
| 125 |
+
):
|
| 126 |
+
# Check if this energy is the highest among the three at this frame
|
| 127 |
+
# This check is implicitly handled by iterating `sorted_types_by_energy`
|
| 128 |
+
# and breaking after the first detection.
|
| 129 |
+
results.append((i * hop_sec, onset_type))
|
| 130 |
+
last_onset_frame = i
|
| 131 |
+
detected_this_frame = True
|
| 132 |
+
break # Only one onset type per frame
|
| 133 |
+
|
| 134 |
+
return results
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# --- VISUALIZATION ---
|
| 138 |
+
def plot_results(
|
| 139 |
+
mel_spectrogram,
|
| 140 |
+
don_energy,
|
| 141 |
+
ka_energy,
|
| 142 |
+
drumroll_energy,
|
| 143 |
+
onsets,
|
| 144 |
+
hop_sec,
|
| 145 |
+
out_path=None,
|
| 146 |
+
):
|
| 147 |
+
# mel_spectrogram is (N_MELS, T_mel)
|
| 148 |
+
T_mel = mel_spectrogram.shape[1]
|
| 149 |
+
T_out = len(don_energy) # Length of energy arrays (model output time dimension)
|
| 150 |
+
|
| 151 |
+
# Time axes
|
| 152 |
+
time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE)
|
| 153 |
+
# hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE
|
| 154 |
+
# However, the model output T_out is related to T_mel (input to CNN).
|
| 155 |
+
# If CNN does not change time dimension, T_out = T_mel.
|
| 156 |
+
# If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB.
|
| 157 |
+
# The `lengths` passed to conformer in `run_inference` is T_mel.
|
| 158 |
+
# The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`.
|
| 159 |
+
# So, T_out from model is T_mel.
|
| 160 |
+
# The `hop_sec` for onsets should be based on the model output frame rate.
|
| 161 |
+
# If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE.
|
| 162 |
+
# The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE.
|
| 163 |
+
# This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels.
|
| 164 |
+
# The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`.
|
| 165 |
+
# In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN.
|
| 166 |
+
# The `lengths` for the conformer is based on this T_cnn_out.
|
| 167 |
+
# So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps.
|
| 168 |
+
# Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer.
|
| 169 |
+
# Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames.
|
| 170 |
+
time_axis_energies = np.arange(T_out) * hop_sec
|
| 171 |
+
|
| 172 |
+
fig, ax1 = plt.subplots(figsize=(100, 10))
|
| 173 |
+
|
| 174 |
+
# Plot Mel Spectrogram on ax1
|
| 175 |
+
mel_db = torchaudio.functional.amplitude_to_DB(
|
| 176 |
+
mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0
|
| 177 |
+
)
|
| 178 |
+
img = ax1.imshow(
|
| 179 |
+
mel_db.numpy(),
|
| 180 |
+
aspect="auto",
|
| 181 |
+
origin="lower",
|
| 182 |
+
cmap="magma",
|
| 183 |
+
extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS],
|
| 184 |
+
)
|
| 185 |
+
ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets")
|
| 186 |
+
ax1.set_xlabel("Time (s)")
|
| 187 |
+
ax1.set_ylabel("Mel Bin")
|
| 188 |
+
fig.colorbar(img, ax=ax1, format="%+2.0f dB")
|
| 189 |
+
|
| 190 |
+
# Create a second y-axis for energies
|
| 191 |
+
ax2 = ax1.twinx()
|
| 192 |
+
ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red")
|
| 193 |
+
ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue")
|
| 194 |
+
ax2.plot(
|
| 195 |
+
time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green"
|
| 196 |
+
)
|
| 197 |
+
ax2.set_ylabel("Energy")
|
| 198 |
+
ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded
|
| 199 |
+
|
| 200 |
+
# Overlay onsets from decode_onsets (t is already in seconds)
|
| 201 |
+
labeled_types = set()
|
| 202 |
+
# Group drumrolls into segments (reuse logic from write_tja)
|
| 203 |
+
drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5]
|
| 204 |
+
drumroll_times.sort()
|
| 205 |
+
drumroll_segments = []
|
| 206 |
+
if drumroll_times:
|
| 207 |
+
seg_start = drumroll_times[0]
|
| 208 |
+
prev = drumroll_times[0]
|
| 209 |
+
for t in drumroll_times[1:]:
|
| 210 |
+
if t - prev <= hop_sec * 6: # up to 5-frame gap
|
| 211 |
+
prev = t
|
| 212 |
+
else:
|
| 213 |
+
drumroll_segments.append((seg_start, prev))
|
| 214 |
+
seg_start = t
|
| 215 |
+
prev = t
|
| 216 |
+
drumroll_segments.append((seg_start, prev))
|
| 217 |
+
# Plot Don/Ka onsets as vertical lines
|
| 218 |
+
for t_sec, typ in onsets:
|
| 219 |
+
if typ == 5:
|
| 220 |
+
continue # skip drumroll onsets
|
| 221 |
+
color_map = {1: "darkred", 2: "darkblue"}
|
| 222 |
+
label_map = {1: "Don Onset", 2: "Ka Onset"}
|
| 223 |
+
line_color = color_map.get(typ, "black")
|
| 224 |
+
line_label = label_map.get(typ, f"Type {typ} Onset")
|
| 225 |
+
if typ not in labeled_types:
|
| 226 |
+
ax1.axvline(
|
| 227 |
+
t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label
|
| 228 |
+
)
|
| 229 |
+
labeled_types.add(typ)
|
| 230 |
+
else:
|
| 231 |
+
ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9)
|
| 232 |
+
# Plot drumroll segments as shaded regions
|
| 233 |
+
for seg_start, seg_end in drumroll_segments:
|
| 234 |
+
ax1.axvspan(
|
| 235 |
+
seg_start,
|
| 236 |
+
seg_end + hop_sec,
|
| 237 |
+
color="green",
|
| 238 |
+
alpha=0.2,
|
| 239 |
+
label="Drumroll Segment" if "drumroll" not in labeled_types else None,
|
| 240 |
+
)
|
| 241 |
+
labeled_types.add("drumroll")
|
| 242 |
+
|
| 243 |
+
# Combine legends from both axes
|
| 244 |
+
lines, labels = ax1.get_legend_handles_labels()
|
| 245 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
| 246 |
+
ax2.legend(lines + lines2, labels + labels2, loc="upper right")
|
| 247 |
+
|
| 248 |
+
fig.tight_layout()
|
| 249 |
+
|
| 250 |
+
# Return plot as image buffer or save to file if path provided
|
| 251 |
+
if out_path:
|
| 252 |
+
plt.savefig(out_path)
|
| 253 |
+
print(f"Saved plot to {out_path}")
|
| 254 |
+
plt.close(fig)
|
| 255 |
+
return out_path
|
| 256 |
+
else:
|
| 257 |
+
# Return plot as in-memory buffer
|
| 258 |
+
return fig
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
|
| 262 |
+
# TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
|
| 263 |
+
# Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
|
| 264 |
+
sec_per_beat = 60 / bpm
|
| 265 |
+
beats_per_measure = 4 # Assuming 4/4 time signature
|
| 266 |
+
sec_per_measure = sec_per_beat * beats_per_measure
|
| 267 |
+
# Step 1: Map onsets to (measure_idx, slot, typ)
|
| 268 |
+
slot_events = []
|
| 269 |
+
for t, typ in onsets:
|
| 270 |
+
measure_idx = int(t // sec_per_measure)
|
| 271 |
+
t_in_measure = t % sec_per_measure
|
| 272 |
+
slot = int(round(t_in_measure / sec_per_measure * quantize))
|
| 273 |
+
if slot >= quantize:
|
| 274 |
+
slot = quantize - 1
|
| 275 |
+
slot_events.append((measure_idx, slot, typ))
|
| 276 |
+
# Step 2: Build measure/slot grid
|
| 277 |
+
if slot_events:
|
| 278 |
+
max_measure_idx = max(m for m, _, _ in slot_events)
|
| 279 |
+
else:
|
| 280 |
+
max_measure_idx = -1
|
| 281 |
+
measures = {i: [0] * quantize for i in range(max_measure_idx + 1)}
|
| 282 |
+
# Step 3: Place Don/Ka, collect drumrolls
|
| 283 |
+
drumroll_slots = set()
|
| 284 |
+
for m, s, typ in slot_events:
|
| 285 |
+
if typ in [1, 2]:
|
| 286 |
+
measures[m][s] = typ
|
| 287 |
+
elif typ == 5:
|
| 288 |
+
drumroll_slots.add((m, s))
|
| 289 |
+
# Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end)
|
| 290 |
+
# Flatten all slots to a list of (measure, slot) sorted
|
| 291 |
+
drumroll_list = sorted(list(drumroll_slots))
|
| 292 |
+
# Group into contiguous regions (allowing a gap of 5 slots)
|
| 293 |
+
grouped = []
|
| 294 |
+
group = []
|
| 295 |
+
for ms in drumroll_list:
|
| 296 |
+
if not group:
|
| 297 |
+
group = [ms]
|
| 298 |
+
else:
|
| 299 |
+
last_m, last_s = group[-1]
|
| 300 |
+
m, s = ms
|
| 301 |
+
# Calculate slot distance, considering measure wrap
|
| 302 |
+
slot_dist = None
|
| 303 |
+
if m == last_m:
|
| 304 |
+
slot_dist = s - last_s
|
| 305 |
+
elif m == last_m + 1 and last_s <= quantize - 1:
|
| 306 |
+
slot_dist = (quantize - 1 - last_s) + s + 1
|
| 307 |
+
else:
|
| 308 |
+
slot_dist = None
|
| 309 |
+
# Allow gap of up to 5 slots (slot_dist <= 6)
|
| 310 |
+
if slot_dist is not None and 1 <= slot_dist <= 6:
|
| 311 |
+
group.append(ms)
|
| 312 |
+
else:
|
| 313 |
+
grouped.append(group)
|
| 314 |
+
group = [ms]
|
| 315 |
+
if group:
|
| 316 |
+
grouped.append(group)
|
| 317 |
+
# Mark 5 (start) and 8 (end) for each group
|
| 318 |
+
for region in grouped:
|
| 319 |
+
if len(region) == 1:
|
| 320 |
+
m, s = region[0]
|
| 321 |
+
measures[m][s] = 5
|
| 322 |
+
# Place 8 in next slot (or next measure if at end)
|
| 323 |
+
if s < quantize - 1:
|
| 324 |
+
measures[m][s + 1] = 8
|
| 325 |
+
elif m < max_measure_idx:
|
| 326 |
+
measures[m + 1][0] = 8
|
| 327 |
+
else:
|
| 328 |
+
m_start, s_start = region[0]
|
| 329 |
+
m_end, s_end = region[-1]
|
| 330 |
+
measures[m_start][s_start] = 5
|
| 331 |
+
measures[m_end][s_end] = 8
|
| 332 |
+
# Fill 0 for middle slots (already 0 by default)
|
| 333 |
+
|
| 334 |
+
# Step 5: Generate TJA content
|
| 335 |
+
tja_content = []
|
| 336 |
+
tja_content.append(f"TITLE:{audio} (TC5, {time.strftime('%Y-%m-%d %H:%M:%S')})")
|
| 337 |
+
tja_content.append(f"BPM:{bpm}")
|
| 338 |
+
tja_content.append(f"WAVE:{audio}")
|
| 339 |
+
tja_content.append("OFFSET:0")
|
| 340 |
+
tja_content.append("COURSE:Oni\nLEVEL:9\n")
|
| 341 |
+
tja_content.append("#START")
|
| 342 |
+
for i in range(max_measure_idx + 1):
|
| 343 |
+
notes = measures.get(i, [0] * quantize)
|
| 344 |
+
line = "".join(str(n) for n in notes)
|
| 345 |
+
tja_content.append(line + ",")
|
| 346 |
+
tja_content.append("#END")
|
| 347 |
+
|
| 348 |
+
tja_string = "\n".join(tja_content)
|
| 349 |
+
|
| 350 |
+
# If out_path is provided, also write to file
|
| 351 |
+
if out_path:
|
| 352 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 353 |
+
f.write(tja_string)
|
| 354 |
+
print(f"TJA chart saved to {out_path}")
|
| 355 |
+
|
| 356 |
+
return tja_string
|
tc5/loss.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TaikoEnergyLoss(nn.Module):
|
| 6 |
+
def __init__(self, reduction="mean"):
|
| 7 |
+
super().__init__()
|
| 8 |
+
# Use 'none' reduction to get element-wise losses, then manually apply masking and reduction
|
| 9 |
+
self.mse_loss = nn.MSELoss(reduction="none")
|
| 10 |
+
self.reduction = reduction
|
| 11 |
+
|
| 12 |
+
def forward(self, outputs, batch):
|
| 13 |
+
"""
|
| 14 |
+
Calculates the MSE loss for energy-based predictions.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
outputs (dict): Model output, containing 'presence' tensor.
|
| 18 |
+
outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
|
| 19 |
+
batch (dict): Batch data from collate_fn, containing true labels and lengths.
|
| 20 |
+
batch['don_labels'], batch['ka_labels'], batch['drumroll_labels'] shape: (B, T)
|
| 21 |
+
batch['lengths'] shape: (B,) - valid sequence lengths for time dimension T.
|
| 22 |
+
Returns:
|
| 23 |
+
torch.Tensor: The calculated loss.
|
| 24 |
+
"""
|
| 25 |
+
pred_energies = outputs["presence"] # (B, T, 3)
|
| 26 |
+
|
| 27 |
+
true_don = batch["don_labels"] # (B, T)
|
| 28 |
+
true_ka = batch["ka_labels"] # (B, T)
|
| 29 |
+
true_drumroll = batch["drumroll_labels"] # (B, T)
|
| 30 |
+
|
| 31 |
+
# Stack true labels to match the structure of pred_energies (B, T, 3)
|
| 32 |
+
true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2)
|
| 33 |
+
|
| 34 |
+
B, T, _ = pred_energies.shape
|
| 35 |
+
|
| 36 |
+
# Create a mask based on batch['lengths'] to ignore padded parts of sequences
|
| 37 |
+
# batch['lengths'] gives the actual length of each sequence in the batch
|
| 38 |
+
# mask shape: (B, T)
|
| 39 |
+
mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
|
| 40 |
+
"lengths"
|
| 41 |
+
].unsqueeze(1)
|
| 42 |
+
# Expand mask to (B, T, 1) to broadcast across the 3 energy channels
|
| 43 |
+
mask_3d = mask_2d.unsqueeze(2)
|
| 44 |
+
|
| 45 |
+
# Calculate element-wise MSE loss
|
| 46 |
+
loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3)
|
| 47 |
+
|
| 48 |
+
# Apply the mask to the loss
|
| 49 |
+
masked_loss = loss_elementwise * mask_3d
|
| 50 |
+
|
| 51 |
+
if self.reduction == "mean":
|
| 52 |
+
# Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
|
| 53 |
+
total_loss = masked_loss.sum()
|
| 54 |
+
num_valid_elements = mask_3d.sum() # Total number of unmasked float values
|
| 55 |
+
if num_valid_elements > 0:
|
| 56 |
+
return total_loss / num_valid_elements
|
| 57 |
+
else:
|
| 58 |
+
# Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
|
| 59 |
+
return torch.tensor(
|
| 60 |
+
0.0, device=pred_energies.device, requires_grad=True
|
| 61 |
+
)
|
| 62 |
+
elif self.reduction == "sum":
|
| 63 |
+
return masked_loss.sum()
|
| 64 |
+
else: # 'none' or any other case
|
| 65 |
+
return masked_loss
|
tc5/model.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchaudio.models import Conformer
|
| 4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 5 |
+
from .config import (
|
| 6 |
+
N_MELS,
|
| 7 |
+
CNN_CH,
|
| 8 |
+
N_HEADS,
|
| 9 |
+
D_MODEL,
|
| 10 |
+
FF_DIM,
|
| 11 |
+
N_LAYERS,
|
| 12 |
+
DROPOUT,
|
| 13 |
+
DEPTHWISE_CONV_KERNEL_SIZE,
|
| 14 |
+
HIDDEN_DIM,
|
| 15 |
+
DEVICE,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TaikoConformer5(nn.Module, PyTorchModelHubMixin):
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super().__init__()
|
| 22 |
+
# 1) CNN frontend: frequency-only pooling
|
| 23 |
+
self.cnn = nn.Sequential(
|
| 24 |
+
nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
|
| 25 |
+
nn.BatchNorm2d(CNN_CH),
|
| 26 |
+
nn.GELU(),
|
| 27 |
+
nn.Dropout2d(DROPOUT),
|
| 28 |
+
nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
|
| 29 |
+
nn.BatchNorm2d(CNN_CH),
|
| 30 |
+
nn.GELU(),
|
| 31 |
+
nn.Dropout2d(DROPOUT),
|
| 32 |
+
)
|
| 33 |
+
feat_dim = CNN_CH * (N_MELS // 4)
|
| 34 |
+
|
| 35 |
+
# 2) Linear projection to model dimension
|
| 36 |
+
self.proj = nn.Linear(feat_dim, D_MODEL)
|
| 37 |
+
|
| 38 |
+
# 3) FiLM conditioning for notes_per_second
|
| 39 |
+
self.film = nn.Linear(1, 2 * D_MODEL)
|
| 40 |
+
|
| 41 |
+
# 4) Conformer encoder
|
| 42 |
+
self.encoder = Conformer(
|
| 43 |
+
input_dim=D_MODEL,
|
| 44 |
+
num_heads=N_HEADS,
|
| 45 |
+
ffn_dim=FF_DIM,
|
| 46 |
+
num_layers=N_LAYERS,
|
| 47 |
+
depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
|
| 48 |
+
dropout=DROPOUT,
|
| 49 |
+
use_group_norm=False,
|
| 50 |
+
convolution_first=False,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# 5) Presence regressor head
|
| 54 |
+
self.presence_regressor = nn.Sequential(
|
| 55 |
+
nn.Dropout(DROPOUT),
|
| 56 |
+
nn.Linear(D_MODEL, HIDDEN_DIM),
|
| 57 |
+
nn.GELU(),
|
| 58 |
+
nn.Dropout(DROPOUT),
|
| 59 |
+
nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
|
| 60 |
+
nn.Sigmoid(), # Output between 0 and 1
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# 6) Initialize weights
|
| 64 |
+
for m in self.modules():
|
| 65 |
+
if isinstance(m, nn.Conv2d):
|
| 66 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 67 |
+
elif isinstance(m, nn.Linear):
|
| 68 |
+
nn.init.xavier_uniform_(m.weight)
|
| 69 |
+
if m.bias is not None:
|
| 70 |
+
nn.init.zeros_(m.bias)
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self, mel: torch.Tensor, lengths: torch.Tensor, notes_per_second: torch.Tensor
|
| 74 |
+
):
|
| 75 |
+
"""
|
| 76 |
+
Args:
|
| 77 |
+
mel: (B, 1, N_MELS, T_mel)
|
| 78 |
+
lengths: (B,) lengths after CNN
|
| 79 |
+
notes_per_second: (B,) stream of control values
|
| 80 |
+
Returns:
|
| 81 |
+
Dict with:
|
| 82 |
+
'presence': (B, T_cnn_out, 4)
|
| 83 |
+
'lengths': lengths
|
| 84 |
+
"""
|
| 85 |
+
# CNN frontend
|
| 86 |
+
x = self.cnn(mel) # (B, C, F, T)
|
| 87 |
+
B, C, F, T = x.size()
|
| 88 |
+
x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
|
| 89 |
+
|
| 90 |
+
# Project to model dimension
|
| 91 |
+
x = self.proj(x) # (B, T, D_MODEL)
|
| 92 |
+
|
| 93 |
+
# FiLM conditioning
|
| 94 |
+
nps = notes_per_second.unsqueeze(-1) # (B, 1)
|
| 95 |
+
gamma_beta = self.film(nps) # (B, 2*D_MODEL)
|
| 96 |
+
gamma, beta = gamma_beta.chunk(2, dim=-1)
|
| 97 |
+
x = gamma.unsqueeze(1) * x + beta.unsqueeze(1)
|
| 98 |
+
|
| 99 |
+
# Conformer encoder
|
| 100 |
+
x, _ = self.encoder(x, lengths=lengths)
|
| 101 |
+
|
| 102 |
+
# Presence prediction
|
| 103 |
+
presence = self.presence_regressor(x)
|
| 104 |
+
return {"presence": presence, "lengths": lengths}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
model = TaikoConformer5().to(device=DEVICE)
|
| 109 |
+
print(model)
|
| 110 |
+
|
| 111 |
+
for name, param in model.named_parameters():
|
| 112 |
+
if param.requires_grad:
|
| 113 |
+
print(f"{name}: {param.numel():,}")
|
| 114 |
+
|
| 115 |
+
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 116 |
+
print(f"Total parameters: {params / 1e6:.2f}M")
|
| 117 |
+
|
| 118 |
+
batch_size = 4
|
| 119 |
+
mel_time_steps = 1024
|
| 120 |
+
input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
|
| 121 |
+
|
| 122 |
+
conformer_lengths = torch.tensor(
|
| 123 |
+
[mel_time_steps] * batch_size, dtype=torch.long
|
| 124 |
+
).to(DEVICE)
|
| 125 |
+
|
| 126 |
+
notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
|
| 127 |
+
DEVICE
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
output = model(input_mel, conformer_lengths, notes_per_second_input)
|
| 131 |
+
print("Output shapes:")
|
| 132 |
+
for key, value in output.items():
|
| 133 |
+
print(f"{key}: {value.shape}")
|
tc5/preprocess.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchaudio
|
| 5 |
+
from torchaudio.transforms import FrequencyMasking
|
| 6 |
+
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
|
| 7 |
+
from .model import TaikoConformer5
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
| 11 |
+
sample_rate=SAMPLE_RATE,
|
| 12 |
+
n_mels=N_MELS,
|
| 13 |
+
hop_length=HOP_LENGTH,
|
| 14 |
+
n_fft=2048,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
freq_mask = FrequencyMasking(freq_mask_param=15)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def preprocess(example, difficulty="oni"):
|
| 22 |
+
wav_tensor = example["audio"]["array"]
|
| 23 |
+
sr = example["audio"]["sampling_rate"]
|
| 24 |
+
# 1) load & resample
|
| 25 |
+
if sr != SAMPLE_RATE:
|
| 26 |
+
wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
|
| 27 |
+
# normalize audio
|
| 28 |
+
wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
|
| 29 |
+
# add random Gaussian noise
|
| 30 |
+
if torch.rand(1).item() < 0.5:
|
| 31 |
+
wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
|
| 32 |
+
# 2) mel: (1, N_MELS, T)
|
| 33 |
+
mel = mel_transform(wav_tensor).unsqueeze(0)
|
| 34 |
+
# apply SpecAugment
|
| 35 |
+
# we don't use time masking since we don't want model to predict notes when they are masked
|
| 36 |
+
mel = freq_mask(mel)
|
| 37 |
+
_, _, T = mel.shape
|
| 38 |
+
# 3) build label sequence of length ceil(T / TIME_SUB)
|
| 39 |
+
T_sub = math.ceil(T / TIME_SUB)
|
| 40 |
+
|
| 41 |
+
# Initialize energy-based labels for Don, Ka, Drumroll
|
| 42 |
+
don_labels = torch.zeros(T_sub, dtype=torch.float32)
|
| 43 |
+
ka_labels = torch.zeros(T_sub, dtype=torch.float32)
|
| 44 |
+
drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
|
| 45 |
+
|
| 46 |
+
# Define exponential decay tail parameters
|
| 47 |
+
tail_length = 40 # number of frames for decay tail
|
| 48 |
+
decay_rate = 8.0 # decay rate parameter, adjust as needed
|
| 49 |
+
tail_kernel = torch.exp(
|
| 50 |
+
-torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
fps = SAMPLE_RATE / HOP_LENGTH
|
| 54 |
+
num_valid_notes = 0
|
| 55 |
+
for onset in example[difficulty]:
|
| 56 |
+
typ, t_start, t_end, *_ = onset
|
| 57 |
+
|
| 58 |
+
# Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
|
| 59 |
+
if typ < 1 or typ > N_TYPES: # Filter out invalid types
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
num_valid_notes += 1
|
| 63 |
+
|
| 64 |
+
f = int(round(t_start.item() * fps))
|
| 65 |
+
idx = f // TIME_SUB
|
| 66 |
+
if 0 <= idx < T_sub:
|
| 67 |
+
# Apply exponential decay kernel to the corresponding energy channel
|
| 68 |
+
# Type 1 and 3 are Don
|
| 69 |
+
if typ == 1 or typ == 3:
|
| 70 |
+
for i, val in enumerate(tail_kernel):
|
| 71 |
+
target_idx = idx + i
|
| 72 |
+
if 0 <= target_idx < T_sub:
|
| 73 |
+
don_labels[target_idx] = max(
|
| 74 |
+
don_labels[target_idx].item(), val.item()
|
| 75 |
+
)
|
| 76 |
+
# Type 2 and 4 are Ka
|
| 77 |
+
elif typ == 2 or typ == 4:
|
| 78 |
+
for i, val in enumerate(tail_kernel):
|
| 79 |
+
target_idx = idx + i
|
| 80 |
+
if 0 <= target_idx < T_sub:
|
| 81 |
+
ka_labels[target_idx] = max(
|
| 82 |
+
ka_labels[target_idx].item(), val.item()
|
| 83 |
+
)
|
| 84 |
+
# Type 5, 6, 7 are Drumroll
|
| 85 |
+
elif typ >= 5 and typ <= 7:
|
| 86 |
+
f_end = int(round(t_end.item() * fps))
|
| 87 |
+
idx_end = f_end // TIME_SUB
|
| 88 |
+
|
| 89 |
+
for dr in range(idx, idx_end):
|
| 90 |
+
if 0 <= dr < T_sub:
|
| 91 |
+
drumroll_labels[dr] = 1.0
|
| 92 |
+
|
| 93 |
+
for i, val in enumerate(tail_kernel):
|
| 94 |
+
target_idx = idx_end + i
|
| 95 |
+
if 0 <= target_idx < T_sub:
|
| 96 |
+
drumroll_labels[target_idx] = max(
|
| 97 |
+
drumroll_labels[target_idx].item(), val.item()
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
|
| 101 |
+
nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
|
| 102 |
+
print(
|
| 103 |
+
f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return {
|
| 107 |
+
"mel": mel,
|
| 108 |
+
"don_labels": don_labels,
|
| 109 |
+
"ka_labels": ka_labels,
|
| 110 |
+
"drumroll_labels": drumroll_labels,
|
| 111 |
+
"nps": torch.tensor(nps, dtype=torch.float32),
|
| 112 |
+
"duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def collate_fn(batch):
|
| 117 |
+
mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS)
|
| 118 |
+
# Extract new energy-based labels
|
| 119 |
+
don_labels_list = [b["don_labels"] for b in batch]
|
| 120 |
+
ka_labels_list = [b["ka_labels"] for b in batch]
|
| 121 |
+
drumroll_labels_list = [b["drumroll_labels"] for b in batch]
|
| 122 |
+
|
| 123 |
+
nps_list = [b["nps"] for b in batch] # Extract NPS
|
| 124 |
+
durations_list = [b["duration_seconds"] for b in batch] # Extract durations
|
| 125 |
+
|
| 126 |
+
# Pad mels
|
| 127 |
+
padded_mels = nn.utils.rnn.pad_sequence(
|
| 128 |
+
mels_list, batch_first=True
|
| 129 |
+
) # (B, T_max, N_MELS)
|
| 130 |
+
# Reshape for CNN: (B, 1, N_MELS, T_max)
|
| 131 |
+
reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
|
| 132 |
+
|
| 133 |
+
# Simulate CNN time downsampling to get output lengths
|
| 134 |
+
dummy_model_for_shape_inference = TaikoConformer5()
|
| 135 |
+
dummy_cnn = dummy_model_for_shape_inference.cnn
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
cnn_out = dummy_cnn(reshaped_mels) # Use reshaped_mels that has batch dim
|
| 138 |
+
_, _, _, T_cnn = cnn_out.shape
|
| 139 |
+
|
| 140 |
+
padded_don_labels = []
|
| 141 |
+
padded_ka_labels = []
|
| 142 |
+
padded_drumroll_labels = []
|
| 143 |
+
# lengths = [] # This was for original presence/type labels, conformer_input_lengths is used for model
|
| 144 |
+
|
| 145 |
+
for i in range(len(batch)):
|
| 146 |
+
d_labels = don_labels_list[i]
|
| 147 |
+
k_labels = ka_labels_list[i]
|
| 148 |
+
dr_labels = drumroll_labels_list[i]
|
| 149 |
+
|
| 150 |
+
item_original_T_sub = d_labels.shape[
|
| 151 |
+
0
|
| 152 |
+
] # Assuming all label types have same original length
|
| 153 |
+
out_len = T_cnn # Target length for labels is T_cnn
|
| 154 |
+
|
| 155 |
+
# Pad or truncate don_labels
|
| 156 |
+
if item_original_T_sub < out_len:
|
| 157 |
+
pad_d = torch.full(
|
| 158 |
+
(out_len - item_original_T_sub,),
|
| 159 |
+
0, # Pad with 0 for energy labels
|
| 160 |
+
dtype=d_labels.dtype,
|
| 161 |
+
device=d_labels.device,
|
| 162 |
+
)
|
| 163 |
+
padded_d = torch.cat([d_labels, pad_d], dim=0)
|
| 164 |
+
else:
|
| 165 |
+
padded_d = d_labels[:out_len]
|
| 166 |
+
padded_don_labels.append(padded_d)
|
| 167 |
+
|
| 168 |
+
# Pad or truncate ka_labels
|
| 169 |
+
if item_original_T_sub < out_len:
|
| 170 |
+
pad_k = torch.full(
|
| 171 |
+
(out_len - item_original_T_sub,),
|
| 172 |
+
0, # Pad with 0 for energy labels
|
| 173 |
+
dtype=k_labels.dtype,
|
| 174 |
+
device=k_labels.device,
|
| 175 |
+
)
|
| 176 |
+
padded_k = torch.cat([k_labels, pad_k], dim=0)
|
| 177 |
+
else:
|
| 178 |
+
padded_k = k_labels[:out_len]
|
| 179 |
+
padded_ka_labels.append(padded_k)
|
| 180 |
+
|
| 181 |
+
# Pad or truncate drumroll_labels
|
| 182 |
+
if item_original_T_sub < out_len:
|
| 183 |
+
pad_dr = torch.full(
|
| 184 |
+
(out_len - item_original_T_sub,),
|
| 185 |
+
0, # Pad with 0 for energy labels
|
| 186 |
+
dtype=dr_labels.dtype,
|
| 187 |
+
device=dr_labels.device,
|
| 188 |
+
)
|
| 189 |
+
padded_dr = torch.cat([dr_labels, pad_dr], dim=0)
|
| 190 |
+
else:
|
| 191 |
+
padded_dr = dr_labels[:out_len]
|
| 192 |
+
padded_drumroll_labels.append(padded_dr)
|
| 193 |
+
|
| 194 |
+
# For Conformer input lengths: lengths of mel sequences after CNN subsampling
|
| 195 |
+
# (Assuming CNN does not subsample in time, T_cnn is effectively T_mel_padded)
|
| 196 |
+
# The `lengths` for the Conformer should be based on the mel input to the conformer part.
|
| 197 |
+
# The existing calculation for conformer_input_lengths seems to relate to TIME_SUB.
|
| 198 |
+
# If the Conformer input itself is not subsampled by TIME_SUB, this might need review.
|
| 199 |
+
# For now, keeping the existing conformer_input_lengths logic as it's outside the scope of label change.
|
| 200 |
+
conformer_input_lengths = [
|
| 201 |
+
math.ceil(mels_list[i].shape[0] / TIME_SUB) for i in range(len(batch))
|
| 202 |
+
]
|
| 203 |
+
conformer_input_lengths = torch.tensor(
|
| 204 |
+
[min(l, T_cnn) for l in conformer_input_lengths], dtype=torch.long
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return {
|
| 208 |
+
"mel": reshaped_mels,
|
| 209 |
+
"don_labels": torch.stack(padded_don_labels),
|
| 210 |
+
"ka_labels": torch.stack(padded_ka_labels),
|
| 211 |
+
"drumroll_labels": torch.stack(padded_drumroll_labels),
|
| 212 |
+
"lengths": conformer_input_lengths, # These are for the Conformer model
|
| 213 |
+
"nps": torch.stack(nps_list),
|
| 214 |
+
"durations": torch.stack(durations_list),
|
| 215 |
+
}
|
tc5/train.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from accelerate.utils import set_seed
|
| 2 |
+
|
| 3 |
+
set_seed(1024)
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from datasets import concatenate_datasets
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import numpy as np
|
| 14 |
+
from .config import (
|
| 15 |
+
BATCH_SIZE,
|
| 16 |
+
DEVICE,
|
| 17 |
+
EPOCHS,
|
| 18 |
+
LR,
|
| 19 |
+
GRAD_ACCUM_STEPS,
|
| 20 |
+
HOP_LENGTH,
|
| 21 |
+
SAMPLE_RATE,
|
| 22 |
+
)
|
| 23 |
+
from .model import TaikoConformer5
|
| 24 |
+
from .dataset import ds
|
| 25 |
+
from .preprocess import preprocess, collate_fn
|
| 26 |
+
from .loss import TaikoEnergyLoss
|
| 27 |
+
from huggingface_hub import upload_folder
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# --- Helper function to log energy plots ---
|
| 31 |
+
def log_energy_plots_to_tensorboard(
|
| 32 |
+
writer,
|
| 33 |
+
tag_prefix,
|
| 34 |
+
epoch,
|
| 35 |
+
pred_don,
|
| 36 |
+
pred_ka,
|
| 37 |
+
pred_drumroll,
|
| 38 |
+
true_don,
|
| 39 |
+
true_ka,
|
| 40 |
+
true_drumroll,
|
| 41 |
+
valid_length, # Actual valid length of the sequence (before padding)
|
| 42 |
+
hop_sec,
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Logs a plot of predicted vs. true energies for one sample to TensorBoard.
|
| 46 |
+
Energies should be 1D numpy arrays for the single sample, up to valid_length.
|
| 47 |
+
"""
|
| 48 |
+
# Ensure data is on CPU and converted to numpy, and select only the valid part
|
| 49 |
+
pred_don = pred_don[:valid_length].detach().cpu().numpy()
|
| 50 |
+
pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
|
| 51 |
+
pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
|
| 52 |
+
true_don = true_don[:valid_length].cpu().numpy()
|
| 53 |
+
true_ka = true_ka[:valid_length].cpu().numpy()
|
| 54 |
+
true_drumroll = true_drumroll[:valid_length].cpu().numpy()
|
| 55 |
+
|
| 56 |
+
time_axis = np.arange(valid_length) * hop_sec
|
| 57 |
+
|
| 58 |
+
fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
|
| 59 |
+
fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)
|
| 60 |
+
|
| 61 |
+
axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
|
| 62 |
+
axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
|
| 63 |
+
axs[0].set_ylabel("Don Energy")
|
| 64 |
+
axs[0].legend()
|
| 65 |
+
axs[0].grid(True)
|
| 66 |
+
|
| 67 |
+
axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
|
| 68 |
+
axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
|
| 69 |
+
axs[1].set_ylabel("Ka Energy")
|
| 70 |
+
axs[1].legend()
|
| 71 |
+
axs[1].grid(True)
|
| 72 |
+
|
| 73 |
+
axs[2].plot(
|
| 74 |
+
time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
|
| 75 |
+
)
|
| 76 |
+
axs[2].plot(
|
| 77 |
+
time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
|
| 78 |
+
)
|
| 79 |
+
axs[2].set_ylabel("Drumroll Energy")
|
| 80 |
+
axs[2].set_xlabel("Time (s)")
|
| 81 |
+
axs[2].legend()
|
| 82 |
+
axs[2].grid(True)
|
| 83 |
+
|
| 84 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
|
| 85 |
+
writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
|
| 86 |
+
plt.close(fig)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def main():
|
| 90 |
+
global ds
|
| 91 |
+
|
| 92 |
+
# Calculate hop seconds for model output frames
|
| 93 |
+
# This assumes the model output time dimension corresponds to the mel spectrogram time dimension
|
| 94 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 95 |
+
|
| 96 |
+
best_val_loss = float("inf")
|
| 97 |
+
patience = 10 # Increased patience a bit
|
| 98 |
+
pat_count = 0
|
| 99 |
+
|
| 100 |
+
ds_oni = ds.map(
|
| 101 |
+
preprocess,
|
| 102 |
+
remove_columns=ds.column_names,
|
| 103 |
+
fn_kwargs={"difficulty": "oni"},
|
| 104 |
+
writer_batch_size=10,
|
| 105 |
+
)
|
| 106 |
+
ds_hard = ds.map(
|
| 107 |
+
preprocess,
|
| 108 |
+
remove_columns=ds.column_names,
|
| 109 |
+
fn_kwargs={"difficulty": "hard"},
|
| 110 |
+
writer_batch_size=10,
|
| 111 |
+
)
|
| 112 |
+
ds_normal = ds.map(
|
| 113 |
+
preprocess,
|
| 114 |
+
remove_columns=ds.column_names,
|
| 115 |
+
fn_kwargs={"difficulty": "normal"},
|
| 116 |
+
writer_batch_size=10,
|
| 117 |
+
)
|
| 118 |
+
ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])
|
| 119 |
+
|
| 120 |
+
ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
|
| 121 |
+
train_loader = DataLoader(
|
| 122 |
+
ds_train_test["train"],
|
| 123 |
+
batch_size=BATCH_SIZE,
|
| 124 |
+
shuffle=True,
|
| 125 |
+
collate_fn=collate_fn,
|
| 126 |
+
num_workers=2,
|
| 127 |
+
)
|
| 128 |
+
val_loader = DataLoader(
|
| 129 |
+
ds_train_test["test"],
|
| 130 |
+
batch_size=BATCH_SIZE,
|
| 131 |
+
shuffle=False,
|
| 132 |
+
collate_fn=collate_fn,
|
| 133 |
+
num_workers=2,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
model = TaikoConformer5().to(DEVICE)
|
| 137 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
| 138 |
+
|
| 139 |
+
criterion = TaikoEnergyLoss(reduction="mean").to(DEVICE)
|
| 140 |
+
|
| 141 |
+
# Adjust scheduler steps for gradient accumulation
|
| 142 |
+
num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
|
| 143 |
+
total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch
|
| 144 |
+
|
| 145 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 146 |
+
optimizer, max_lr=LR, total_steps=total_optimizer_steps
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
writer = SummaryWriter()
|
| 150 |
+
|
| 151 |
+
for epoch in range(1, EPOCHS + 1):
|
| 152 |
+
model.train()
|
| 153 |
+
total_epoch_loss = 0.0
|
| 154 |
+
optimizer.zero_grad()
|
| 155 |
+
|
| 156 |
+
for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
|
| 157 |
+
mel = batch["mel"].to(DEVICE)
|
| 158 |
+
# Unpack new energy-based labels
|
| 159 |
+
don_labels = batch["don_labels"].to(DEVICE)
|
| 160 |
+
ka_labels = batch["ka_labels"].to(DEVICE)
|
| 161 |
+
drumroll_labels = batch["drumroll_labels"].to(DEVICE)
|
| 162 |
+
lengths = batch["lengths"].to(
|
| 163 |
+
DEVICE
|
| 164 |
+
) # These are for the Conformer model output
|
| 165 |
+
nps = batch["nps"].to(DEVICE)
|
| 166 |
+
|
| 167 |
+
output_dict = model(mel, lengths, nps)
|
| 168 |
+
# output_dict["presence"] is now (B, T_out, 3) for don, ka, drumroll energies
|
| 169 |
+
pred_energies_batch = output_dict["presence"] # (B, T_out, 3)
|
| 170 |
+
|
| 171 |
+
loss_input_batch = {
|
| 172 |
+
"don_labels": don_labels,
|
| 173 |
+
"ka_labels": ka_labels,
|
| 174 |
+
"drumroll_labels": drumroll_labels,
|
| 175 |
+
"lengths": lengths, # Pass lengths for masking within the loss function
|
| 176 |
+
}
|
| 177 |
+
loss = criterion(output_dict, loss_input_batch)
|
| 178 |
+
|
| 179 |
+
(loss / GRAD_ACCUM_STEPS).backward()
|
| 180 |
+
|
| 181 |
+
if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
|
| 182 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 183 |
+
optimizer.step()
|
| 184 |
+
scheduler.step()
|
| 185 |
+
optimizer.zero_grad()
|
| 186 |
+
|
| 187 |
+
total_epoch_loss += loss.item()
|
| 188 |
+
|
| 189 |
+
# Log plot for the first sample of the first batch in each training epoch
|
| 190 |
+
if idx == 0:
|
| 191 |
+
first_sample_pred_don = pred_energies_batch[0, :, 0]
|
| 192 |
+
first_sample_pred_ka = pred_energies_batch[0, :, 1]
|
| 193 |
+
first_sample_pred_drumroll = pred_energies_batch[0, :, 2]
|
| 194 |
+
|
| 195 |
+
first_sample_true_don = don_labels[0, :]
|
| 196 |
+
first_sample_true_ka = ka_labels[0, :]
|
| 197 |
+
first_sample_true_drumroll = drumroll_labels[0, :]
|
| 198 |
+
|
| 199 |
+
first_sample_length = lengths[
|
| 200 |
+
0
|
| 201 |
+
].item() # Get the valid length of the first sample
|
| 202 |
+
|
| 203 |
+
log_energy_plots_to_tensorboard(
|
| 204 |
+
writer,
|
| 205 |
+
"Train/Sample_0",
|
| 206 |
+
epoch,
|
| 207 |
+
first_sample_pred_don,
|
| 208 |
+
first_sample_pred_ka,
|
| 209 |
+
first_sample_pred_drumroll,
|
| 210 |
+
first_sample_true_don,
|
| 211 |
+
first_sample_true_ka,
|
| 212 |
+
first_sample_true_drumroll,
|
| 213 |
+
first_sample_length,
|
| 214 |
+
output_frame_hop_sec,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
avg_train_loss = total_epoch_loss / len(train_loader)
|
| 218 |
+
writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)
|
| 219 |
+
|
| 220 |
+
# Validation
|
| 221 |
+
model.eval()
|
| 222 |
+
total_val_loss = 0.0
|
| 223 |
+
# Removed storage for classification logits/labels and confusion matrix components
|
| 224 |
+
|
| 225 |
+
with torch.no_grad():
|
| 226 |
+
for val_idx, batch in enumerate(
|
| 227 |
+
tqdm(val_loader, desc=f"Val Epoch {epoch}")
|
| 228 |
+
):
|
| 229 |
+
mel = batch["mel"].to(DEVICE)
|
| 230 |
+
don_labels = batch["don_labels"].to(DEVICE)
|
| 231 |
+
ka_labels = batch["ka_labels"].to(DEVICE)
|
| 232 |
+
drumroll_labels = batch["drumroll_labels"].to(DEVICE)
|
| 233 |
+
lengths = batch["lengths"].to(DEVICE)
|
| 234 |
+
nps = batch["nps"].to(DEVICE) # Ground truth NPS from batch
|
| 235 |
+
|
| 236 |
+
output_dict = model(mel, lengths, nps)
|
| 237 |
+
pred_energies_val_batch = output_dict["presence"] # (B, T_out, 3)
|
| 238 |
+
|
| 239 |
+
val_loss_input_batch = {
|
| 240 |
+
"don_labels": don_labels,
|
| 241 |
+
"ka_labels": ka_labels,
|
| 242 |
+
"drumroll_labels": drumroll_labels,
|
| 243 |
+
"lengths": lengths,
|
| 244 |
+
}
|
| 245 |
+
val_loss = criterion(output_dict, val_loss_input_batch)
|
| 246 |
+
total_val_loss += val_loss.item()
|
| 247 |
+
|
| 248 |
+
# Log plot for the first sample of the first batch in each validation epoch
|
| 249 |
+
if val_idx == 0:
|
| 250 |
+
first_val_sample_pred_don = pred_energies_val_batch[0, :, 0]
|
| 251 |
+
first_val_sample_pred_ka = pred_energies_val_batch[0, :, 1]
|
| 252 |
+
first_val_sample_pred_drumroll = pred_energies_val_batch[0, :, 2]
|
| 253 |
+
|
| 254 |
+
first_val_sample_true_don = don_labels[0, :]
|
| 255 |
+
first_val_sample_true_ka = ka_labels[0, :]
|
| 256 |
+
first_val_sample_true_drumroll = drumroll_labels[0, :]
|
| 257 |
+
|
| 258 |
+
first_val_sample_length = lengths[0].item()
|
| 259 |
+
|
| 260 |
+
log_energy_plots_to_tensorboard(
|
| 261 |
+
writer,
|
| 262 |
+
"Eval/Sample_0",
|
| 263 |
+
epoch,
|
| 264 |
+
first_val_sample_pred_don,
|
| 265 |
+
first_val_sample_pred_ka,
|
| 266 |
+
first_val_sample_pred_drumroll,
|
| 267 |
+
first_val_sample_true_don,
|
| 268 |
+
first_val_sample_true_ka,
|
| 269 |
+
first_val_sample_true_drumroll,
|
| 270 |
+
first_val_sample_length,
|
| 271 |
+
output_frame_hop_sec,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Log ground truth NPS for reference during validation if needed
|
| 275 |
+
# writer.add_scalar("NPS/GT_Val_Batch_Avg", nps.mean().item(), epoch * len(val_loader) + idx)
|
| 276 |
+
|
| 277 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
| 278 |
+
writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)
|
| 279 |
+
|
| 280 |
+
# Log learning rate
|
| 281 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
| 282 |
+
writer.add_scalar("LR/learning_rate", current_lr, epoch)
|
| 283 |
+
|
| 284 |
+
# Log ground truth NPS from the last validation batch (or mean over epoch)
|
| 285 |
+
if "nps" in batch: # Check if nps is in the last batch
|
| 286 |
+
writer.add_scalar(
|
| 287 |
+
"NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
print(
|
| 291 |
+
f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if avg_val_loss < best_val_loss:
|
| 295 |
+
best_val_loss = avg_val_loss
|
| 296 |
+
pat_count = 0
|
| 297 |
+
torch.save(model.state_dict(), "best_model.pt") # Changed model save name
|
| 298 |
+
print(f"Saved new best model to best_model.pt at epoch {epoch}")
|
| 299 |
+
else:
|
| 300 |
+
pat_count += 1
|
| 301 |
+
if pat_count >= patience:
|
| 302 |
+
print("Early stopping!")
|
| 303 |
+
break
|
| 304 |
+
writer.close()
|
| 305 |
+
|
| 306 |
+
model_id = "JacobLinCool/taiko-conformer-5"
|
| 307 |
+
try:
|
| 308 |
+
model.push_to_hub(model_id, commit_message="Upload trained model")
|
| 309 |
+
upload_folder(
|
| 310 |
+
repo_id=model_id,
|
| 311 |
+
folder_path="runs",
|
| 312 |
+
path_in_repo=".",
|
| 313 |
+
commit_message="Upload training logs",
|
| 314 |
+
ignore_patterns=["*.txt", "*.json", "*.csv"],
|
| 315 |
+
)
|
| 316 |
+
print(f"Model and logs uploaded to {model_id}")
|
| 317 |
+
except Exception as e:
|
| 318 |
+
print(f"Error uploading to Hugging Face Hub: {e}")
|
| 319 |
+
print("Make sure you have the correct permissions and try again.")
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
if __name__ == "__main__":
|
| 323 |
+
main()
|
tc6/__init__.py
ADDED
|
File without changes
|
tc6/config.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# ─── 1) CONFIG ─────────────────────────────────────────────────────
|
| 4 |
+
SAMPLE_RATE = 22050
|
| 5 |
+
N_MELS = 80
|
| 6 |
+
HOP_LENGTH = 256
|
| 7 |
+
TIME_SUB = 1
|
| 8 |
+
CNN_CH = 256
|
| 9 |
+
N_HEADS = 8
|
| 10 |
+
D_MODEL = 512
|
| 11 |
+
FF_DIM = 1024
|
| 12 |
+
N_LAYERS = 6
|
| 13 |
+
DEPTHWISE_CONV_KERNEL_SIZE = 31
|
| 14 |
+
DROPOUT = 0.1
|
| 15 |
+
HIDDEN_DIM = 64
|
| 16 |
+
N_TYPES = 7
|
| 17 |
+
BATCH_SIZE = 2
|
| 18 |
+
GRAD_ACCUM_STEPS = 8
|
| 19 |
+
LR = 3e-4
|
| 20 |
+
EPOCHS = 200
|
| 21 |
+
DEVICE = (
|
| 22 |
+
"cuda"
|
| 23 |
+
if torch.cuda.is_available()
|
| 24 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 25 |
+
)
|
tc6/dataset.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset, concatenate_datasets
|
| 2 |
+
|
| 3 |
+
ds1 = load_dataset("JacobLinCool/taiko-2023-1.1", split="train")
|
| 4 |
+
ds2 = load_dataset("JacobLinCool/taiko-2023-1.2", split="train")
|
| 5 |
+
ds3 = load_dataset("JacobLinCool/taiko-2023-1.3", split="train")
|
| 6 |
+
ds4 = load_dataset("JacobLinCool/taiko-2023-1.4", split="train")
|
| 7 |
+
ds5 = load_dataset("JacobLinCool/taiko-2023-1.5", split="train")
|
| 8 |
+
ds6 = load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
| 9 |
+
ds7 = load_dataset("JacobLinCool/taiko-2023-1.7", split="train")
|
| 10 |
+
ds = concatenate_datasets([ds1, ds2, ds3, ds4, ds5, ds6, ds7]).with_format("torch")
|
| 11 |
+
|
| 12 |
+
good = list(range(len(ds)))
|
| 13 |
+
good.remove(1079) # 1079 has file problem
|
| 14 |
+
ds = ds.select(good)
|
| 15 |
+
|
| 16 |
+
# for local test
|
| 17 |
+
# ds = (
|
| 18 |
+
# load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
| 19 |
+
# .with_format("torch")
|
| 20 |
+
# .select(range(10))
|
| 21 |
+
# )
|
tc6/infer.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH
|
| 7 |
+
import torch.profiler
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# --- PREPROCESSING (match training) ---
|
| 11 |
+
def preprocess_audio(audio_path):
|
| 12 |
+
wav, sr = torchaudio.load(audio_path)
|
| 13 |
+
wav = wav.mean(dim=0) # mono
|
| 14 |
+
if sr != SAMPLE_RATE:
|
| 15 |
+
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
|
| 16 |
+
wav = wav / (wav.abs().max() + 1e-8) # Normalize audio
|
| 17 |
+
|
| 18 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
| 19 |
+
sample_rate=SAMPLE_RATE,
|
| 20 |
+
n_mels=N_MELS,
|
| 21 |
+
hop_length=HOP_LENGTH,
|
| 22 |
+
n_fft=2048,
|
| 23 |
+
)
|
| 24 |
+
mel = mel_transform(wav)
|
| 25 |
+
return mel # mel is (N_MELS, T_mel)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# --- INFERENCE ---
|
| 29 |
+
def run_inference(model, mel_input, nps_input, difficulty_input, level_input, device):
|
| 30 |
+
model.eval()
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel)
|
| 33 |
+
nps = nps_input.to(device).unsqueeze(0) # (1,)
|
| 34 |
+
difficulty = difficulty_input.to(device).unsqueeze(0) # (1,)
|
| 35 |
+
level = level_input.to(device).unsqueeze(0) # (1,)
|
| 36 |
+
|
| 37 |
+
mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel)
|
| 38 |
+
|
| 39 |
+
conformer_lengths = torch.tensor(
|
| 40 |
+
[mel_cnn_input.shape[-1]], dtype=torch.long, device=device
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
with torch.profiler.profile(
|
| 44 |
+
activities=[
|
| 45 |
+
torch.profiler.ProfilerActivity.CPU,
|
| 46 |
+
*(
|
| 47 |
+
[torch.profiler.ProfilerActivity.CUDA]
|
| 48 |
+
if device.type == "cuda"
|
| 49 |
+
else []
|
| 50 |
+
),
|
| 51 |
+
],
|
| 52 |
+
record_shapes=True,
|
| 53 |
+
profile_memory=True,
|
| 54 |
+
with_stack=False,
|
| 55 |
+
with_flops=True,
|
| 56 |
+
) as prof:
|
| 57 |
+
out_dict = model(mel_cnn_input, conformer_lengths, nps, difficulty, level)
|
| 58 |
+
print(
|
| 59 |
+
prof.key_averages().table(
|
| 60 |
+
sort_by=(
|
| 61 |
+
"self_cuda_memory_usage"
|
| 62 |
+
if device.type == "cuda"
|
| 63 |
+
else "self_cpu_time_total"
|
| 64 |
+
),
|
| 65 |
+
row_limit=20,
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
energies = out_dict["presence"].squeeze(0).cpu().numpy()
|
| 70 |
+
|
| 71 |
+
don_energy = energies[:, 0]
|
| 72 |
+
ka_energy = energies[:, 1]
|
| 73 |
+
drumroll_energy = energies[:, 2]
|
| 74 |
+
|
| 75 |
+
return don_energy, ka_energy, drumroll_energy
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# --- DECODE TO ONSETS ---
|
| 79 |
+
def decode_onsets(
|
| 80 |
+
don_energy,
|
| 81 |
+
ka_energy,
|
| 82 |
+
drumroll_energy,
|
| 83 |
+
hop_sec,
|
| 84 |
+
threshold=0.5,
|
| 85 |
+
min_distance_frames=3,
|
| 86 |
+
):
|
| 87 |
+
results = []
|
| 88 |
+
T_out = len(don_energy)
|
| 89 |
+
last_onset_frame = -min_distance_frames
|
| 90 |
+
|
| 91 |
+
for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection
|
| 92 |
+
if i < last_onset_frame + min_distance_frames:
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i]
|
| 96 |
+
energies_at_i = {
|
| 97 |
+
1: e_don,
|
| 98 |
+
2: e_ka,
|
| 99 |
+
5: e_drum,
|
| 100 |
+
} # Type mapping: 1:Don, 2:Ka, 5:Drumroll
|
| 101 |
+
|
| 102 |
+
# Find which energy is max and if it's a peak above threshold
|
| 103 |
+
# Sort by energy value descending to prioritize higher energy in case of ties for peak condition
|
| 104 |
+
sorted_types_by_energy = sorted(
|
| 105 |
+
energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
detected_this_frame = False
|
| 109 |
+
for onset_type in sorted_types_by_energy:
|
| 110 |
+
current_energy_series = None
|
| 111 |
+
if onset_type == 1:
|
| 112 |
+
current_energy_series = don_energy
|
| 113 |
+
elif onset_type == 2:
|
| 114 |
+
current_energy_series = ka_energy
|
| 115 |
+
elif onset_type == 5:
|
| 116 |
+
current_energy_series = drumroll_energy
|
| 117 |
+
|
| 118 |
+
energy_val = current_energy_series[i]
|
| 119 |
+
|
| 120 |
+
if (
|
| 121 |
+
energy_val > threshold
|
| 122 |
+
and energy_val > current_energy_series[i - 1]
|
| 123 |
+
and energy_val > current_energy_series[i + 1]
|
| 124 |
+
):
|
| 125 |
+
# Check if this energy is the highest among the three at this frame
|
| 126 |
+
# This check is implicitly handled by iterating `sorted_types_by_energy`
|
| 127 |
+
# and breaking after the first detection.
|
| 128 |
+
results.append((i * hop_sec, onset_type))
|
| 129 |
+
last_onset_frame = i
|
| 130 |
+
detected_this_frame = True
|
| 131 |
+
break # Only one onset type per frame
|
| 132 |
+
|
| 133 |
+
return results
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# --- VISUALIZATION ---
|
| 137 |
+
def plot_results(
|
| 138 |
+
mel_spectrogram,
|
| 139 |
+
don_energy,
|
| 140 |
+
ka_energy,
|
| 141 |
+
drumroll_energy,
|
| 142 |
+
onsets,
|
| 143 |
+
hop_sec,
|
| 144 |
+
out_path=None,
|
| 145 |
+
):
|
| 146 |
+
# mel_spectrogram is (N_MELS, T_mel)
|
| 147 |
+
T_mel = mel_spectrogram.shape[1]
|
| 148 |
+
T_out = len(don_energy) # Length of energy arrays (model output time dimension)
|
| 149 |
+
|
| 150 |
+
# Time axes
|
| 151 |
+
time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE)
|
| 152 |
+
# hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE
|
| 153 |
+
# However, the model output T_out is related to T_mel (input to CNN).
|
| 154 |
+
# If CNN does not change time dimension, T_out = T_mel.
|
| 155 |
+
# If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB.
|
| 156 |
+
# The `lengths` passed to conformer in `run_inference` is T_mel.
|
| 157 |
+
# The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`.
|
| 158 |
+
# So, T_out from model is T_mel.
|
| 159 |
+
# The `hop_sec` for onsets should be based on the model output frame rate.
|
| 160 |
+
# If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE.
|
| 161 |
+
# The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE.
|
| 162 |
+
# This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels.
|
| 163 |
+
# The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`.
|
| 164 |
+
# In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN.
|
| 165 |
+
# The `lengths` for the conformer is based on this T_cnn_out.
|
| 166 |
+
# So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps.
|
| 167 |
+
# Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer.
|
| 168 |
+
# Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames.
|
| 169 |
+
time_axis_energies = np.arange(T_out) * hop_sec
|
| 170 |
+
|
| 171 |
+
fig, ax1 = plt.subplots(figsize=(100, 10))
|
| 172 |
+
|
| 173 |
+
# Plot Mel Spectrogram on ax1
|
| 174 |
+
mel_db = torchaudio.functional.amplitude_to_DB(
|
| 175 |
+
mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0
|
| 176 |
+
)
|
| 177 |
+
img = ax1.imshow(
|
| 178 |
+
mel_db.numpy(),
|
| 179 |
+
aspect="auto",
|
| 180 |
+
origin="lower",
|
| 181 |
+
cmap="magma",
|
| 182 |
+
extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS],
|
| 183 |
+
)
|
| 184 |
+
ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets")
|
| 185 |
+
ax1.set_xlabel("Time (s)")
|
| 186 |
+
ax1.set_ylabel("Mel Bin")
|
| 187 |
+
fig.colorbar(img, ax=ax1, format="%+2.0f dB")
|
| 188 |
+
|
| 189 |
+
# Create a second y-axis for energies
|
| 190 |
+
ax2 = ax1.twinx()
|
| 191 |
+
ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red")
|
| 192 |
+
ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue")
|
| 193 |
+
ax2.plot(
|
| 194 |
+
time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green"
|
| 195 |
+
)
|
| 196 |
+
ax2.set_ylabel("Energy")
|
| 197 |
+
ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded
|
| 198 |
+
|
| 199 |
+
# Overlay onsets from decode_onsets (t is already in seconds)
|
| 200 |
+
labeled_types = set()
|
| 201 |
+
# Group drumrolls into segments (reuse logic from write_tja)
|
| 202 |
+
drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5]
|
| 203 |
+
drumroll_times.sort()
|
| 204 |
+
drumroll_segments = []
|
| 205 |
+
if drumroll_times:
|
| 206 |
+
seg_start = drumroll_times[0]
|
| 207 |
+
prev = drumroll_times[0]
|
| 208 |
+
for t in drumroll_times[1:]:
|
| 209 |
+
if t - prev <= hop_sec * 6: # up to 5-frame gap
|
| 210 |
+
prev = t
|
| 211 |
+
else:
|
| 212 |
+
drumroll_segments.append((seg_start, prev))
|
| 213 |
+
seg_start = t
|
| 214 |
+
prev = t
|
| 215 |
+
drumroll_segments.append((seg_start, prev))
|
| 216 |
+
# Plot Don/Ka onsets as vertical lines
|
| 217 |
+
for t_sec, typ in onsets:
|
| 218 |
+
if typ == 5:
|
| 219 |
+
continue # skip drumroll onsets
|
| 220 |
+
color_map = {1: "darkred", 2: "darkblue"}
|
| 221 |
+
label_map = {1: "Don Onset", 2: "Ka Onset"}
|
| 222 |
+
line_color = color_map.get(typ, "black")
|
| 223 |
+
line_label = label_map.get(typ, f"Type {typ} Onset")
|
| 224 |
+
if typ not in labeled_types:
|
| 225 |
+
ax1.axvline(
|
| 226 |
+
t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label
|
| 227 |
+
)
|
| 228 |
+
labeled_types.add(typ)
|
| 229 |
+
else:
|
| 230 |
+
ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9)
|
| 231 |
+
# Plot drumroll segments as shaded regions
|
| 232 |
+
for seg_start, seg_end in drumroll_segments:
|
| 233 |
+
ax1.axvspan(
|
| 234 |
+
seg_start,
|
| 235 |
+
seg_end + hop_sec,
|
| 236 |
+
color="green",
|
| 237 |
+
alpha=0.2,
|
| 238 |
+
label="Drumroll Segment" if "drumroll" not in labeled_types else None,
|
| 239 |
+
)
|
| 240 |
+
labeled_types.add("drumroll")
|
| 241 |
+
|
| 242 |
+
# Combine legends from both axes
|
| 243 |
+
lines, labels = ax1.get_legend_handles_labels()
|
| 244 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
| 245 |
+
ax2.legend(lines + lines2, labels + labels2, loc="upper right")
|
| 246 |
+
|
| 247 |
+
fig.tight_layout()
|
| 248 |
+
|
| 249 |
+
# Return plot as image buffer or save to file if path provided
|
| 250 |
+
if out_path:
|
| 251 |
+
plt.savefig(out_path)
|
| 252 |
+
print(f"Saved plot to {out_path}")
|
| 253 |
+
plt.close(fig)
|
| 254 |
+
return out_path
|
| 255 |
+
else:
|
| 256 |
+
# Return plot as in-memory buffer
|
| 257 |
+
return fig
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
|
| 261 |
+
# TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
|
| 262 |
+
# Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
|
| 263 |
+
sec_per_beat = 60 / bpm
|
| 264 |
+
beats_per_measure = 4 # Assuming 4/4 time signature
|
| 265 |
+
sec_per_measure = sec_per_beat * beats_per_measure
|
| 266 |
+
# Step 1: Map onsets to (measure_idx, slot, typ)
|
| 267 |
+
slot_events = []
|
| 268 |
+
for t, typ in onsets:
|
| 269 |
+
measure_idx = int(t // sec_per_measure)
|
| 270 |
+
t_in_measure = t % sec_per_measure
|
| 271 |
+
slot = int(round(t_in_measure / sec_per_measure * quantize))
|
| 272 |
+
if slot >= quantize:
|
| 273 |
+
slot = quantize - 1
|
| 274 |
+
slot_events.append((measure_idx, slot, typ))
|
| 275 |
+
# Step 2: Build measure/slot grid
|
| 276 |
+
if slot_events:
|
| 277 |
+
max_measure_idx = max(m for m, _, _ in slot_events)
|
| 278 |
+
else:
|
| 279 |
+
max_measure_idx = -1
|
| 280 |
+
measures = {i: [0] * quantize for i in range(max_measure_idx + 1)}
|
| 281 |
+
# Step 3: Place Don/Ka, collect drumrolls
|
| 282 |
+
drumroll_slots = set()
|
| 283 |
+
for m, s, typ in slot_events:
|
| 284 |
+
if typ in [1, 2]:
|
| 285 |
+
measures[m][s] = typ
|
| 286 |
+
elif typ == 5:
|
| 287 |
+
drumroll_slots.add((m, s))
|
| 288 |
+
# Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end)
|
| 289 |
+
# Flatten all slots to a list of (measure, slot) sorted
|
| 290 |
+
drumroll_list = sorted(list(drumroll_slots))
|
| 291 |
+
# Group into contiguous regions (allowing a gap of 5 slots)
|
| 292 |
+
grouped = []
|
| 293 |
+
group = []
|
| 294 |
+
for ms in drumroll_list:
|
| 295 |
+
if not group:
|
| 296 |
+
group = [ms]
|
| 297 |
+
else:
|
| 298 |
+
last_m, last_s = group[-1]
|
| 299 |
+
m, s = ms
|
| 300 |
+
# Calculate slot distance, considering measure wrap
|
| 301 |
+
slot_dist = None
|
| 302 |
+
if m == last_m:
|
| 303 |
+
slot_dist = s - last_s
|
| 304 |
+
elif m == last_m + 1 and last_s <= quantize - 1:
|
| 305 |
+
slot_dist = (quantize - 1 - last_s) + s + 1
|
| 306 |
+
else:
|
| 307 |
+
slot_dist = None
|
| 308 |
+
# Allow gap of up to 5 slots (slot_dist <= 6)
|
| 309 |
+
if slot_dist is not None and 1 <= slot_dist <= 6:
|
| 310 |
+
group.append(ms)
|
| 311 |
+
else:
|
| 312 |
+
grouped.append(group)
|
| 313 |
+
group = [ms]
|
| 314 |
+
if group:
|
| 315 |
+
grouped.append(group)
|
| 316 |
+
# Mark 5 (start) and 8 (end) for each group
|
| 317 |
+
for region in grouped:
|
| 318 |
+
if len(region) == 1:
|
| 319 |
+
m, s = region[0]
|
| 320 |
+
measures[m][s] = 5
|
| 321 |
+
# Place 8 in next slot (or next measure if at end)
|
| 322 |
+
if s < quantize - 1:
|
| 323 |
+
measures[m][s + 1] = 8
|
| 324 |
+
elif m < max_measure_idx:
|
| 325 |
+
measures[m + 1][0] = 8
|
| 326 |
+
else:
|
| 327 |
+
m_start, s_start = region[0]
|
| 328 |
+
m_end, s_end = region[-1]
|
| 329 |
+
measures[m_start][s_start] = 5
|
| 330 |
+
measures[m_end][s_end] = 8
|
| 331 |
+
# Fill 0 for middle slots (already 0 by default)
|
| 332 |
+
# Step 5: Generate TJA content
|
| 333 |
+
tja_content = []
|
| 334 |
+
tja_content.append(f"TITLE:{audio} (TC6, {time.strftime('%Y-%m-%d %H:%M:%S')})")
|
| 335 |
+
tja_content.append(f"BPM:{bpm}")
|
| 336 |
+
tja_content.append(f"WAVE:{audio}")
|
| 337 |
+
tja_content.append("OFFSET:0")
|
| 338 |
+
tja_content.append("COURSE:Oni\nLEVEL:9\n")
|
| 339 |
+
tja_content.append("#START")
|
| 340 |
+
for i in range(max_measure_idx + 1):
|
| 341 |
+
notes = measures.get(i, [0] * quantize)
|
| 342 |
+
line = "".join(str(n) for n in notes)
|
| 343 |
+
tja_content.append(line + ",")
|
| 344 |
+
tja_content.append("#END")
|
| 345 |
+
|
| 346 |
+
tja_string = "\n".join(tja_content)
|
| 347 |
+
|
| 348 |
+
# If out_path is provided, also write to file
|
| 349 |
+
if out_path:
|
| 350 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 351 |
+
f.write(tja_string)
|
| 352 |
+
print(f"TJA chart saved to {out_path}")
|
| 353 |
+
|
| 354 |
+
return tja_string
|
tc6/loss.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TaikoEnergyLoss(nn.Module):
|
| 6 |
+
def __init__(self, reduction="mean"):
|
| 7 |
+
super().__init__()
|
| 8 |
+
# Use 'none' reduction to get element-wise losses, then manually apply masking and reduction
|
| 9 |
+
self.mse_loss = nn.MSELoss(reduction="none")
|
| 10 |
+
self.reduction = reduction
|
| 11 |
+
|
| 12 |
+
def forward(self, outputs, batch):
|
| 13 |
+
"""
|
| 14 |
+
Calculates the MSE loss for energy-based predictions.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
outputs (dict): Model output, containing 'presence' tensor.
|
| 18 |
+
outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
|
| 19 |
+
batch (dict): Batch data from collate_fn, containing true labels and lengths.
|
| 20 |
+
batch['don_labels'], batch['ka_labels'], batch['drumroll_labels'] shape: (B, T)
|
| 21 |
+
batch['lengths'] shape: (B,) - valid sequence lengths for time dimension T.
|
| 22 |
+
Returns:
|
| 23 |
+
torch.Tensor: The calculated loss.
|
| 24 |
+
"""
|
| 25 |
+
pred_energies = outputs["presence"] # (B, T, 3)
|
| 26 |
+
|
| 27 |
+
true_don = batch["don_labels"] # (B, T)
|
| 28 |
+
true_ka = batch["ka_labels"] # (B, T)
|
| 29 |
+
true_drumroll = batch["drumroll_labels"] # (B, T)
|
| 30 |
+
|
| 31 |
+
# Stack true labels to match the structure of pred_energies (B, T, 3)
|
| 32 |
+
true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2)
|
| 33 |
+
|
| 34 |
+
B, T, _ = pred_energies.shape
|
| 35 |
+
|
| 36 |
+
# Create a mask based on batch['lengths'] to ignore padded parts of sequences
|
| 37 |
+
# batch['lengths'] gives the actual length of each sequence in the batch
|
| 38 |
+
# mask shape: (B, T)
|
| 39 |
+
mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
|
| 40 |
+
"lengths"
|
| 41 |
+
].unsqueeze(1)
|
| 42 |
+
# Expand mask to (B, T, 1) to broadcast across the 3 energy channels
|
| 43 |
+
mask_3d = mask_2d.unsqueeze(2)
|
| 44 |
+
|
| 45 |
+
# Calculate element-wise MSE loss
|
| 46 |
+
loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3)
|
| 47 |
+
|
| 48 |
+
# Apply the mask to the loss
|
| 49 |
+
masked_loss = loss_elementwise * mask_3d
|
| 50 |
+
|
| 51 |
+
if self.reduction == "mean":
|
| 52 |
+
# Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
|
| 53 |
+
total_loss = masked_loss.sum()
|
| 54 |
+
num_valid_elements = mask_3d.sum() # Total number of unmasked float values
|
| 55 |
+
if num_valid_elements > 0:
|
| 56 |
+
return total_loss / num_valid_elements
|
| 57 |
+
else:
|
| 58 |
+
# Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
|
| 59 |
+
return torch.tensor(
|
| 60 |
+
0.0, device=pred_energies.device, requires_grad=True
|
| 61 |
+
)
|
| 62 |
+
elif self.reduction == "sum":
|
| 63 |
+
return masked_loss.sum()
|
| 64 |
+
else: # 'none' or any other case
|
| 65 |
+
return masked_loss
|
tc6/model.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchaudio.models import Conformer
|
| 4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 5 |
+
from .config import (
|
| 6 |
+
N_MELS,
|
| 7 |
+
CNN_CH,
|
| 8 |
+
N_HEADS,
|
| 9 |
+
D_MODEL,
|
| 10 |
+
FF_DIM,
|
| 11 |
+
N_LAYERS,
|
| 12 |
+
DROPOUT,
|
| 13 |
+
DEPTHWISE_CONV_KERNEL_SIZE,
|
| 14 |
+
HIDDEN_DIM,
|
| 15 |
+
DEVICE,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TaikoConformer6(nn.Module, PyTorchModelHubMixin):
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super().__init__()
|
| 22 |
+
# 1) CNN frontend: frequency-only pooling
|
| 23 |
+
self.cnn = nn.Sequential(
|
| 24 |
+
nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
|
| 25 |
+
nn.BatchNorm2d(CNN_CH),
|
| 26 |
+
nn.GELU(),
|
| 27 |
+
nn.Dropout2d(DROPOUT),
|
| 28 |
+
nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
|
| 29 |
+
nn.BatchNorm2d(CNN_CH),
|
| 30 |
+
nn.GELU(),
|
| 31 |
+
nn.Dropout2d(DROPOUT),
|
| 32 |
+
)
|
| 33 |
+
feat_dim = CNN_CH * (N_MELS // 4)
|
| 34 |
+
|
| 35 |
+
# 2) Linear projection to model dimension
|
| 36 |
+
self.proj = nn.Linear(feat_dim, D_MODEL)
|
| 37 |
+
|
| 38 |
+
# 3) FiLM conditioning for notes_per_second, difficulty, and level
|
| 39 |
+
self.film_nps = nn.Linear(1, 2 * D_MODEL)
|
| 40 |
+
self.film_difficulty = nn.Linear(
|
| 41 |
+
1, 2 * D_MODEL
|
| 42 |
+
) # Assuming difficulty is a single scalar
|
| 43 |
+
self.film_level = nn.Linear(1, 2 * D_MODEL) # Assuming level is a single scalar
|
| 44 |
+
|
| 45 |
+
# 4) Conformer encoder
|
| 46 |
+
self.encoder = Conformer(
|
| 47 |
+
input_dim=D_MODEL,
|
| 48 |
+
num_heads=N_HEADS,
|
| 49 |
+
ffn_dim=FF_DIM,
|
| 50 |
+
num_layers=N_LAYERS,
|
| 51 |
+
depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
|
| 52 |
+
dropout=DROPOUT,
|
| 53 |
+
use_group_norm=False,
|
| 54 |
+
convolution_first=False,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# 5) Presence regressor head
|
| 58 |
+
self.presence_regressor = nn.Sequential(
|
| 59 |
+
nn.Dropout(DROPOUT),
|
| 60 |
+
nn.Linear(D_MODEL, HIDDEN_DIM),
|
| 61 |
+
nn.GELU(),
|
| 62 |
+
nn.Dropout(DROPOUT),
|
| 63 |
+
nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
|
| 64 |
+
nn.Sigmoid(), # Output between 0 and 1
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# 6) Initialize weights
|
| 68 |
+
for m in self.modules():
|
| 69 |
+
if isinstance(m, nn.Conv2d):
|
| 70 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 71 |
+
elif isinstance(m, nn.Linear):
|
| 72 |
+
nn.init.xavier_uniform_(m.weight)
|
| 73 |
+
if m.bias is not None:
|
| 74 |
+
nn.init.zeros_(m.bias)
|
| 75 |
+
|
| 76 |
+
def forward(
|
| 77 |
+
self,
|
| 78 |
+
mel: torch.Tensor,
|
| 79 |
+
lengths: torch.Tensor,
|
| 80 |
+
notes_per_second: torch.Tensor,
|
| 81 |
+
difficulty: torch.Tensor,
|
| 82 |
+
level: torch.Tensor,
|
| 83 |
+
):
|
| 84 |
+
"""
|
| 85 |
+
Args:
|
| 86 |
+
mel: (B, 1, N_MELS, T_mel)
|
| 87 |
+
lengths: (B,) lengths after CNN
|
| 88 |
+
notes_per_second: (B,) stream of control values
|
| 89 |
+
difficulty: (B,) difficulty values
|
| 90 |
+
level: (B,) level values
|
| 91 |
+
Returns:
|
| 92 |
+
Dict with:
|
| 93 |
+
'presence': (B, T_cnn_out, 3) # Corrected from 4 to 3
|
| 94 |
+
'lengths': lengths
|
| 95 |
+
"""
|
| 96 |
+
# CNN frontend
|
| 97 |
+
x = self.cnn(mel) # (B, C, F, T)
|
| 98 |
+
B, C, F, T = x.size()
|
| 99 |
+
x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
|
| 100 |
+
|
| 101 |
+
# Project to model dimension
|
| 102 |
+
x = self.proj(x) # (B, T, D_MODEL)
|
| 103 |
+
|
| 104 |
+
# FiLM conditioning
|
| 105 |
+
nps = notes_per_second.unsqueeze(-1).float() # (B, 1)
|
| 106 |
+
gamma_beta_nps = self.film_nps(nps) # (B, 2*D_MODEL)
|
| 107 |
+
gamma_nps, beta_nps = gamma_beta_nps.chunk(2, dim=-1)
|
| 108 |
+
x = gamma_nps.unsqueeze(1) * x + beta_nps.unsqueeze(1)
|
| 109 |
+
|
| 110 |
+
diff = difficulty.unsqueeze(-1).float() # (B, 1)
|
| 111 |
+
gamma_beta_diff = self.film_difficulty(diff) # (B, 2*D_MODEL)
|
| 112 |
+
gamma_diff, beta_diff = gamma_beta_diff.chunk(2, dim=-1)
|
| 113 |
+
x = gamma_diff.unsqueeze(1) * x + beta_diff.unsqueeze(1)
|
| 114 |
+
|
| 115 |
+
lvl = level.unsqueeze(-1).float() # (B, 1)
|
| 116 |
+
gamma_beta_lvl = self.film_level(lvl) # (B, 2*D_MODEL)
|
| 117 |
+
gamma_lvl, beta_lvl = gamma_beta_lvl.chunk(2, dim=-1)
|
| 118 |
+
x = gamma_lvl.unsqueeze(1) * x + beta_lvl.unsqueeze(1)
|
| 119 |
+
|
| 120 |
+
# Conformer encoder
|
| 121 |
+
x, _ = self.encoder(x, lengths=lengths)
|
| 122 |
+
|
| 123 |
+
# Presence prediction
|
| 124 |
+
presence = self.presence_regressor(x)
|
| 125 |
+
return {"presence": presence, "lengths": lengths}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
model = TaikoConformer6().to(device=DEVICE)
|
| 130 |
+
print(model)
|
| 131 |
+
|
| 132 |
+
for name, param in model.named_parameters():
|
| 133 |
+
if param.requires_grad:
|
| 134 |
+
print(f"{name}: {param.numel():,}")
|
| 135 |
+
|
| 136 |
+
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 137 |
+
print(f"Total parameters: {params / 1e6:.2f}M")
|
| 138 |
+
|
| 139 |
+
batch_size = 4
|
| 140 |
+
mel_time_steps = 1024
|
| 141 |
+
input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
|
| 142 |
+
|
| 143 |
+
conformer_lengths = torch.tensor(
|
| 144 |
+
[mel_time_steps] * batch_size, dtype=torch.long
|
| 145 |
+
).to(DEVICE)
|
| 146 |
+
|
| 147 |
+
notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
|
| 148 |
+
DEVICE
|
| 149 |
+
)
|
| 150 |
+
difficulty_input = torch.tensor([1.0] * batch_size, dtype=torch.float32).to(
|
| 151 |
+
DEVICE
|
| 152 |
+
) # Example difficulty
|
| 153 |
+
level_input = torch.tensor([5.0] * batch_size, dtype=torch.float32).to(
|
| 154 |
+
DEVICE
|
| 155 |
+
) # Example level
|
| 156 |
+
|
| 157 |
+
output = model(
|
| 158 |
+
input_mel,
|
| 159 |
+
conformer_lengths,
|
| 160 |
+
notes_per_second_input,
|
| 161 |
+
difficulty_input,
|
| 162 |
+
level_input,
|
| 163 |
+
)
|
| 164 |
+
print("Output shapes:")
|
| 165 |
+
for key, value in output.items():
|
| 166 |
+
print(f"{key}: {value.shape}")
|
tc6/preprocess.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchaudio
|
| 5 |
+
from torchaudio.transforms import FrequencyMasking
|
| 6 |
+
from tja import parse_tja, PyParsingMode
|
| 7 |
+
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
|
| 8 |
+
from .model import TaikoConformer6
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
| 12 |
+
sample_rate=SAMPLE_RATE,
|
| 13 |
+
n_mels=N_MELS,
|
| 14 |
+
hop_length=HOP_LENGTH,
|
| 15 |
+
n_fft=2048,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
freq_mask = FrequencyMasking(freq_mask_param=15)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def preprocess(example, difficulty="oni"):
|
| 23 |
+
wav_tensor = example["audio"]["array"]
|
| 24 |
+
sr = example["audio"]["sampling_rate"]
|
| 25 |
+
# 1) load & resample
|
| 26 |
+
if sr != SAMPLE_RATE:
|
| 27 |
+
wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
|
| 28 |
+
# normalize audio
|
| 29 |
+
wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
|
| 30 |
+
# add random Gaussian noise
|
| 31 |
+
if torch.rand(1).item() < 0.5:
|
| 32 |
+
wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
|
| 33 |
+
# 2) mel: (1, N_MELS, T)
|
| 34 |
+
mel = mel_transform(wav_tensor).unsqueeze(0)
|
| 35 |
+
# apply SpecAugment
|
| 36 |
+
mel = freq_mask(mel)
|
| 37 |
+
_, _, T = mel.shape
|
| 38 |
+
# 3) build label sequence of length ceil(T / TIME_SUB)
|
| 39 |
+
T_sub = math.ceil(T / TIME_SUB)
|
| 40 |
+
|
| 41 |
+
# Initialize energy-based labels for Don, Ka, Drumroll
|
| 42 |
+
don_labels = torch.zeros(T_sub, dtype=torch.float32)
|
| 43 |
+
ka_labels = torch.zeros(T_sub, dtype=torch.float32)
|
| 44 |
+
drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
|
| 45 |
+
|
| 46 |
+
# Define exponential decay tail parameters
|
| 47 |
+
tail_length = 40 # number of frames for decay tail
|
| 48 |
+
decay_rate = 8.0 # decay rate parameter, adjust as needed
|
| 49 |
+
tail_kernel = torch.exp(
|
| 50 |
+
-torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
fps = SAMPLE_RATE / HOP_LENGTH
|
| 54 |
+
num_valid_notes = 0
|
| 55 |
+
for onset in example[difficulty]:
|
| 56 |
+
typ, t_start, t_end, *_ = onset
|
| 57 |
+
|
| 58 |
+
# Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
|
| 59 |
+
if typ < 1 or typ > N_TYPES: # Filter out invalid types
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
num_valid_notes += 1
|
| 63 |
+
|
| 64 |
+
exact_frame_start = t_start.item() * fps
|
| 65 |
+
|
| 66 |
+
# Type 1 and 3 are Don, Type 2 and 4 are Ka
|
| 67 |
+
if typ == 1 or typ == 3 or typ == 2 or typ == 4:
|
| 68 |
+
exact_hit_time_sub = exact_frame_start / TIME_SUB
|
| 69 |
+
|
| 70 |
+
current_labels = don_labels if (typ == 1 or typ == 3) else ka_labels
|
| 71 |
+
|
| 72 |
+
start_points_info = []
|
| 73 |
+
rounded_hit_time_sub = round(exact_hit_time_sub)
|
| 74 |
+
|
| 75 |
+
if (
|
| 76 |
+
abs(exact_hit_time_sub - rounded_hit_time_sub) < 1e-6
|
| 77 |
+
): # Tolerance for float precision
|
| 78 |
+
idx_single = int(rounded_hit_time_sub)
|
| 79 |
+
if 0 <= idx_single < T_sub:
|
| 80 |
+
start_points_info.append({"idx": idx_single, "weight": 1.0})
|
| 81 |
+
else:
|
| 82 |
+
idx_floor = math.floor(exact_hit_time_sub)
|
| 83 |
+
idx_ceil = idx_floor + 1
|
| 84 |
+
|
| 85 |
+
frac = exact_hit_time_sub - idx_floor
|
| 86 |
+
weight_ceil = frac
|
| 87 |
+
weight_floor = 1.0 - frac
|
| 88 |
+
|
| 89 |
+
if weight_floor > 1e-6 and 0 <= idx_floor < T_sub:
|
| 90 |
+
start_points_info.append({"idx": idx_floor, "weight": weight_floor})
|
| 91 |
+
if weight_ceil > 1e-6 and 0 <= idx_ceil < T_sub:
|
| 92 |
+
start_points_info.append({"idx": idx_ceil, "weight": weight_ceil})
|
| 93 |
+
|
| 94 |
+
for point_info in start_points_info:
|
| 95 |
+
start_idx = point_info["idx"]
|
| 96 |
+
weight = point_info["weight"]
|
| 97 |
+
for k_idx, kernel_val in enumerate(tail_kernel):
|
| 98 |
+
target_idx = start_idx + k_idx
|
| 99 |
+
if 0 <= target_idx < T_sub:
|
| 100 |
+
current_labels[target_idx] = max(
|
| 101 |
+
current_labels[target_idx].item(),
|
| 102 |
+
weight * kernel_val.item(),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Type 5, 6, 7 are Drumroll
|
| 106 |
+
elif typ >= 5 and typ <= 7:
|
| 107 |
+
exact_frame_end = t_end.item() * fps
|
| 108 |
+
exact_start_time_sub = exact_frame_start / TIME_SUB
|
| 109 |
+
exact_end_time_sub = exact_frame_end / TIME_SUB
|
| 110 |
+
|
| 111 |
+
# Improved drumroll body
|
| 112 |
+
body_loop_start_idx = math.floor(exact_start_time_sub)
|
| 113 |
+
body_loop_end_idx = math.ceil(exact_end_time_sub)
|
| 114 |
+
|
| 115 |
+
for dr_idx in range(body_loop_start_idx, body_loop_end_idx):
|
| 116 |
+
if 0 <= dr_idx < T_sub:
|
| 117 |
+
drumroll_labels[dr_idx] = 1.0
|
| 118 |
+
|
| 119 |
+
# Improved drumroll tail (starts from exact_end_time_sub)
|
| 120 |
+
tail_start_points_info = []
|
| 121 |
+
rounded_end_time_sub = round(exact_end_time_sub)
|
| 122 |
+
if abs(exact_end_time_sub - rounded_end_time_sub) < 1e-6:
|
| 123 |
+
idx_single_tail = int(rounded_end_time_sub)
|
| 124 |
+
if 0 <= idx_single_tail < T_sub:
|
| 125 |
+
tail_start_points_info.append(
|
| 126 |
+
{"idx": idx_single_tail, "weight": 1.0}
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
idx_floor_tail = math.floor(exact_end_time_sub)
|
| 130 |
+
idx_ceil_tail = idx_floor_tail + 1
|
| 131 |
+
|
| 132 |
+
frac_tail = exact_end_time_sub - idx_floor_tail
|
| 133 |
+
weight_ceil_tail = frac_tail
|
| 134 |
+
weight_floor_tail = 1.0 - frac_tail
|
| 135 |
+
|
| 136 |
+
if weight_floor_tail > 1e-6 and 0 <= idx_floor_tail < T_sub:
|
| 137 |
+
tail_start_points_info.append(
|
| 138 |
+
{"idx": idx_floor_tail, "weight": weight_floor_tail}
|
| 139 |
+
)
|
| 140 |
+
if weight_ceil_tail > 1e-6 and 0 <= idx_ceil_tail < T_sub:
|
| 141 |
+
tail_start_points_info.append(
|
| 142 |
+
{"idx": idx_ceil_tail, "weight": weight_ceil_tail}
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
for point_info in tail_start_points_info:
|
| 146 |
+
start_idx = point_info["idx"]
|
| 147 |
+
weight = point_info["weight"]
|
| 148 |
+
for k_idx, kernel_val in enumerate(tail_kernel):
|
| 149 |
+
target_idx = start_idx + k_idx
|
| 150 |
+
if 0 <= target_idx < T_sub:
|
| 151 |
+
drumroll_labels[target_idx] = max(
|
| 152 |
+
drumroll_labels[target_idx].item(),
|
| 153 |
+
weight * kernel_val.item(),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
|
| 157 |
+
nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
|
| 158 |
+
|
| 159 |
+
parsed = parse_tja(example["tja"], mode=PyParsingMode.Full)
|
| 160 |
+
chart = next(
|
| 161 |
+
(chart for chart in parsed.charts if chart.course.lower() == difficulty), None
|
| 162 |
+
)
|
| 163 |
+
difficulty_id = (
|
| 164 |
+
0
|
| 165 |
+
if difficulty == "easy"
|
| 166 |
+
else (
|
| 167 |
+
1
|
| 168 |
+
if difficulty == "normal"
|
| 169 |
+
else 2 if difficulty == "hard" else 3 if difficulty == "oni" else 4
|
| 170 |
+
) # Assuming 4 for edit/ura
|
| 171 |
+
)
|
| 172 |
+
level = chart.level if chart else 0
|
| 173 |
+
|
| 174 |
+
# --- CNN shape inference and label padding/truncation ---
|
| 175 |
+
# Simulate CNN to get output time length (T_cnn)
|
| 176 |
+
dummy_model = TaikoConformer6()
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
cnn_out = dummy_model.cnn(mel.unsqueeze(0)) # (1, C, F, T_cnn)
|
| 179 |
+
_, _, _, T_cnn = cnn_out.shape
|
| 180 |
+
|
| 181 |
+
# Pad or truncate labels to T_cnn
|
| 182 |
+
def pad_or_truncate(label, out_len):
|
| 183 |
+
if label.shape[0] < out_len:
|
| 184 |
+
pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
|
| 185 |
+
return torch.cat([label, pad], dim=0)
|
| 186 |
+
else:
|
| 187 |
+
return label[:out_len]
|
| 188 |
+
|
| 189 |
+
don_labels = pad_or_truncate(don_labels, T_cnn)
|
| 190 |
+
ka_labels = pad_or_truncate(ka_labels, T_cnn)
|
| 191 |
+
drumroll_labels = pad_or_truncate(drumroll_labels, T_cnn)
|
| 192 |
+
|
| 193 |
+
# For conformer input lengths: based on original mel shape (before CNN)
|
| 194 |
+
conformer_input_length = min(math.ceil(T / TIME_SUB), T_cnn)
|
| 195 |
+
|
| 196 |
+
print(
|
| 197 |
+
f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}, Difficulty: {difficulty_id}, Level: {level}"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
return {
|
| 201 |
+
"mel": mel, # (1, N_MELS, T)
|
| 202 |
+
"don_labels": don_labels, # (T_cnn,)
|
| 203 |
+
"ka_labels": ka_labels, # (T_cnn,)
|
| 204 |
+
"drumroll_labels": drumroll_labels, # (T_cnn,)
|
| 205 |
+
"nps": torch.tensor(nps, dtype=torch.float32),
|
| 206 |
+
"difficulty": torch.tensor(difficulty_id, dtype=torch.long),
|
| 207 |
+
"level": torch.tensor(level, dtype=torch.long),
|
| 208 |
+
"duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
|
| 209 |
+
"length": torch.tensor(
|
| 210 |
+
conformer_input_length, dtype=torch.long
|
| 211 |
+
), # for conformer
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def collate_fn(batch):
|
| 216 |
+
mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS)
|
| 217 |
+
don_labels_list = [b["don_labels"] for b in batch]
|
| 218 |
+
ka_labels_list = [b["ka_labels"] for b in batch]
|
| 219 |
+
drumroll_labels_list = [b["drumroll_labels"] for b in batch]
|
| 220 |
+
nps_list = [b["nps"] for b in batch]
|
| 221 |
+
difficulty_list = [b["difficulty"] for b in batch]
|
| 222 |
+
level_list = [b["level"] for b in batch]
|
| 223 |
+
durations_list = [b["duration_seconds"] for b in batch]
|
| 224 |
+
lengths_list = [b["length"] for b in batch]
|
| 225 |
+
|
| 226 |
+
# Pad mels
|
| 227 |
+
padded_mels = nn.utils.rnn.pad_sequence(
|
| 228 |
+
mels_list, batch_first=True
|
| 229 |
+
) # (B, T_max, N_MELS)
|
| 230 |
+
reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
|
| 231 |
+
T_max = padded_mels.shape[1]
|
| 232 |
+
|
| 233 |
+
# Pad labels to T_max
|
| 234 |
+
def pad_label(label, out_len):
|
| 235 |
+
if label.shape[0] < out_len:
|
| 236 |
+
pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
|
| 237 |
+
return torch.cat([label, pad], dim=0)
|
| 238 |
+
else:
|
| 239 |
+
return label[:out_len]
|
| 240 |
+
|
| 241 |
+
don_labels = torch.stack([pad_label(l, T_max) for l in don_labels_list])
|
| 242 |
+
ka_labels = torch.stack([pad_label(l, T_max) for l in ka_labels_list])
|
| 243 |
+
drumroll_labels = torch.stack([pad_label(l, T_max) for l in drumroll_labels_list])
|
| 244 |
+
lengths = torch.tensor(
|
| 245 |
+
[min(l.item(), T_max) for l in lengths_list], dtype=torch.long
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
return {
|
| 249 |
+
"mel": reshaped_mels,
|
| 250 |
+
"don_labels": don_labels,
|
| 251 |
+
"ka_labels": ka_labels,
|
| 252 |
+
"drumroll_labels": drumroll_labels,
|
| 253 |
+
"lengths": lengths, # for conformer
|
| 254 |
+
"nps": torch.stack(nps_list),
|
| 255 |
+
"difficulty": torch.stack(difficulty_list),
|
| 256 |
+
"level": torch.stack(level_list),
|
| 257 |
+
"durations": torch.stack(durations_list),
|
| 258 |
+
}
|
tc6/train.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from accelerate.utils import set_seed
|
| 2 |
+
|
| 3 |
+
set_seed(1024)
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from datasets import concatenate_datasets
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import numpy as np
|
| 14 |
+
from .config import (
|
| 15 |
+
BATCH_SIZE,
|
| 16 |
+
DEVICE,
|
| 17 |
+
EPOCHS,
|
| 18 |
+
LR,
|
| 19 |
+
GRAD_ACCUM_STEPS,
|
| 20 |
+
HOP_LENGTH,
|
| 21 |
+
SAMPLE_RATE,
|
| 22 |
+
)
|
| 23 |
+
from .model import TaikoConformer6
|
| 24 |
+
from .dataset import ds
|
| 25 |
+
from .preprocess import preprocess, collate_fn
|
| 26 |
+
from .loss import TaikoEnergyLoss
|
| 27 |
+
from huggingface_hub import upload_folder
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# --- Helper function to log energy plots ---
|
| 31 |
+
def log_energy_plots_to_tensorboard(
|
| 32 |
+
writer,
|
| 33 |
+
tag_prefix,
|
| 34 |
+
epoch,
|
| 35 |
+
pred_don,
|
| 36 |
+
pred_ka,
|
| 37 |
+
pred_drumroll,
|
| 38 |
+
true_don,
|
| 39 |
+
true_ka,
|
| 40 |
+
true_drumroll,
|
| 41 |
+
valid_length, # Actual valid length of the sequence (before padding)
|
| 42 |
+
hop_sec,
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Logs a plot of predicted vs. true energies for one sample to TensorBoard.
|
| 46 |
+
Energies should be 1D numpy arrays for the single sample, up to valid_length.
|
| 47 |
+
"""
|
| 48 |
+
# Ensure data is on CPU and converted to numpy, and select only the valid part
|
| 49 |
+
pred_don = pred_don[:valid_length].detach().cpu().numpy()
|
| 50 |
+
pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
|
| 51 |
+
pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
|
| 52 |
+
true_don = true_don[:valid_length].cpu().numpy()
|
| 53 |
+
true_ka = true_ka[:valid_length].cpu().numpy()
|
| 54 |
+
true_drumroll = true_drumroll[:valid_length].cpu().numpy()
|
| 55 |
+
|
| 56 |
+
time_axis = np.arange(valid_length) * hop_sec
|
| 57 |
+
|
| 58 |
+
fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
|
| 59 |
+
fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)
|
| 60 |
+
|
| 61 |
+
axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
|
| 62 |
+
axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
|
| 63 |
+
axs[0].set_ylabel("Don Energy")
|
| 64 |
+
axs[0].legend()
|
| 65 |
+
axs[0].grid(True)
|
| 66 |
+
|
| 67 |
+
axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
|
| 68 |
+
axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
|
| 69 |
+
axs[1].set_ylabel("Ka Energy")
|
| 70 |
+
axs[1].legend()
|
| 71 |
+
axs[1].grid(True)
|
| 72 |
+
|
| 73 |
+
axs[2].plot(
|
| 74 |
+
time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
|
| 75 |
+
)
|
| 76 |
+
axs[2].plot(
|
| 77 |
+
time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
|
| 78 |
+
)
|
| 79 |
+
axs[2].set_ylabel("Drumroll Energy")
|
| 80 |
+
axs[2].set_xlabel("Time (s)")
|
| 81 |
+
axs[2].legend()
|
| 82 |
+
axs[2].grid(True)
|
| 83 |
+
|
| 84 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
|
| 85 |
+
writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
|
| 86 |
+
plt.close(fig)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def main():
|
| 90 |
+
global ds
|
| 91 |
+
|
| 92 |
+
# Calculate hop seconds for model output frames
|
| 93 |
+
# This assumes the model output time dimension corresponds to the mel spectrogram time dimension
|
| 94 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 95 |
+
|
| 96 |
+
best_val_loss = float("inf")
|
| 97 |
+
patience = 10 # Increased patience a bit
|
| 98 |
+
pat_count = 0
|
| 99 |
+
|
| 100 |
+
ds_oni = ds.map(
|
| 101 |
+
preprocess,
|
| 102 |
+
remove_columns=ds.column_names,
|
| 103 |
+
fn_kwargs={"difficulty": "oni"},
|
| 104 |
+
writer_batch_size=10,
|
| 105 |
+
)
|
| 106 |
+
ds_hard = ds.map(
|
| 107 |
+
preprocess,
|
| 108 |
+
remove_columns=ds.column_names,
|
| 109 |
+
fn_kwargs={"difficulty": "hard"},
|
| 110 |
+
writer_batch_size=10,
|
| 111 |
+
)
|
| 112 |
+
ds_normal = ds.map(
|
| 113 |
+
preprocess,
|
| 114 |
+
remove_columns=ds.column_names,
|
| 115 |
+
fn_kwargs={"difficulty": "normal"},
|
| 116 |
+
writer_batch_size=10,
|
| 117 |
+
)
|
| 118 |
+
ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])
|
| 119 |
+
|
| 120 |
+
ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
|
| 121 |
+
# ds_train_test.push_to_hub("JacobLinCool/taiko-conformer-6-ds")
|
| 122 |
+
train_loader = DataLoader(
|
| 123 |
+
ds_train_test["train"],
|
| 124 |
+
batch_size=BATCH_SIZE,
|
| 125 |
+
shuffle=True,
|
| 126 |
+
collate_fn=collate_fn,
|
| 127 |
+
num_workers=16,
|
| 128 |
+
persistent_workers=True,
|
| 129 |
+
prefetch_factor=4,
|
| 130 |
+
)
|
| 131 |
+
val_loader = DataLoader(
|
| 132 |
+
ds_train_test["test"],
|
| 133 |
+
batch_size=BATCH_SIZE,
|
| 134 |
+
shuffle=False,
|
| 135 |
+
collate_fn=collate_fn,
|
| 136 |
+
num_workers=16,
|
| 137 |
+
persistent_workers=True,
|
| 138 |
+
prefetch_factor=4,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
model = TaikoConformer6().to(DEVICE)
|
| 142 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
| 143 |
+
|
| 144 |
+
criterion = TaikoEnergyLoss(reduction="mean").to(DEVICE)
|
| 145 |
+
|
| 146 |
+
# Adjust scheduler steps for gradient accumulation
|
| 147 |
+
num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
|
| 148 |
+
total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch
|
| 149 |
+
|
| 150 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 151 |
+
optimizer, max_lr=LR, total_steps=total_optimizer_steps
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
writer = SummaryWriter()
|
| 155 |
+
|
| 156 |
+
for epoch in range(1, EPOCHS + 1):
|
| 157 |
+
model.train()
|
| 158 |
+
total_epoch_loss = 0.0
|
| 159 |
+
optimizer.zero_grad()
|
| 160 |
+
|
| 161 |
+
for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
|
| 162 |
+
mel = batch["mel"].to(DEVICE)
|
| 163 |
+
# Unpack new energy-based labels
|
| 164 |
+
don_labels = batch["don_labels"].to(DEVICE)
|
| 165 |
+
ka_labels = batch["ka_labels"].to(DEVICE)
|
| 166 |
+
drumroll_labels = batch["drumroll_labels"].to(DEVICE)
|
| 167 |
+
lengths = batch["lengths"].to(
|
| 168 |
+
DEVICE
|
| 169 |
+
) # These are for the Conformer model output
|
| 170 |
+
nps = batch["nps"].to(DEVICE)
|
| 171 |
+
difficulty = batch["difficulty"].to(DEVICE) # Add difficulty
|
| 172 |
+
level = batch["level"].to(DEVICE) # Add level
|
| 173 |
+
|
| 174 |
+
output_dict = model(
|
| 175 |
+
mel, lengths, nps, difficulty, level
|
| 176 |
+
) # Pass difficulty and level
|
| 177 |
+
# output_dict["presence"] is now (B, T_out, 3) for don, ka, drumroll energies
|
| 178 |
+
pred_energies_batch = output_dict["presence"] # (B, T_out, 3)
|
| 179 |
+
|
| 180 |
+
loss_input_batch = {
|
| 181 |
+
"don_labels": don_labels,
|
| 182 |
+
"ka_labels": ka_labels,
|
| 183 |
+
"drumroll_labels": drumroll_labels,
|
| 184 |
+
"lengths": lengths, # Pass lengths for masking within the loss function
|
| 185 |
+
}
|
| 186 |
+
loss = criterion(output_dict, loss_input_batch)
|
| 187 |
+
|
| 188 |
+
(loss / GRAD_ACCUM_STEPS).backward()
|
| 189 |
+
|
| 190 |
+
if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
|
| 191 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 192 |
+
optimizer.step()
|
| 193 |
+
scheduler.step()
|
| 194 |
+
optimizer.zero_grad()
|
| 195 |
+
|
| 196 |
+
total_epoch_loss += loss.item()
|
| 197 |
+
|
| 198 |
+
# Log plot for the first sample of the first batch in each training epoch
|
| 199 |
+
if idx == 0:
|
| 200 |
+
first_sample_pred_don = pred_energies_batch[0, :, 0]
|
| 201 |
+
first_sample_pred_ka = pred_energies_batch[0, :, 1]
|
| 202 |
+
first_sample_pred_drumroll = pred_energies_batch[0, :, 2]
|
| 203 |
+
|
| 204 |
+
first_sample_true_don = don_labels[0, :]
|
| 205 |
+
first_sample_true_ka = ka_labels[0, :]
|
| 206 |
+
first_sample_true_drumroll = drumroll_labels[0, :]
|
| 207 |
+
|
| 208 |
+
first_sample_length = lengths[
|
| 209 |
+
0
|
| 210 |
+
].item() # Get the valid length of the first sample
|
| 211 |
+
|
| 212 |
+
log_energy_plots_to_tensorboard(
|
| 213 |
+
writer,
|
| 214 |
+
"Train/Sample_0",
|
| 215 |
+
epoch,
|
| 216 |
+
first_sample_pred_don,
|
| 217 |
+
first_sample_pred_ka,
|
| 218 |
+
first_sample_pred_drumroll,
|
| 219 |
+
first_sample_true_don,
|
| 220 |
+
first_sample_true_ka,
|
| 221 |
+
first_sample_true_drumroll,
|
| 222 |
+
first_sample_length,
|
| 223 |
+
output_frame_hop_sec,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
avg_train_loss = total_epoch_loss / len(train_loader)
|
| 227 |
+
writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)
|
| 228 |
+
|
| 229 |
+
# Validation
|
| 230 |
+
model.eval()
|
| 231 |
+
total_val_loss = 0.0
|
| 232 |
+
# Removed storage for classification logits/labels and confusion matrix components
|
| 233 |
+
|
| 234 |
+
with torch.no_grad():
|
| 235 |
+
for val_idx, batch in enumerate(
|
| 236 |
+
tqdm(val_loader, desc=f"Val Epoch {epoch}")
|
| 237 |
+
):
|
| 238 |
+
mel = batch["mel"].to(DEVICE)
|
| 239 |
+
don_labels = batch["don_labels"].to(DEVICE)
|
| 240 |
+
ka_labels = batch["ka_labels"].to(DEVICE)
|
| 241 |
+
drumroll_labels = batch["drumroll_labels"].to(DEVICE)
|
| 242 |
+
lengths = batch["lengths"].to(DEVICE)
|
| 243 |
+
nps = batch["nps"].to(DEVICE) # Ground truth NPS from batch
|
| 244 |
+
difficulty = batch["difficulty"].to(DEVICE) # Add difficulty
|
| 245 |
+
level = batch["level"].to(DEVICE) # Add level
|
| 246 |
+
|
| 247 |
+
output_dict = model(
|
| 248 |
+
mel, lengths, nps, difficulty, level
|
| 249 |
+
) # Pass difficulty and level
|
| 250 |
+
pred_energies_val_batch = output_dict["presence"] # (B, T_out, 3)
|
| 251 |
+
|
| 252 |
+
val_loss_input_batch = {
|
| 253 |
+
"don_labels": don_labels,
|
| 254 |
+
"ka_labels": ka_labels,
|
| 255 |
+
"drumroll_labels": drumroll_labels,
|
| 256 |
+
"lengths": lengths,
|
| 257 |
+
}
|
| 258 |
+
val_loss = criterion(output_dict, val_loss_input_batch)
|
| 259 |
+
total_val_loss += val_loss.item()
|
| 260 |
+
|
| 261 |
+
# Log plot for the first sample of the first batch in each validation epoch
|
| 262 |
+
if val_idx == 0:
|
| 263 |
+
first_val_sample_pred_don = pred_energies_val_batch[0, :, 0]
|
| 264 |
+
first_val_sample_pred_ka = pred_energies_val_batch[0, :, 1]
|
| 265 |
+
first_val_sample_pred_drumroll = pred_energies_val_batch[0, :, 2]
|
| 266 |
+
|
| 267 |
+
first_val_sample_true_don = don_labels[0, :]
|
| 268 |
+
first_val_sample_true_ka = ka_labels[0, :]
|
| 269 |
+
first_val_sample_true_drumroll = drumroll_labels[0, :]
|
| 270 |
+
|
| 271 |
+
first_val_sample_length = lengths[0].item()
|
| 272 |
+
|
| 273 |
+
log_energy_plots_to_tensorboard(
|
| 274 |
+
writer,
|
| 275 |
+
"Eval/Sample_0",
|
| 276 |
+
epoch,
|
| 277 |
+
first_val_sample_pred_don,
|
| 278 |
+
first_val_sample_pred_ka,
|
| 279 |
+
first_val_sample_pred_drumroll,
|
| 280 |
+
first_val_sample_true_don,
|
| 281 |
+
first_val_sample_true_ka,
|
| 282 |
+
first_val_sample_true_drumroll,
|
| 283 |
+
first_val_sample_length,
|
| 284 |
+
output_frame_hop_sec,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Log ground truth NPS for reference during validation if needed
|
| 288 |
+
# writer.add_scalar("NPS/GT_Val_Batch_Avg", nps.mean().item(), epoch * len(val_loader) + idx)
|
| 289 |
+
|
| 290 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
| 291 |
+
writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)
|
| 292 |
+
|
| 293 |
+
# Log learning rate
|
| 294 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
| 295 |
+
writer.add_scalar("LR/learning_rate", current_lr, epoch)
|
| 296 |
+
|
| 297 |
+
# Log ground truth NPS from the last validation batch (or mean over epoch)
|
| 298 |
+
if "nps" in batch: # Check if nps is in the last batch
|
| 299 |
+
writer.add_scalar(
|
| 300 |
+
"NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
print(
|
| 304 |
+
f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if avg_val_loss < best_val_loss:
|
| 308 |
+
best_val_loss = avg_val_loss
|
| 309 |
+
pat_count = 0
|
| 310 |
+
torch.save(model.state_dict(), "best_model.pt") # Changed model save name
|
| 311 |
+
print(f"Saved new best model to best_model.pt at epoch {epoch}")
|
| 312 |
+
else:
|
| 313 |
+
pat_count += 1
|
| 314 |
+
if pat_count >= patience:
|
| 315 |
+
print("Early stopping!")
|
| 316 |
+
break
|
| 317 |
+
writer.close()
|
| 318 |
+
|
| 319 |
+
model_id = "JacobLinCool/taiko-conformer-6"
|
| 320 |
+
try:
|
| 321 |
+
model.push_to_hub(model_id, commit_message="Upload trained model")
|
| 322 |
+
upload_folder(
|
| 323 |
+
repo_id=model_id,
|
| 324 |
+
folder_path="runs",
|
| 325 |
+
path_in_repo=".",
|
| 326 |
+
commit_message="Upload training logs",
|
| 327 |
+
ignore_patterns=["*.txt", "*.json", "*.csv"],
|
| 328 |
+
)
|
| 329 |
+
print(f"Model and logs uploaded to {model_id}")
|
| 330 |
+
except Exception as e:
|
| 331 |
+
print(f"Error uploading to Hugging Face Hub: {e}")
|
| 332 |
+
print("Make sure you have the correct permissions and try again.")
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
if __name__ == "__main__":
|
| 336 |
+
main()
|
tc7/__init__.py
ADDED
|
File without changes
|
tc7/config.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# ─── 1) CONFIG ─────────────────────────────────────────────────────
|
| 4 |
+
SAMPLE_RATE = 22050
|
| 5 |
+
N_MELS = 80
|
| 6 |
+
HOP_LENGTH = 256
|
| 7 |
+
TIME_SUB = 1
|
| 8 |
+
CNN_CH = 256
|
| 9 |
+
N_HEADS = 8
|
| 10 |
+
D_MODEL = 512
|
| 11 |
+
FF_DIM = 1024
|
| 12 |
+
N_LAYERS = 6
|
| 13 |
+
DEPTHWISE_CONV_KERNEL_SIZE = 31
|
| 14 |
+
DROPOUT = 0.1
|
| 15 |
+
HIDDEN_DIM = 64
|
| 16 |
+
N_TYPES = 7
|
| 17 |
+
BATCH_SIZE = 2
|
| 18 |
+
GRAD_ACCUM_STEPS = 8
|
| 19 |
+
LR = 3e-4
|
| 20 |
+
EPOCHS = 200
|
| 21 |
+
NPS_PENALTY_WEIGHT_ALPHA = 0.3
|
| 22 |
+
NPS_PENALTY_WEIGHT_BETA = 1.0
|
| 23 |
+
DEVICE = (
|
| 24 |
+
"cuda"
|
| 25 |
+
if torch.cuda.is_available()
|
| 26 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 27 |
+
)
|
tc7/dataset.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset, concatenate_datasets
|
| 2 |
+
|
| 3 |
+
ds1 = load_dataset("JacobLinCool/taiko-2023-1.1", split="train")
|
| 4 |
+
ds2 = load_dataset("JacobLinCool/taiko-2023-1.2", split="train")
|
| 5 |
+
ds3 = load_dataset("JacobLinCool/taiko-2023-1.3", split="train")
|
| 6 |
+
ds4 = load_dataset("JacobLinCool/taiko-2023-1.4", split="train")
|
| 7 |
+
ds5 = load_dataset("JacobLinCool/taiko-2023-1.5", split="train")
|
| 8 |
+
ds6 = load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
| 9 |
+
ds7 = load_dataset("JacobLinCool/taiko-2023-1.7", split="train")
|
| 10 |
+
ds = concatenate_datasets([ds1, ds2, ds3, ds4, ds5, ds6, ds7]).with_format("torch")
|
| 11 |
+
|
| 12 |
+
good = list(range(len(ds)))
|
| 13 |
+
good.remove(1079) # 1079 has file problem
|
| 14 |
+
ds = ds.select(good)
|
| 15 |
+
|
| 16 |
+
# for local test
|
| 17 |
+
# ds = (
|
| 18 |
+
# load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
| 19 |
+
# .with_format("torch")
|
| 20 |
+
# .select(range(10))
|
| 21 |
+
# )
|
tc7/infer.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH
|
| 7 |
+
import torch.profiler
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# --- PREPROCESSING (match training) ---
|
| 11 |
+
def preprocess_audio(audio_path):
|
| 12 |
+
wav, sr = torchaudio.load(audio_path)
|
| 13 |
+
wav = wav.mean(dim=0) # mono
|
| 14 |
+
if sr != SAMPLE_RATE:
|
| 15 |
+
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
|
| 16 |
+
wav = wav / (wav.abs().max() + 1e-8) # Normalize audio
|
| 17 |
+
|
| 18 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
| 19 |
+
sample_rate=SAMPLE_RATE,
|
| 20 |
+
n_mels=N_MELS,
|
| 21 |
+
hop_length=HOP_LENGTH,
|
| 22 |
+
n_fft=2048,
|
| 23 |
+
)
|
| 24 |
+
mel = mel_transform(wav)
|
| 25 |
+
return mel # mel is (N_MELS, T_mel)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# --- INFERENCE ---
|
| 29 |
+
def run_inference(model, mel_input, nps_input, difficulty_input, level_input, device):
|
| 30 |
+
model.eval()
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel)
|
| 33 |
+
nps = nps_input.to(device).unsqueeze(0) # (1,)
|
| 34 |
+
difficulty = difficulty_input.to(device).unsqueeze(0) # (1,)
|
| 35 |
+
level = level_input.to(device).unsqueeze(0) # (1,)
|
| 36 |
+
|
| 37 |
+
mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel)
|
| 38 |
+
|
| 39 |
+
conformer_lengths = torch.tensor(
|
| 40 |
+
[mel_cnn_input.shape[-1]], dtype=torch.long, device=device
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
with torch.profiler.profile(
|
| 44 |
+
activities=[
|
| 45 |
+
torch.profiler.ProfilerActivity.CPU,
|
| 46 |
+
*(
|
| 47 |
+
[torch.profiler.ProfilerActivity.CUDA]
|
| 48 |
+
if device.type == "cuda"
|
| 49 |
+
else []
|
| 50 |
+
),
|
| 51 |
+
],
|
| 52 |
+
record_shapes=True,
|
| 53 |
+
profile_memory=True,
|
| 54 |
+
with_stack=False,
|
| 55 |
+
with_flops=True,
|
| 56 |
+
) as prof:
|
| 57 |
+
out_dict = model(mel_cnn_input, conformer_lengths, nps, difficulty, level)
|
| 58 |
+
print(
|
| 59 |
+
prof.key_averages().table(
|
| 60 |
+
sort_by=(
|
| 61 |
+
"self_cuda_memory_usage"
|
| 62 |
+
if device.type == "cuda"
|
| 63 |
+
else "self_cpu_time_total"
|
| 64 |
+
),
|
| 65 |
+
row_limit=20,
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
energies = out_dict["presence"].squeeze(0).cpu().numpy()
|
| 70 |
+
|
| 71 |
+
don_energy = energies[:, 0]
|
| 72 |
+
ka_energy = energies[:, 1]
|
| 73 |
+
drumroll_energy = energies[:, 2]
|
| 74 |
+
|
| 75 |
+
return don_energy, ka_energy, drumroll_energy
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# --- DECODE TO ONSETS ---
|
| 79 |
+
def decode_onsets(
|
| 80 |
+
don_energy,
|
| 81 |
+
ka_energy,
|
| 82 |
+
drumroll_energy,
|
| 83 |
+
hop_sec,
|
| 84 |
+
threshold=0.5,
|
| 85 |
+
min_distance_frames=3,
|
| 86 |
+
):
|
| 87 |
+
results = []
|
| 88 |
+
T_out = len(don_energy)
|
| 89 |
+
last_onset_frame = -min_distance_frames
|
| 90 |
+
|
| 91 |
+
for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection
|
| 92 |
+
if i < last_onset_frame + min_distance_frames:
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i]
|
| 96 |
+
energies_at_i = {
|
| 97 |
+
1: e_don,
|
| 98 |
+
2: e_ka,
|
| 99 |
+
5: e_drum,
|
| 100 |
+
} # Type mapping: 1:Don, 2:Ka, 5:Drumroll
|
| 101 |
+
|
| 102 |
+
# Find which energy is max and if it's a peak above threshold
|
| 103 |
+
# Sort by energy value descending to prioritize higher energy in case of ties for peak condition
|
| 104 |
+
sorted_types_by_energy = sorted(
|
| 105 |
+
energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
detected_this_frame = False
|
| 109 |
+
for onset_type in sorted_types_by_energy:
|
| 110 |
+
current_energy_series = None
|
| 111 |
+
if onset_type == 1:
|
| 112 |
+
current_energy_series = don_energy
|
| 113 |
+
elif onset_type == 2:
|
| 114 |
+
current_energy_series = ka_energy
|
| 115 |
+
elif onset_type == 5:
|
| 116 |
+
current_energy_series = drumroll_energy
|
| 117 |
+
|
| 118 |
+
energy_val = current_energy_series[i]
|
| 119 |
+
|
| 120 |
+
if (
|
| 121 |
+
energy_val > threshold
|
| 122 |
+
and energy_val > current_energy_series[i - 1]
|
| 123 |
+
and energy_val > current_energy_series[i + 1]
|
| 124 |
+
):
|
| 125 |
+
# Check if this energy is the highest among the three at this frame
|
| 126 |
+
# This check is implicitly handled by iterating `sorted_types_by_energy`
|
| 127 |
+
# and breaking after the first detection.
|
| 128 |
+
results.append((i * hop_sec, onset_type))
|
| 129 |
+
last_onset_frame = i
|
| 130 |
+
detected_this_frame = True
|
| 131 |
+
break # Only one onset type per frame
|
| 132 |
+
|
| 133 |
+
return results
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# --- VISUALIZATION ---
|
| 137 |
+
def plot_results(
|
| 138 |
+
mel_spectrogram,
|
| 139 |
+
don_energy,
|
| 140 |
+
ka_energy,
|
| 141 |
+
drumroll_energy,
|
| 142 |
+
onsets,
|
| 143 |
+
hop_sec,
|
| 144 |
+
out_path=None,
|
| 145 |
+
):
|
| 146 |
+
# mel_spectrogram is (N_MELS, T_mel)
|
| 147 |
+
T_mel = mel_spectrogram.shape[1]
|
| 148 |
+
T_out = len(don_energy) # Length of energy arrays (model output time dimension)
|
| 149 |
+
|
| 150 |
+
# Time axes
|
| 151 |
+
time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE)
|
| 152 |
+
# hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE
|
| 153 |
+
# However, the model output T_out is related to T_mel (input to CNN).
|
| 154 |
+
# If CNN does not change time dimension, T_out = T_mel.
|
| 155 |
+
# If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB.
|
| 156 |
+
# The `lengths` passed to conformer in `run_inference` is T_mel.
|
| 157 |
+
# The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`.
|
| 158 |
+
# So, T_out from model is T_mel.
|
| 159 |
+
# The `hop_sec` for onsets should be based on the model output frame rate.
|
| 160 |
+
# If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE.
|
| 161 |
+
# The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE.
|
| 162 |
+
# This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels.
|
| 163 |
+
# The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`.
|
| 164 |
+
# In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN.
|
| 165 |
+
# The `lengths` for the conformer is based on this T_cnn_out.
|
| 166 |
+
# So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps.
|
| 167 |
+
# Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer.
|
| 168 |
+
# Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames.
|
| 169 |
+
time_axis_energies = np.arange(T_out) * hop_sec
|
| 170 |
+
|
| 171 |
+
fig, ax1 = plt.subplots(figsize=(100, 10))
|
| 172 |
+
|
| 173 |
+
# Plot Mel Spectrogram on ax1
|
| 174 |
+
mel_db = torchaudio.functional.amplitude_to_DB(
|
| 175 |
+
mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0
|
| 176 |
+
)
|
| 177 |
+
img = ax1.imshow(
|
| 178 |
+
mel_db.numpy(),
|
| 179 |
+
aspect="auto",
|
| 180 |
+
origin="lower",
|
| 181 |
+
cmap="magma",
|
| 182 |
+
extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS],
|
| 183 |
+
)
|
| 184 |
+
ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets")
|
| 185 |
+
ax1.set_xlabel("Time (s)")
|
| 186 |
+
ax1.set_ylabel("Mel Bin")
|
| 187 |
+
fig.colorbar(img, ax=ax1, format="%+2.0f dB")
|
| 188 |
+
|
| 189 |
+
# Create a second y-axis for energies
|
| 190 |
+
ax2 = ax1.twinx()
|
| 191 |
+
ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red")
|
| 192 |
+
ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue")
|
| 193 |
+
ax2.plot(
|
| 194 |
+
time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green"
|
| 195 |
+
)
|
| 196 |
+
ax2.set_ylabel("Energy")
|
| 197 |
+
ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded
|
| 198 |
+
|
| 199 |
+
# Overlay onsets from decode_onsets (t is already in seconds)
|
| 200 |
+
labeled_types = set()
|
| 201 |
+
# Group drumrolls into segments (reuse logic from write_tja)
|
| 202 |
+
drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5]
|
| 203 |
+
drumroll_times.sort()
|
| 204 |
+
drumroll_segments = []
|
| 205 |
+
if drumroll_times:
|
| 206 |
+
seg_start = drumroll_times[0]
|
| 207 |
+
prev = drumroll_times[0]
|
| 208 |
+
for t in drumroll_times[1:]:
|
| 209 |
+
if t - prev <= hop_sec * 6: # up to 5-frame gap
|
| 210 |
+
prev = t
|
| 211 |
+
else:
|
| 212 |
+
drumroll_segments.append((seg_start, prev))
|
| 213 |
+
seg_start = t
|
| 214 |
+
prev = t
|
| 215 |
+
drumroll_segments.append((seg_start, prev))
|
| 216 |
+
# Plot Don/Ka onsets as vertical lines
|
| 217 |
+
for t_sec, typ in onsets:
|
| 218 |
+
if typ == 5:
|
| 219 |
+
continue # skip drumroll onsets
|
| 220 |
+
color_map = {1: "darkred", 2: "darkblue"}
|
| 221 |
+
label_map = {1: "Don Onset", 2: "Ka Onset"}
|
| 222 |
+
line_color = color_map.get(typ, "black")
|
| 223 |
+
line_label = label_map.get(typ, f"Type {typ} Onset")
|
| 224 |
+
if typ not in labeled_types:
|
| 225 |
+
ax1.axvline(
|
| 226 |
+
t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label
|
| 227 |
+
)
|
| 228 |
+
labeled_types.add(typ)
|
| 229 |
+
else:
|
| 230 |
+
ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9)
|
| 231 |
+
# Plot drumroll segments as shaded regions
|
| 232 |
+
for seg_start, seg_end in drumroll_segments:
|
| 233 |
+
ax1.axvspan(
|
| 234 |
+
seg_start,
|
| 235 |
+
seg_end + hop_sec,
|
| 236 |
+
color="green",
|
| 237 |
+
alpha=0.2,
|
| 238 |
+
label="Drumroll Segment" if "drumroll" not in labeled_types else None,
|
| 239 |
+
)
|
| 240 |
+
labeled_types.add("drumroll")
|
| 241 |
+
|
| 242 |
+
# Combine legends from both axes
|
| 243 |
+
lines, labels = ax1.get_legend_handles_labels()
|
| 244 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
| 245 |
+
ax2.legend(lines + lines2, labels + labels2, loc="upper right")
|
| 246 |
+
|
| 247 |
+
fig.tight_layout()
|
| 248 |
+
|
| 249 |
+
# Return plot as image buffer or save to file if path provided
|
| 250 |
+
if out_path:
|
| 251 |
+
plt.savefig(out_path)
|
| 252 |
+
print(f"Saved plot to {out_path}")
|
| 253 |
+
plt.close(fig)
|
| 254 |
+
return out_path
|
| 255 |
+
else:
|
| 256 |
+
# Return plot as in-memory buffer
|
| 257 |
+
return fig
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
|
| 261 |
+
# TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
|
| 262 |
+
# Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
|
| 263 |
+
sec_per_beat = 60 / bpm
|
| 264 |
+
beats_per_measure = 4 # Assuming 4/4 time signature
|
| 265 |
+
sec_per_measure = sec_per_beat * beats_per_measure
|
| 266 |
+
# Step 1: Map onsets to (measure_idx, slot, typ)
|
| 267 |
+
slot_events = []
|
| 268 |
+
for t, typ in onsets:
|
| 269 |
+
measure_idx = int(t // sec_per_measure)
|
| 270 |
+
t_in_measure = t % sec_per_measure
|
| 271 |
+
slot = int(round(t_in_measure / sec_per_measure * quantize))
|
| 272 |
+
if slot >= quantize:
|
| 273 |
+
slot = quantize - 1
|
| 274 |
+
slot_events.append((measure_idx, slot, typ))
|
| 275 |
+
# Step 2: Build measure/slot grid
|
| 276 |
+
if slot_events:
|
| 277 |
+
max_measure_idx = max(m for m, _, _ in slot_events)
|
| 278 |
+
else:
|
| 279 |
+
max_measure_idx = -1
|
| 280 |
+
measures = {i: [0] * quantize for i in range(max_measure_idx + 1)}
|
| 281 |
+
# Step 3: Place Don/Ka, collect drumrolls
|
| 282 |
+
drumroll_slots = set()
|
| 283 |
+
for m, s, typ in slot_events:
|
| 284 |
+
if typ in [1, 2]:
|
| 285 |
+
measures[m][s] = typ
|
| 286 |
+
elif typ == 5:
|
| 287 |
+
drumroll_slots.add((m, s))
|
| 288 |
+
# Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end)
|
| 289 |
+
# Flatten all slots to a list of (measure, slot) sorted
|
| 290 |
+
drumroll_list = sorted(list(drumroll_slots))
|
| 291 |
+
# Group into contiguous regions (allowing a gap of 5 slots)
|
| 292 |
+
grouped = []
|
| 293 |
+
group = []
|
| 294 |
+
for ms in drumroll_list:
|
| 295 |
+
if not group:
|
| 296 |
+
group = [ms]
|
| 297 |
+
else:
|
| 298 |
+
last_m, last_s = group[-1]
|
| 299 |
+
m, s = ms
|
| 300 |
+
# Calculate slot distance, considering measure wrap
|
| 301 |
+
slot_dist = None
|
| 302 |
+
if m == last_m:
|
| 303 |
+
slot_dist = s - last_s
|
| 304 |
+
elif m == last_m + 1 and last_s <= quantize - 1:
|
| 305 |
+
slot_dist = (quantize - 1 - last_s) + s + 1
|
| 306 |
+
else:
|
| 307 |
+
slot_dist = None
|
| 308 |
+
# Allow gap of up to 5 slots (slot_dist <= 6)
|
| 309 |
+
if slot_dist is not None and 1 <= slot_dist <= 6:
|
| 310 |
+
group.append(ms)
|
| 311 |
+
else:
|
| 312 |
+
grouped.append(group)
|
| 313 |
+
group = [ms]
|
| 314 |
+
if group:
|
| 315 |
+
grouped.append(group)
|
| 316 |
+
# Mark 5 (start) and 8 (end) for each group
|
| 317 |
+
for region in grouped:
|
| 318 |
+
if len(region) == 1:
|
| 319 |
+
m, s = region[0]
|
| 320 |
+
measures[m][s] = 5
|
| 321 |
+
# Place 8 in next slot (or next measure if at end)
|
| 322 |
+
if s < quantize - 1:
|
| 323 |
+
measures[m][s + 1] = 8
|
| 324 |
+
elif m < max_measure_idx:
|
| 325 |
+
measures[m + 1][0] = 8
|
| 326 |
+
else:
|
| 327 |
+
m_start, s_start = region[0]
|
| 328 |
+
m_end, s_end = region[-1]
|
| 329 |
+
measures[m_start][s_start] = 5
|
| 330 |
+
measures[m_end][s_end] = 8
|
| 331 |
+
# Fill 0 for middle slots (already 0 by default)
|
| 332 |
+
# Step 5: Generate TJA content
|
| 333 |
+
tja_content = []
|
| 334 |
+
tja_content.append(f"TITLE:{audio} (TC7, {time.strftime('%Y-%m-%d %H:%M:%S')})")
|
| 335 |
+
tja_content.append(f"BPM:{bpm}")
|
| 336 |
+
tja_content.append(f"WAVE:{audio}")
|
| 337 |
+
tja_content.append("OFFSET:0")
|
| 338 |
+
tja_content.append("COURSE:Oni\nLEVEL:9\n")
|
| 339 |
+
tja_content.append("#START")
|
| 340 |
+
for i in range(max_measure_idx + 1):
|
| 341 |
+
notes = measures.get(i, [0] * quantize)
|
| 342 |
+
line = "".join(str(n) for n in notes)
|
| 343 |
+
tja_content.append(line + ",")
|
| 344 |
+
tja_content.append("#END")
|
| 345 |
+
|
| 346 |
+
tja_string = "\n".join(tja_content)
|
| 347 |
+
|
| 348 |
+
# If out_path is provided, also write to file
|
| 349 |
+
if out_path:
|
| 350 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 351 |
+
f.write(tja_string)
|
| 352 |
+
print(f"TJA chart saved to {out_path}")
|
| 353 |
+
|
| 354 |
+
return tja_string
|
tc7/loss.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TaikoLoss(nn.Module):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
reduction="mean",
|
| 9 |
+
nps_penalty_weight_alpha=0.3,
|
| 10 |
+
nps_penalty_weight_beta=1.0,
|
| 11 |
+
):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.mse_loss = nn.MSELoss(reduction="none")
|
| 14 |
+
self.reduction = reduction
|
| 15 |
+
self.nps_penalty_weight_alpha = nps_penalty_weight_alpha
|
| 16 |
+
self.nps_penalty_weight_beta = nps_penalty_weight_beta
|
| 17 |
+
|
| 18 |
+
def forward(self, outputs, batch):
|
| 19 |
+
"""
|
| 20 |
+
Calculates the MSE loss for energy-based predictions, with a two-level penalty
|
| 21 |
+
based on sliding NPS values.
|
| 22 |
+
- A heavier penalty if sliding_nps is 0.
|
| 23 |
+
- A continuous penalty if sliding_nps > 0.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
outputs (dict): Model output, containing 'presence' tensor.
|
| 27 |
+
outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
|
| 28 |
+
batch (dict): Batch data from collate_fn, containing true labels, lengths,
|
| 29 |
+
and sliding_nps_labels.
|
| 30 |
+
batch['sliding_nps_labels'] shape: (B, T)
|
| 31 |
+
Returns:
|
| 32 |
+
torch.Tensor: The calculated loss.
|
| 33 |
+
"""
|
| 34 |
+
pred_energies = outputs["presence"] # (B, T, 3)
|
| 35 |
+
true_don = batch["don_labels"] # (B, T)
|
| 36 |
+
true_ka = batch["ka_labels"] # (B, T)
|
| 37 |
+
true_drumroll = batch["drumroll_labels"] # (B, T)
|
| 38 |
+
true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2).to(
|
| 39 |
+
pred_energies.device
|
| 40 |
+
) # (B, T, 3)
|
| 41 |
+
|
| 42 |
+
B, T, _ = pred_energies.shape
|
| 43 |
+
|
| 44 |
+
# Create a mask based on batch['lengths'] to ignore padded parts of sequences
|
| 45 |
+
# batch['lengths'] gives the actual length of each sequence in the batch
|
| 46 |
+
# mask shape: (B, T)
|
| 47 |
+
mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
|
| 48 |
+
"lengths"
|
| 49 |
+
].to(pred_energies.device).unsqueeze(1)
|
| 50 |
+
# Expand mask to (B, T, 1) to broadcast across the 3 energy channels
|
| 51 |
+
mask_3d = mask_2d.unsqueeze(2) # (B, T, 1)
|
| 52 |
+
|
| 53 |
+
# Calculate element-wise MSE loss
|
| 54 |
+
mse_loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3)
|
| 55 |
+
|
| 56 |
+
# Calculate two-level Sliding NPS penalty
|
| 57 |
+
sliding_nps = batch["sliding_nps_labels"].to(pred_energies.device) # (B, T)
|
| 58 |
+
|
| 59 |
+
penalty_coefficients = torch.zeros_like(sliding_nps) # (B, T)
|
| 60 |
+
|
| 61 |
+
is_zero_nps = sliding_nps == 0.0
|
| 62 |
+
is_not_zero_nps = ~is_zero_nps
|
| 63 |
+
|
| 64 |
+
# Apply heavy penalty where sliding_nps is 0
|
| 65 |
+
penalty_coefficients[is_zero_nps] = self.nps_penalty_weight_beta
|
| 66 |
+
|
| 67 |
+
# Apply continuous penalty where sliding_nps > 0
|
| 68 |
+
penalty_coefficients[is_not_zero_nps] = self.nps_penalty_weight_alpha * (
|
| 69 |
+
1 - sliding_nps[is_not_zero_nps]
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Apply penalty factor to the MSE loss
|
| 73 |
+
loss_elementwise = mse_loss_elementwise * (
|
| 74 |
+
1 + penalty_coefficients.unsqueeze(2)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Apply the mask to the combined loss
|
| 78 |
+
masked_loss = loss_elementwise * mask_3d
|
| 79 |
+
|
| 80 |
+
if self.reduction == "mean":
|
| 81 |
+
# Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
|
| 82 |
+
total_loss = masked_loss.sum()
|
| 83 |
+
num_valid_elements = mask_3d.sum() # Total number of unmasked float values
|
| 84 |
+
if num_valid_elements > 0:
|
| 85 |
+
return total_loss / num_valid_elements
|
| 86 |
+
else:
|
| 87 |
+
# Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
|
| 88 |
+
return torch.tensor(
|
| 89 |
+
0.0, device=pred_energies.device, requires_grad=True
|
| 90 |
+
)
|
| 91 |
+
elif self.reduction == "sum":
|
| 92 |
+
return masked_loss.sum()
|
| 93 |
+
else: # 'none' or any other case
|
| 94 |
+
return masked_loss
|
tc7/model.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchaudio.models import Conformer
|
| 4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 5 |
+
from .config import (
|
| 6 |
+
N_MELS,
|
| 7 |
+
CNN_CH,
|
| 8 |
+
N_HEADS,
|
| 9 |
+
D_MODEL,
|
| 10 |
+
FF_DIM,
|
| 11 |
+
N_LAYERS,
|
| 12 |
+
DROPOUT,
|
| 13 |
+
DEPTHWISE_CONV_KERNEL_SIZE,
|
| 14 |
+
HIDDEN_DIM,
|
| 15 |
+
DEVICE,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TaikoConformer7(nn.Module, PyTorchModelHubMixin):
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super().__init__()
|
| 22 |
+
# 1) CNN frontend: frequency-only pooling
|
| 23 |
+
self.cnn = nn.Sequential(
|
| 24 |
+
nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
|
| 25 |
+
nn.BatchNorm2d(CNN_CH),
|
| 26 |
+
nn.GELU(),
|
| 27 |
+
nn.Dropout2d(DROPOUT),
|
| 28 |
+
nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
|
| 29 |
+
nn.BatchNorm2d(CNN_CH),
|
| 30 |
+
nn.GELU(),
|
| 31 |
+
nn.Dropout2d(DROPOUT),
|
| 32 |
+
)
|
| 33 |
+
feat_dim = CNN_CH * (N_MELS // 4)
|
| 34 |
+
|
| 35 |
+
# 2) Linear projection to model dimension
|
| 36 |
+
self.proj = nn.Linear(feat_dim, D_MODEL)
|
| 37 |
+
|
| 38 |
+
# 3) FiLM conditioning for notes_per_second, difficulty, and level
|
| 39 |
+
self.film_nps = nn.Linear(1, 2 * D_MODEL)
|
| 40 |
+
self.film_difficulty = nn.Linear(
|
| 41 |
+
1, 2 * D_MODEL
|
| 42 |
+
) # Assuming difficulty is a single scalar
|
| 43 |
+
self.film_level = nn.Linear(1, 2 * D_MODEL) # Assuming level is a single scalar
|
| 44 |
+
|
| 45 |
+
# 4) Conformer encoder
|
| 46 |
+
self.encoder = Conformer(
|
| 47 |
+
input_dim=D_MODEL,
|
| 48 |
+
num_heads=N_HEADS,
|
| 49 |
+
ffn_dim=FF_DIM,
|
| 50 |
+
num_layers=N_LAYERS,
|
| 51 |
+
depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
|
| 52 |
+
dropout=DROPOUT,
|
| 53 |
+
use_group_norm=False,
|
| 54 |
+
convolution_first=False,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# 5) Presence regressor head
|
| 58 |
+
self.presence_regressor = nn.Sequential(
|
| 59 |
+
nn.Dropout(DROPOUT),
|
| 60 |
+
nn.Linear(D_MODEL, HIDDEN_DIM),
|
| 61 |
+
nn.GELU(),
|
| 62 |
+
nn.Dropout(DROPOUT),
|
| 63 |
+
nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
|
| 64 |
+
nn.Sigmoid(), # Output between 0 and 1
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# 6) Initialize weights
|
| 68 |
+
for m in self.modules():
|
| 69 |
+
if isinstance(m, nn.Conv2d):
|
| 70 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 71 |
+
elif isinstance(m, nn.Linear):
|
| 72 |
+
nn.init.xavier_uniform_(m.weight)
|
| 73 |
+
if m.bias is not None:
|
| 74 |
+
nn.init.zeros_(m.bias)
|
| 75 |
+
|
| 76 |
+
def forward(
|
| 77 |
+
self,
|
| 78 |
+
mel: torch.Tensor,
|
| 79 |
+
lengths: torch.Tensor,
|
| 80 |
+
notes_per_second: torch.Tensor,
|
| 81 |
+
difficulty: torch.Tensor,
|
| 82 |
+
level: torch.Tensor,
|
| 83 |
+
):
|
| 84 |
+
"""
|
| 85 |
+
Args:
|
| 86 |
+
mel: (B, 1, N_MELS, T_mel)
|
| 87 |
+
lengths: (B,) lengths after CNN
|
| 88 |
+
notes_per_second: (B,) stream of control values
|
| 89 |
+
difficulty: (B,) difficulty values
|
| 90 |
+
level: (B,) level values
|
| 91 |
+
Returns:
|
| 92 |
+
Dict with:
|
| 93 |
+
'presence': (B, T_cnn_out, 3) # Corrected from 4 to 3
|
| 94 |
+
'lengths': lengths
|
| 95 |
+
"""
|
| 96 |
+
# CNN frontend
|
| 97 |
+
x = self.cnn(mel) # (B, C, F, T)
|
| 98 |
+
B, C, F, T = x.size()
|
| 99 |
+
x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
|
| 100 |
+
|
| 101 |
+
# Project to model dimension
|
| 102 |
+
x = self.proj(x) # (B, T, D_MODEL)
|
| 103 |
+
|
| 104 |
+
# FiLM conditioning
|
| 105 |
+
nps = notes_per_second.unsqueeze(-1).float() # (B, 1)
|
| 106 |
+
gamma_beta_nps = self.film_nps(nps) # (B, 2*D_MODEL)
|
| 107 |
+
gamma_nps, beta_nps = gamma_beta_nps.chunk(2, dim=-1)
|
| 108 |
+
x = gamma_nps.unsqueeze(1) * x + beta_nps.unsqueeze(1)
|
| 109 |
+
|
| 110 |
+
diff = difficulty.unsqueeze(-1).float() # (B, 1)
|
| 111 |
+
gamma_beta_diff = self.film_difficulty(diff) # (B, 2*D_MODEL)
|
| 112 |
+
gamma_diff, beta_diff = gamma_beta_diff.chunk(2, dim=-1)
|
| 113 |
+
x = gamma_diff.unsqueeze(1) * x + beta_diff.unsqueeze(1)
|
| 114 |
+
|
| 115 |
+
lvl = level.unsqueeze(-1).float() # (B, 1)
|
| 116 |
+
gamma_beta_lvl = self.film_level(lvl) # (B, 2*D_MODEL)
|
| 117 |
+
gamma_lvl, beta_lvl = gamma_beta_lvl.chunk(2, dim=-1)
|
| 118 |
+
x = gamma_lvl.unsqueeze(1) * x + beta_lvl.unsqueeze(1)
|
| 119 |
+
|
| 120 |
+
# Conformer encoder
|
| 121 |
+
x, _ = self.encoder(x, lengths=lengths)
|
| 122 |
+
|
| 123 |
+
# Presence prediction
|
| 124 |
+
presence = self.presence_regressor(x)
|
| 125 |
+
return {"presence": presence, "lengths": lengths}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
model = TaikoConformer7().to(device=DEVICE)
|
| 130 |
+
print(model)
|
| 131 |
+
|
| 132 |
+
for name, param in model.named_parameters():
|
| 133 |
+
if param.requires_grad:
|
| 134 |
+
print(f"{name}: {param.numel():,}")
|
| 135 |
+
|
| 136 |
+
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 137 |
+
print(f"Total parameters: {params / 1e6:.2f}M")
|
| 138 |
+
|
| 139 |
+
batch_size = 4
|
| 140 |
+
mel_time_steps = 1024
|
| 141 |
+
input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
|
| 142 |
+
|
| 143 |
+
conformer_lengths = torch.tensor(
|
| 144 |
+
[mel_time_steps] * batch_size, dtype=torch.long
|
| 145 |
+
).to(DEVICE)
|
| 146 |
+
|
| 147 |
+
notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
|
| 148 |
+
DEVICE
|
| 149 |
+
)
|
| 150 |
+
difficulty_input = torch.tensor([1.0] * batch_size, dtype=torch.float32).to(
|
| 151 |
+
DEVICE
|
| 152 |
+
) # Example difficulty
|
| 153 |
+
level_input = torch.tensor([5.0] * batch_size, dtype=torch.float32).to(
|
| 154 |
+
DEVICE
|
| 155 |
+
) # Example level
|
| 156 |
+
|
| 157 |
+
output = model(
|
| 158 |
+
input_mel,
|
| 159 |
+
conformer_lengths,
|
| 160 |
+
notes_per_second_input,
|
| 161 |
+
difficulty_input,
|
| 162 |
+
level_input,
|
| 163 |
+
)
|
| 164 |
+
print("Output shapes:")
|
| 165 |
+
for key, value in output.items():
|
| 166 |
+
print(f"{key}: {value.shape}")
|
tc7/preprocess.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchaudio
|
| 5 |
+
from torchaudio.transforms import FrequencyMasking
|
| 6 |
+
from tja import parse_tja, PyParsingMode
|
| 7 |
+
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
|
| 8 |
+
from .model import TaikoConformer7
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
| 12 |
+
sample_rate=SAMPLE_RATE,
|
| 13 |
+
n_mels=N_MELS,
|
| 14 |
+
hop_length=HOP_LENGTH,
|
| 15 |
+
n_fft=2048,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
freq_mask = FrequencyMasking(freq_mask_param=15)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def preprocess(example, difficulty="oni"):
|
| 23 |
+
wav_tensor = example["audio"]["array"]
|
| 24 |
+
sr = example["audio"]["sampling_rate"]
|
| 25 |
+
# 1) load & resample
|
| 26 |
+
if sr != SAMPLE_RATE:
|
| 27 |
+
wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
|
| 28 |
+
# normalize audio
|
| 29 |
+
wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
|
| 30 |
+
# add random Gaussian noise
|
| 31 |
+
if torch.rand(1).item() < 0.5:
|
| 32 |
+
wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
|
| 33 |
+
# 2) mel: (1, N_MELS, T)
|
| 34 |
+
mel = mel_transform(wav_tensor).unsqueeze(0)
|
| 35 |
+
# apply SpecAugment
|
| 36 |
+
mel = freq_mask(mel)
|
| 37 |
+
_, _, T = mel.shape
|
| 38 |
+
# 3) build label sequence of length ceil(T / TIME_SUB)
|
| 39 |
+
T_sub = math.ceil(T / TIME_SUB)
|
| 40 |
+
|
| 41 |
+
# Initialize energy-based labels for Don, Ka, Drumroll
|
| 42 |
+
don_labels = torch.zeros(T_sub, dtype=torch.float32)
|
| 43 |
+
ka_labels = torch.zeros(T_sub, dtype=torch.float32)
|
| 44 |
+
drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
|
| 45 |
+
sliding_nps_labels = torch.zeros(
|
| 46 |
+
T_sub, dtype=torch.float32
|
| 47 |
+
) # New label for sliding NPS
|
| 48 |
+
|
| 49 |
+
# Define exponential decay tail parameters
|
| 50 |
+
tail_length = 40 # number of frames for decay tail
|
| 51 |
+
decay_rate = 8.0 # decay rate parameter, adjust as needed
|
| 52 |
+
tail_kernel = torch.exp(
|
| 53 |
+
-torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
fps = SAMPLE_RATE / HOP_LENGTH
|
| 57 |
+
num_valid_notes = 0
|
| 58 |
+
for onset in example[difficulty]:
|
| 59 |
+
typ, t_start, t_end, *_ = onset
|
| 60 |
+
|
| 61 |
+
# Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
|
| 62 |
+
if typ < 1 or typ > N_TYPES: # Filter out invalid types
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
num_valid_notes += 1
|
| 66 |
+
|
| 67 |
+
exact_frame_start = t_start.item() * fps
|
| 68 |
+
|
| 69 |
+
# Type 1 and 3 are Don, Type 2 and 4 are Ka
|
| 70 |
+
if typ == 1 or typ == 3 or typ == 2 or typ == 4:
|
| 71 |
+
exact_hit_time_sub = exact_frame_start / TIME_SUB
|
| 72 |
+
|
| 73 |
+
current_labels = don_labels if (typ == 1 or typ == 3) else ka_labels
|
| 74 |
+
|
| 75 |
+
start_points_info = []
|
| 76 |
+
rounded_hit_time_sub = round(exact_hit_time_sub)
|
| 77 |
+
|
| 78 |
+
if (
|
| 79 |
+
abs(exact_hit_time_sub - rounded_hit_time_sub) < 1e-6
|
| 80 |
+
): # Tolerance for float precision
|
| 81 |
+
idx_single = int(rounded_hit_time_sub)
|
| 82 |
+
if 0 <= idx_single < T_sub:
|
| 83 |
+
start_points_info.append({"idx": idx_single, "weight": 1.0})
|
| 84 |
+
else:
|
| 85 |
+
idx_floor = math.floor(exact_hit_time_sub)
|
| 86 |
+
idx_ceil = idx_floor + 1
|
| 87 |
+
|
| 88 |
+
frac = exact_hit_time_sub - idx_floor
|
| 89 |
+
weight_ceil = frac
|
| 90 |
+
weight_floor = 1.0 - frac
|
| 91 |
+
|
| 92 |
+
if weight_floor > 1e-6 and 0 <= idx_floor < T_sub:
|
| 93 |
+
start_points_info.append({"idx": idx_floor, "weight": weight_floor})
|
| 94 |
+
if weight_ceil > 1e-6 and 0 <= idx_ceil < T_sub:
|
| 95 |
+
start_points_info.append({"idx": idx_ceil, "weight": weight_ceil})
|
| 96 |
+
|
| 97 |
+
for point_info in start_points_info:
|
| 98 |
+
start_idx = point_info["idx"]
|
| 99 |
+
weight = point_info["weight"]
|
| 100 |
+
for k_idx, kernel_val in enumerate(tail_kernel):
|
| 101 |
+
target_idx = start_idx + k_idx
|
| 102 |
+
if 0 <= target_idx < T_sub:
|
| 103 |
+
current_labels[target_idx] = max(
|
| 104 |
+
current_labels[target_idx].item(),
|
| 105 |
+
weight * kernel_val.item(),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Type 5, 6, 7 are Drumroll
|
| 109 |
+
elif typ >= 5 and typ <= 7:
|
| 110 |
+
exact_frame_end = t_end.item() * fps
|
| 111 |
+
exact_start_time_sub = exact_frame_start / TIME_SUB
|
| 112 |
+
exact_end_time_sub = exact_frame_end / TIME_SUB
|
| 113 |
+
|
| 114 |
+
# Improved drumroll body
|
| 115 |
+
body_loop_start_idx = math.floor(exact_start_time_sub)
|
| 116 |
+
body_loop_end_idx = math.ceil(exact_end_time_sub)
|
| 117 |
+
|
| 118 |
+
for dr_idx in range(body_loop_start_idx, body_loop_end_idx):
|
| 119 |
+
if 0 <= dr_idx < T_sub:
|
| 120 |
+
drumroll_labels[dr_idx] = 1.0
|
| 121 |
+
|
| 122 |
+
# Improved drumroll tail (starts from exact_end_time_sub)
|
| 123 |
+
tail_start_points_info = []
|
| 124 |
+
rounded_end_time_sub = round(exact_end_time_sub)
|
| 125 |
+
if abs(exact_end_time_sub - rounded_end_time_sub) < 1e-6:
|
| 126 |
+
idx_single_tail = int(rounded_end_time_sub)
|
| 127 |
+
if 0 <= idx_single_tail < T_sub:
|
| 128 |
+
tail_start_points_info.append(
|
| 129 |
+
{"idx": idx_single_tail, "weight": 1.0}
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
idx_floor_tail = math.floor(exact_end_time_sub)
|
| 133 |
+
idx_ceil_tail = idx_floor_tail + 1
|
| 134 |
+
|
| 135 |
+
frac_tail = exact_end_time_sub - idx_floor_tail
|
| 136 |
+
weight_ceil_tail = frac_tail
|
| 137 |
+
weight_floor_tail = 1.0 - frac_tail
|
| 138 |
+
|
| 139 |
+
if weight_floor_tail > 1e-6 and 0 <= idx_floor_tail < T_sub:
|
| 140 |
+
tail_start_points_info.append(
|
| 141 |
+
{"idx": idx_floor_tail, "weight": weight_floor_tail}
|
| 142 |
+
)
|
| 143 |
+
if weight_ceil_tail > 1e-6 and 0 <= idx_ceil_tail < T_sub:
|
| 144 |
+
tail_start_points_info.append(
|
| 145 |
+
{"idx": idx_ceil_tail, "weight": weight_ceil_tail}
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
for point_info in tail_start_points_info:
|
| 149 |
+
start_idx = point_info["idx"]
|
| 150 |
+
weight = point_info["weight"]
|
| 151 |
+
for k_idx, kernel_val in enumerate(tail_kernel):
|
| 152 |
+
target_idx = start_idx + k_idx
|
| 153 |
+
if 0 <= target_idx < T_sub:
|
| 154 |
+
drumroll_labels[target_idx] = max(
|
| 155 |
+
drumroll_labels[target_idx].item(),
|
| 156 |
+
weight * kernel_val.item(),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Calculate sliding window NPS
|
| 160 |
+
note_events = (
|
| 161 |
+
[]
|
| 162 |
+
) # Store tuples of (time_sec, type_is_drumroll_start_or_end, duration_if_drumroll)
|
| 163 |
+
for onset in example[difficulty]:
|
| 164 |
+
typ, t_start_tensor, t_end_tensor, *_ = onset
|
| 165 |
+
t_start = t_start_tensor.item()
|
| 166 |
+
t_end = t_end_tensor.item()
|
| 167 |
+
|
| 168 |
+
if typ in [1, 2, 3, 4]: # Don or Ka
|
| 169 |
+
note_events.append(
|
| 170 |
+
(t_start, False, 0)
|
| 171 |
+
) # False indicates not a drumroll event, duration 0
|
| 172 |
+
elif typ >= 5 and typ <= 7: # Drumroll
|
| 173 |
+
note_events.append(
|
| 174 |
+
(t_start, True, t_end - t_start)
|
| 175 |
+
) # True indicates drumroll start, store duration
|
| 176 |
+
# We don't explicitly need a drumroll end event for this calculation method
|
| 177 |
+
|
| 178 |
+
note_events.sort(key=lambda x: x[0]) # Sort by time
|
| 179 |
+
|
| 180 |
+
window_duration_seconds = 0.5
|
| 181 |
+
# drumroll_nps_rate = 10.0 # Removed: Will use adaptive rate
|
| 182 |
+
|
| 183 |
+
# Step 1: Calculate base_sliding_nps_labels (Don/Ka only)
|
| 184 |
+
base_don_ka_sliding_nps = torch.zeros(T_sub, dtype=torch.float32)
|
| 185 |
+
time_step_duration_sec = TIME_SUB / fps # Duration of one T_sub segment
|
| 186 |
+
|
| 187 |
+
for k_idx in range(T_sub):
|
| 188 |
+
k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps
|
| 189 |
+
k_window_start_sec = k_window_end_sec - window_duration_seconds
|
| 190 |
+
|
| 191 |
+
current_don_ka_count = 0.0
|
| 192 |
+
for event_t, is_drumroll, _ in note_events:
|
| 193 |
+
if not is_drumroll: # Don or Ka hit
|
| 194 |
+
if k_window_start_sec <= event_t < k_window_end_sec:
|
| 195 |
+
current_don_ka_count += 1
|
| 196 |
+
base_don_ka_sliding_nps[k_idx] = current_don_ka_count / window_duration_seconds
|
| 197 |
+
|
| 198 |
+
# Step 2: Calculate adaptive_drumroll_rates_for_all_events
|
| 199 |
+
adaptive_drumroll_rates_for_all_events = []
|
| 200 |
+
for event_t, is_drumroll, drumroll_dur in note_events:
|
| 201 |
+
if is_drumroll:
|
| 202 |
+
drumroll_start_sec = event_t
|
| 203 |
+
drumroll_end_sec = event_t + drumroll_dur
|
| 204 |
+
|
| 205 |
+
slice_start_idx = math.floor(drumroll_start_sec / time_step_duration_sec)
|
| 206 |
+
slice_end_idx = math.ceil(drumroll_end_sec / time_step_duration_sec)
|
| 207 |
+
|
| 208 |
+
slice_start_idx = max(0, slice_start_idx)
|
| 209 |
+
slice_end_idx = min(T_sub, slice_end_idx)
|
| 210 |
+
|
| 211 |
+
max_nps_in_drumroll_period = 0.0
|
| 212 |
+
if slice_start_idx < slice_end_idx:
|
| 213 |
+
relevant_base_nps_values = base_don_ka_sliding_nps[
|
| 214 |
+
slice_start_idx:slice_end_idx
|
| 215 |
+
]
|
| 216 |
+
if relevant_base_nps_values.numel() > 0:
|
| 217 |
+
max_nps_in_drumroll_period = torch.max(
|
| 218 |
+
relevant_base_nps_values
|
| 219 |
+
).item()
|
| 220 |
+
|
| 221 |
+
rate = max(5.0, max_nps_in_drumroll_period)
|
| 222 |
+
adaptive_drumroll_rates_for_all_events.append(rate)
|
| 223 |
+
else:
|
| 224 |
+
adaptive_drumroll_rates_for_all_events.append(0.0) # Placeholder
|
| 225 |
+
|
| 226 |
+
# Step 3: Calculate final sliding_nps_labels using adaptive rates
|
| 227 |
+
# sliding_nps_labels is already initialized with zeros earlier in the function.
|
| 228 |
+
for k_idx in range(T_sub):
|
| 229 |
+
k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps
|
| 230 |
+
k_window_start_sec = k_window_end_sec - window_duration_seconds
|
| 231 |
+
|
| 232 |
+
current_window_total_nps_contribution = 0.0
|
| 233 |
+
for event_idx, (event_t, is_drumroll, drumroll_dur) in enumerate(note_events):
|
| 234 |
+
if is_drumroll:
|
| 235 |
+
drumroll_start_sec = event_t
|
| 236 |
+
drumroll_end_sec = event_t + drumroll_dur
|
| 237 |
+
|
| 238 |
+
overlap_start = max(k_window_start_sec, drumroll_start_sec)
|
| 239 |
+
overlap_end = min(k_window_end_sec, drumroll_end_sec)
|
| 240 |
+
|
| 241 |
+
if overlap_end > overlap_start:
|
| 242 |
+
overlap_duration = overlap_end - overlap_start
|
| 243 |
+
current_adaptive_rate = adaptive_drumroll_rates_for_all_events[
|
| 244 |
+
event_idx
|
| 245 |
+
]
|
| 246 |
+
current_window_total_nps_contribution += (
|
| 247 |
+
overlap_duration * current_adaptive_rate
|
| 248 |
+
)
|
| 249 |
+
else: # Don or Ka hit
|
| 250 |
+
if k_window_start_sec <= event_t < k_window_end_sec:
|
| 251 |
+
current_window_total_nps_contribution += (
|
| 252 |
+
1 # Each hit contributes 1 to the count
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
sliding_nps_labels[k_idx] = (
|
| 256 |
+
current_window_total_nps_contribution / window_duration_seconds
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Normalize sliding_nps_labels to 0-1 range
|
| 260 |
+
if T_sub > 0: # Ensure there are elements to normalize
|
| 261 |
+
min_nps_val = torch.min(sliding_nps_labels)
|
| 262 |
+
max_nps_val = torch.max(sliding_nps_labels)
|
| 263 |
+
denominator = max_nps_val - min_nps_val
|
| 264 |
+
if denominator > 1e-6: # Use a small epsilon for float comparison
|
| 265 |
+
sliding_nps_labels = (sliding_nps_labels - min_nps_val) / denominator
|
| 266 |
+
else:
|
| 267 |
+
# If all values are (nearly) the same
|
| 268 |
+
if max_nps_val > 1e-6: # If the constant value is positive
|
| 269 |
+
sliding_nps_labels = torch.ones_like(sliding_nps_labels)
|
| 270 |
+
else: # If the constant value is zero (or very close to it)
|
| 271 |
+
sliding_nps_labels = torch.zeros_like(sliding_nps_labels)
|
| 272 |
+
|
| 273 |
+
duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
|
| 274 |
+
nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
|
| 275 |
+
|
| 276 |
+
parsed = parse_tja(example["tja"], mode=PyParsingMode.Full)
|
| 277 |
+
chart = next(
|
| 278 |
+
(chart for chart in parsed.charts if chart.course.lower() == difficulty), None
|
| 279 |
+
)
|
| 280 |
+
difficulty_id = (
|
| 281 |
+
0
|
| 282 |
+
if difficulty == "easy"
|
| 283 |
+
else (
|
| 284 |
+
1
|
| 285 |
+
if difficulty == "normal"
|
| 286 |
+
else 2 if difficulty == "hard" else 3 if difficulty == "oni" else 4
|
| 287 |
+
) # Assuming 4 for edit/ura
|
| 288 |
+
)
|
| 289 |
+
level = chart.level if chart else 0
|
| 290 |
+
|
| 291 |
+
# --- CNN shape inference and label padding/truncation ---
|
| 292 |
+
# Simulate CNN to get output time length (T_cnn)
|
| 293 |
+
dummy_model = TaikoConformer7()
|
| 294 |
+
with torch.no_grad():
|
| 295 |
+
cnn_out = dummy_model.cnn(mel.unsqueeze(0)) # (1, C, F, T_cnn)
|
| 296 |
+
_, _, _, T_cnn = cnn_out.shape
|
| 297 |
+
|
| 298 |
+
# Pad or truncate labels to T_cnn
|
| 299 |
+
def pad_or_truncate(label, out_len):
|
| 300 |
+
if label.shape[0] < out_len:
|
| 301 |
+
pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
|
| 302 |
+
return torch.cat([label, pad], dim=0)
|
| 303 |
+
else:
|
| 304 |
+
return label[:out_len]
|
| 305 |
+
|
| 306 |
+
don_labels = pad_or_truncate(don_labels, T_cnn)
|
| 307 |
+
ka_labels = pad_or_truncate(ka_labels, T_cnn)
|
| 308 |
+
drumroll_labels = pad_or_truncate(drumroll_labels, T_cnn)
|
| 309 |
+
sliding_nps_labels = pad_or_truncate(sliding_nps_labels, T_cnn) # Pad new label
|
| 310 |
+
|
| 311 |
+
# For conformer input lengths: this should be T_cnn
|
| 312 |
+
conformer_sequence_length = T_cnn # This is the actual sequence length after CNN
|
| 313 |
+
|
| 314 |
+
print(
|
| 315 |
+
f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}, Difficulty: {difficulty_id}, Level: {level}"
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
return {
|
| 319 |
+
"mel": mel, # (1, N_MELS, T)
|
| 320 |
+
"don_labels": don_labels, # (T_cnn,)
|
| 321 |
+
"ka_labels": ka_labels, # (T_cnn,)
|
| 322 |
+
"drumroll_labels": drumroll_labels, # (T_cnn,)
|
| 323 |
+
"sliding_nps_labels": sliding_nps_labels, # Add new label (T_cnn,)
|
| 324 |
+
"nps": torch.tensor(nps, dtype=torch.float32),
|
| 325 |
+
"difficulty": torch.tensor(difficulty_id, dtype=torch.long),
|
| 326 |
+
"level": torch.tensor(level, dtype=torch.long),
|
| 327 |
+
"duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
|
| 328 |
+
"length": torch.tensor(
|
| 329 |
+
conformer_sequence_length, dtype=torch.long
|
| 330 |
+
), # Use T_cnn for conformer and loss masking
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def collate_fn(batch):
|
| 335 |
+
mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS)
|
| 336 |
+
don_labels_list = [b["don_labels"] for b in batch]
|
| 337 |
+
ka_labels_list = [b["ka_labels"] for b in batch]
|
| 338 |
+
drumroll_labels_list = [b["drumroll_labels"] for b in batch]
|
| 339 |
+
sliding_nps_labels_list = [b["sliding_nps_labels"] for b in batch] # New label list
|
| 340 |
+
nps_list = [b["nps"] for b in batch]
|
| 341 |
+
difficulty_list = [b["difficulty"] for b in batch]
|
| 342 |
+
level_list = [b["level"] for b in batch]
|
| 343 |
+
durations_list = [b["duration_seconds"] for b in batch]
|
| 344 |
+
lengths_list = [b["length"] for b in batch] # These are T_cnn_i for each example
|
| 345 |
+
|
| 346 |
+
# Pad mels
|
| 347 |
+
padded_mels = nn.utils.rnn.pad_sequence(
|
| 348 |
+
mels_list, batch_first=True
|
| 349 |
+
) # (B, T_max_mel, N_MELS)
|
| 350 |
+
reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
|
| 351 |
+
# T_max_mel_batch = padded_mels.shape[1] # Max mel length in batch, not used for label padding anymore
|
| 352 |
+
|
| 353 |
+
# Determine max sequence length for labels (max T_cnn in batch)
|
| 354 |
+
max_label_len = 0
|
| 355 |
+
if lengths_list: # handle empty batch case
|
| 356 |
+
max_label_len = max(l.item() for l in lengths_list) if lengths_list else 0
|
| 357 |
+
|
| 358 |
+
# Pad labels to max_label_len (max_t_cnn_in_batch)
|
| 359 |
+
def pad_label_to_max_len(label_tensor, target_len):
|
| 360 |
+
current_len = label_tensor.shape[0]
|
| 361 |
+
if current_len < target_len:
|
| 362 |
+
padding_size = target_len - current_len
|
| 363 |
+
# Ensure padding is created on the same device as the label_tensor
|
| 364 |
+
padding = torch.zeros(
|
| 365 |
+
padding_size, dtype=label_tensor.dtype, device=label_tensor.device
|
| 366 |
+
)
|
| 367 |
+
return torch.cat((label_tensor, padding), dim=0)
|
| 368 |
+
elif (
|
| 369 |
+
current_len > target_len
|
| 370 |
+
): # Should ideally not happen if lengths_list is correct
|
| 371 |
+
return label_tensor[:target_len]
|
| 372 |
+
return label_tensor
|
| 373 |
+
|
| 374 |
+
don_labels = torch.stack(
|
| 375 |
+
[pad_label_to_max_len(l, max_label_len) for l in don_labels_list]
|
| 376 |
+
)
|
| 377 |
+
ka_labels = torch.stack(
|
| 378 |
+
[pad_label_to_max_len(l, max_label_len) for l in ka_labels_list]
|
| 379 |
+
)
|
| 380 |
+
drumroll_labels = torch.stack(
|
| 381 |
+
[pad_label_to_max_len(l, max_label_len) for l in drumroll_labels_list]
|
| 382 |
+
)
|
| 383 |
+
sliding_nps_labels = torch.stack(
|
| 384 |
+
[pad_label_to_max_len(l, max_label_len) for l in sliding_nps_labels_list]
|
| 385 |
+
) # Pad new labels
|
| 386 |
+
|
| 387 |
+
actual_lengths = torch.tensor([l.item() for l in lengths_list], dtype=torch.long)
|
| 388 |
+
|
| 389 |
+
return {
|
| 390 |
+
"mel": reshaped_mels,
|
| 391 |
+
"don_labels": don_labels,
|
| 392 |
+
"ka_labels": ka_labels,
|
| 393 |
+
"drumroll_labels": drumroll_labels,
|
| 394 |
+
"sliding_nps_labels": sliding_nps_labels, # Add new batched labels
|
| 395 |
+
"lengths": actual_lengths, # for conformer and loss masking (T_cnn_i for each item)
|
| 396 |
+
"nps": torch.stack(nps_list),
|
| 397 |
+
"difficulty": torch.stack(difficulty_list),
|
| 398 |
+
"level": torch.stack(level_list),
|
| 399 |
+
"durations": torch.stack(durations_list),
|
| 400 |
+
}
|
tc7/train.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from accelerate.utils import set_seed
|
| 2 |
+
|
| 3 |
+
set_seed(1024)
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from datasets import concatenate_datasets
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import numpy as np
|
| 13 |
+
from .config import (
|
| 14 |
+
BATCH_SIZE,
|
| 15 |
+
DEVICE,
|
| 16 |
+
EPOCHS,
|
| 17 |
+
LR,
|
| 18 |
+
GRAD_ACCUM_STEPS,
|
| 19 |
+
HOP_LENGTH,
|
| 20 |
+
NPS_PENALTY_WEIGHT_ALPHA,
|
| 21 |
+
NPS_PENALTY_WEIGHT_BETA,
|
| 22 |
+
SAMPLE_RATE,
|
| 23 |
+
)
|
| 24 |
+
from .model import TaikoConformer7
|
| 25 |
+
from .dataset import ds
|
| 26 |
+
from .preprocess import preprocess, collate_fn
|
| 27 |
+
from .loss import TaikoLoss
|
| 28 |
+
from huggingface_hub import upload_folder
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def log_energy_plots_to_tensorboard(
|
| 32 |
+
writer,
|
| 33 |
+
tag_prefix,
|
| 34 |
+
epoch,
|
| 35 |
+
pred_don,
|
| 36 |
+
pred_ka,
|
| 37 |
+
pred_drumroll,
|
| 38 |
+
true_don,
|
| 39 |
+
true_ka,
|
| 40 |
+
true_drumroll,
|
| 41 |
+
valid_length,
|
| 42 |
+
hop_sec,
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Logs a plot of predicted vs. true energies for one sample to TensorBoard.
|
| 46 |
+
Energies should be 1D numpy arrays for the single sample, up to valid_length.
|
| 47 |
+
"""
|
| 48 |
+
pred_don = pred_don[:valid_length].detach().cpu().numpy()
|
| 49 |
+
pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
|
| 50 |
+
pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
|
| 51 |
+
true_don = true_don[:valid_length].cpu().numpy()
|
| 52 |
+
true_ka = true_ka[:valid_length].cpu().numpy()
|
| 53 |
+
true_drumroll = true_drumroll[:valid_length].cpu().numpy()
|
| 54 |
+
|
| 55 |
+
time_axis = np.arange(valid_length) * hop_sec
|
| 56 |
+
|
| 57 |
+
fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
|
| 58 |
+
fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)
|
| 59 |
+
|
| 60 |
+
axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
|
| 61 |
+
axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
|
| 62 |
+
axs[0].set_ylabel("Don Energy")
|
| 63 |
+
axs[0].legend()
|
| 64 |
+
axs[0].grid(True)
|
| 65 |
+
|
| 66 |
+
axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
|
| 67 |
+
axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
|
| 68 |
+
axs[1].set_ylabel("Ka Energy")
|
| 69 |
+
axs[1].legend()
|
| 70 |
+
axs[1].grid(True)
|
| 71 |
+
|
| 72 |
+
axs[2].plot(
|
| 73 |
+
time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
|
| 74 |
+
)
|
| 75 |
+
axs[2].plot(
|
| 76 |
+
time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
|
| 77 |
+
)
|
| 78 |
+
axs[2].set_ylabel("Drumroll Energy")
|
| 79 |
+
axs[2].set_xlabel("Time (s)")
|
| 80 |
+
axs[2].legend()
|
| 81 |
+
axs[2].grid(True)
|
| 82 |
+
|
| 83 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
| 84 |
+
writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
|
| 85 |
+
plt.close(fig)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def main():
|
| 89 |
+
global ds
|
| 90 |
+
|
| 91 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 92 |
+
|
| 93 |
+
best_val_loss = float("inf")
|
| 94 |
+
patience = 10
|
| 95 |
+
pat_count = 0
|
| 96 |
+
|
| 97 |
+
ds_oni = ds.map(
|
| 98 |
+
preprocess,
|
| 99 |
+
remove_columns=ds.column_names,
|
| 100 |
+
fn_kwargs={"difficulty": "oni"},
|
| 101 |
+
writer_batch_size=10,
|
| 102 |
+
)
|
| 103 |
+
ds_hard = ds.map(
|
| 104 |
+
preprocess,
|
| 105 |
+
remove_columns=ds.column_names,
|
| 106 |
+
fn_kwargs={"difficulty": "hard"},
|
| 107 |
+
writer_batch_size=10,
|
| 108 |
+
)
|
| 109 |
+
ds_normal = ds.map(
|
| 110 |
+
preprocess,
|
| 111 |
+
remove_columns=ds.column_names,
|
| 112 |
+
fn_kwargs={"difficulty": "normal"},
|
| 113 |
+
writer_batch_size=10,
|
| 114 |
+
)
|
| 115 |
+
ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])
|
| 116 |
+
|
| 117 |
+
ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
|
| 118 |
+
train_loader = DataLoader(
|
| 119 |
+
ds_train_test["train"],
|
| 120 |
+
batch_size=BATCH_SIZE,
|
| 121 |
+
shuffle=True,
|
| 122 |
+
collate_fn=collate_fn,
|
| 123 |
+
num_workers=8,
|
| 124 |
+
persistent_workers=True,
|
| 125 |
+
prefetch_factor=4,
|
| 126 |
+
)
|
| 127 |
+
val_loader = DataLoader(
|
| 128 |
+
ds_train_test["test"],
|
| 129 |
+
batch_size=BATCH_SIZE,
|
| 130 |
+
shuffle=False,
|
| 131 |
+
collate_fn=collate_fn,
|
| 132 |
+
num_workers=8,
|
| 133 |
+
persistent_workers=True,
|
| 134 |
+
prefetch_factor=4,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
model = TaikoConformer7().to(DEVICE)
|
| 138 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
| 139 |
+
|
| 140 |
+
criterion = TaikoLoss(
|
| 141 |
+
reduction="mean",
|
| 142 |
+
nps_penalty_weight_alpha=NPS_PENALTY_WEIGHT_ALPHA,
|
| 143 |
+
nps_penalty_weight_beta=NPS_PENALTY_WEIGHT_BETA,
|
| 144 |
+
).to(DEVICE)
|
| 145 |
+
|
| 146 |
+
num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
|
| 147 |
+
total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch
|
| 148 |
+
|
| 149 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 150 |
+
optimizer, max_lr=LR, total_steps=total_optimizer_steps
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
writer = SummaryWriter()
|
| 154 |
+
|
| 155 |
+
for epoch in range(1, EPOCHS + 1):
|
| 156 |
+
model.train()
|
| 157 |
+
total_epoch_loss = 0.0
|
| 158 |
+
optimizer.zero_grad()
|
| 159 |
+
|
| 160 |
+
for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
|
| 161 |
+
mel = batch["mel"].to(DEVICE)
|
| 162 |
+
lengths = batch["lengths"].to(DEVICE)
|
| 163 |
+
nps = batch["nps"].to(DEVICE)
|
| 164 |
+
difficulty = batch["difficulty"].to(DEVICE)
|
| 165 |
+
level = batch["level"].to(DEVICE)
|
| 166 |
+
|
| 167 |
+
outputs = model(mel, lengths, nps, difficulty, level)
|
| 168 |
+
loss = criterion(outputs, batch)
|
| 169 |
+
|
| 170 |
+
total_epoch_loss += loss.item()
|
| 171 |
+
|
| 172 |
+
loss = loss / GRAD_ACCUM_STEPS
|
| 173 |
+
loss.backward()
|
| 174 |
+
|
| 175 |
+
if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
|
| 176 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 177 |
+
optimizer.step()
|
| 178 |
+
scheduler.step()
|
| 179 |
+
optimizer.zero_grad()
|
| 180 |
+
|
| 181 |
+
writer.add_scalar(
|
| 182 |
+
"Loss/Train_Step",
|
| 183 |
+
loss.item() * GRAD_ACCUM_STEPS,
|
| 184 |
+
epoch * len(train_loader) + idx,
|
| 185 |
+
)
|
| 186 |
+
writer.add_scalar(
|
| 187 |
+
"LR", scheduler.get_last_lr()[0], epoch * len(train_loader) + idx
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
if idx < 3:
|
| 191 |
+
if mel.size(0) > 0:
|
| 192 |
+
pred_don = outputs["presence"][0, :, 0]
|
| 193 |
+
pred_ka = outputs["presence"][0, :, 1]
|
| 194 |
+
pred_drumroll = outputs["presence"][0, :, 2]
|
| 195 |
+
true_don = batch["don_labels"][0]
|
| 196 |
+
true_ka = batch["ka_labels"][0]
|
| 197 |
+
true_drumroll = batch["drumroll_labels"][0]
|
| 198 |
+
valid_length = batch["lengths"][0].item()
|
| 199 |
+
|
| 200 |
+
log_energy_plots_to_tensorboard(
|
| 201 |
+
writer,
|
| 202 |
+
f"Train_Sample_Batch_{idx}_Sample_0",
|
| 203 |
+
epoch,
|
| 204 |
+
pred_don,
|
| 205 |
+
pred_ka,
|
| 206 |
+
pred_drumroll,
|
| 207 |
+
true_don,
|
| 208 |
+
true_ka,
|
| 209 |
+
true_drumroll,
|
| 210 |
+
valid_length,
|
| 211 |
+
output_frame_hop_sec,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
avg_train_loss = total_epoch_loss / len(train_loader)
|
| 215 |
+
writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)
|
| 216 |
+
|
| 217 |
+
model.eval()
|
| 218 |
+
total_val_loss = 0.0
|
| 219 |
+
|
| 220 |
+
with torch.no_grad():
|
| 221 |
+
for idx, batch in enumerate(tqdm(val_loader, desc=f"Val Epoch {epoch}")):
|
| 222 |
+
mel = batch["mel"].to(DEVICE)
|
| 223 |
+
lengths = batch["lengths"].to(DEVICE)
|
| 224 |
+
nps = batch["nps"].to(DEVICE)
|
| 225 |
+
difficulty = batch["difficulty"].to(DEVICE)
|
| 226 |
+
level = batch["level"].to(DEVICE)
|
| 227 |
+
|
| 228 |
+
outputs = model(mel, lengths, nps, difficulty, level)
|
| 229 |
+
loss = criterion(outputs, batch)
|
| 230 |
+
total_val_loss += loss.item()
|
| 231 |
+
|
| 232 |
+
if idx < 3:
|
| 233 |
+
if mel.size(0) > 0:
|
| 234 |
+
pred_don = outputs["presence"][0, :, 0]
|
| 235 |
+
pred_ka = outputs["presence"][0, :, 1]
|
| 236 |
+
pred_drumroll = outputs["presence"][0, :, 2]
|
| 237 |
+
true_don = batch["don_labels"][0]
|
| 238 |
+
true_ka = batch["ka_labels"][0]
|
| 239 |
+
true_drumroll = batch["drumroll_labels"][0]
|
| 240 |
+
valid_length = batch["lengths"][0].item()
|
| 241 |
+
|
| 242 |
+
log_energy_plots_to_tensorboard(
|
| 243 |
+
writer,
|
| 244 |
+
f"Val_Sample_Batch_{idx}_Sample_0",
|
| 245 |
+
epoch,
|
| 246 |
+
pred_don,
|
| 247 |
+
pred_ka,
|
| 248 |
+
pred_drumroll,
|
| 249 |
+
true_don,
|
| 250 |
+
true_ka,
|
| 251 |
+
true_drumroll,
|
| 252 |
+
valid_length,
|
| 253 |
+
output_frame_hop_sec,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
| 257 |
+
writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)
|
| 258 |
+
|
| 259 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
| 260 |
+
writer.add_scalar("LR/learning_rate", current_lr, epoch)
|
| 261 |
+
|
| 262 |
+
if "nps" in batch:
|
| 263 |
+
writer.add_scalar(
|
| 264 |
+
"NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
print(
|
| 268 |
+
f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if avg_val_loss < best_val_loss:
|
| 272 |
+
best_val_loss = avg_val_loss
|
| 273 |
+
pat_count = 0
|
| 274 |
+
torch.save(model.state_dict(), "best_model.pt")
|
| 275 |
+
print(f"Saved new best model to best_model.pt at epoch {epoch}")
|
| 276 |
+
else:
|
| 277 |
+
pat_count += 1
|
| 278 |
+
if pat_count >= patience:
|
| 279 |
+
print("Early stopping!")
|
| 280 |
+
break
|
| 281 |
+
writer.close()
|
| 282 |
+
|
| 283 |
+
model_id = "JacobLinCool/taiko-conformer-7"
|
| 284 |
+
try:
|
| 285 |
+
model.push_to_hub(
|
| 286 |
+
model_id, commit_message=f"Epoch {epoch}, Val Loss: {avg_val_loss:.4f}"
|
| 287 |
+
)
|
| 288 |
+
upload_folder(
|
| 289 |
+
repo_id=model_id,
|
| 290 |
+
folder_path="runs",
|
| 291 |
+
path_in_repo="runs",
|
| 292 |
+
commit_message="Upload TensorBoard logs",
|
| 293 |
+
)
|
| 294 |
+
except Exception as e:
|
| 295 |
+
print(f"Error uploading model or logs: {e}")
|
| 296 |
+
print("Make sure you have the correct permissions and try again.")
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
if __name__ == "__main__":
|
| 300 |
+
main()
|