Spaces:
Sleeping
Sleeping
| from accelerate.utils import set_seed | |
| set_seed(1024) | |
| import math | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| from datasets import concatenate_datasets | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from .config import ( | |
| BATCH_SIZE, | |
| DEVICE, | |
| EPOCHS, | |
| LR, | |
| GRAD_ACCUM_STEPS, | |
| HOP_LENGTH, | |
| NPS_PENALTY_WEIGHT_ALPHA, | |
| NPS_PENALTY_WEIGHT_BETA, | |
| SAMPLE_RATE, | |
| ) | |
| from .model import TaikoConformer7 | |
| from .dataset import ds | |
| from .preprocess import preprocess, collate_fn | |
| from .loss import TaikoLoss | |
| from huggingface_hub import upload_folder | |
| def log_energy_plots_to_tensorboard( | |
| writer, | |
| tag_prefix, | |
| epoch, | |
| pred_don, | |
| pred_ka, | |
| pred_drumroll, | |
| true_don, | |
| true_ka, | |
| true_drumroll, | |
| valid_length, | |
| hop_sec, | |
| ): | |
| """ | |
| Logs a plot of predicted vs. true energies for one sample to TensorBoard. | |
| Energies should be 1D numpy arrays for the single sample, up to valid_length. | |
| """ | |
| pred_don = pred_don[:valid_length].detach().cpu().numpy() | |
| pred_ka = pred_ka[:valid_length].detach().cpu().numpy() | |
| pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy() | |
| true_don = true_don[:valid_length].cpu().numpy() | |
| true_ka = true_ka[:valid_length].cpu().numpy() | |
| true_drumroll = true_drumroll[:valid_length].cpu().numpy() | |
| time_axis = np.arange(valid_length) * hop_sec | |
| fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True) | |
| fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16) | |
| axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--") | |
| axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8) | |
| axs[0].set_ylabel("Don Energy") | |
| axs[0].legend() | |
| axs[0].grid(True) | |
| axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--") | |
| axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8) | |
| axs[1].set_ylabel("Ka Energy") | |
| axs[1].legend() | |
| axs[1].grid(True) | |
| axs[2].plot( | |
| time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--" | |
| ) | |
| axs[2].plot( | |
| time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8 | |
| ) | |
| axs[2].set_ylabel("Drumroll Energy") | |
| axs[2].set_xlabel("Time (s)") | |
| axs[2].legend() | |
| axs[2].grid(True) | |
| plt.tight_layout(rect=[0, 0, 1, 0.96]) | |
| writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch) | |
| plt.close(fig) | |
| def main(): | |
| global ds | |
| output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
| best_val_loss = float("inf") | |
| patience = 10 | |
| pat_count = 0 | |
| ds_oni = ds.map( | |
| preprocess, | |
| remove_columns=ds.column_names, | |
| fn_kwargs={"difficulty": "oni"}, | |
| writer_batch_size=10, | |
| ) | |
| ds_hard = ds.map( | |
| preprocess, | |
| remove_columns=ds.column_names, | |
| fn_kwargs={"difficulty": "hard"}, | |
| writer_batch_size=10, | |
| ) | |
| ds_normal = ds.map( | |
| preprocess, | |
| remove_columns=ds.column_names, | |
| fn_kwargs={"difficulty": "normal"}, | |
| writer_batch_size=10, | |
| ) | |
| ds = concatenate_datasets([ds_oni, ds_hard, ds_normal]) | |
| ds_train_test = ds.train_test_split(test_size=0.1, seed=42) | |
| train_loader = DataLoader( | |
| ds_train_test["train"], | |
| batch_size=BATCH_SIZE, | |
| shuffle=True, | |
| collate_fn=collate_fn, | |
| num_workers=8, | |
| persistent_workers=True, | |
| prefetch_factor=4, | |
| ) | |
| val_loader = DataLoader( | |
| ds_train_test["test"], | |
| batch_size=BATCH_SIZE, | |
| shuffle=False, | |
| collate_fn=collate_fn, | |
| num_workers=8, | |
| persistent_workers=True, | |
| prefetch_factor=4, | |
| ) | |
| model = TaikoConformer7().to(DEVICE) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR) | |
| criterion = TaikoLoss( | |
| reduction="mean", | |
| nps_penalty_weight_alpha=NPS_PENALTY_WEIGHT_ALPHA, | |
| nps_penalty_weight_beta=NPS_PENALTY_WEIGHT_BETA, | |
| ).to(DEVICE) | |
| num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS) | |
| total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch | |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer, max_lr=LR, total_steps=total_optimizer_steps | |
| ) | |
| writer = SummaryWriter() | |
| for epoch in range(1, EPOCHS + 1): | |
| model.train() | |
| total_epoch_loss = 0.0 | |
| optimizer.zero_grad() | |
| for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")): | |
| mel = batch["mel"].to(DEVICE) | |
| lengths = batch["lengths"].to(DEVICE) | |
| nps = batch["nps"].to(DEVICE) | |
| difficulty = batch["difficulty"].to(DEVICE) | |
| level = batch["level"].to(DEVICE) | |
| outputs = model(mel, lengths, nps, difficulty, level) | |
| loss = criterion(outputs, batch) | |
| total_epoch_loss += loss.item() | |
| loss = loss / GRAD_ACCUM_STEPS | |
| loss.backward() | |
| if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader): | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| writer.add_scalar( | |
| "Loss/Train_Step", | |
| loss.item() * GRAD_ACCUM_STEPS, | |
| epoch * len(train_loader) + idx, | |
| ) | |
| writer.add_scalar( | |
| "LR", scheduler.get_last_lr()[0], epoch * len(train_loader) + idx | |
| ) | |
| if idx < 3: | |
| if mel.size(0) > 0: | |
| pred_don = outputs["presence"][0, :, 0] | |
| pred_ka = outputs["presence"][0, :, 1] | |
| pred_drumroll = outputs["presence"][0, :, 2] | |
| true_don = batch["don_labels"][0] | |
| true_ka = batch["ka_labels"][0] | |
| true_drumroll = batch["drumroll_labels"][0] | |
| valid_length = batch["lengths"][0].item() | |
| log_energy_plots_to_tensorboard( | |
| writer, | |
| f"Train_Sample_Batch_{idx}_Sample_0", | |
| epoch, | |
| pred_don, | |
| pred_ka, | |
| pred_drumroll, | |
| true_don, | |
| true_ka, | |
| true_drumroll, | |
| valid_length, | |
| output_frame_hop_sec, | |
| ) | |
| avg_train_loss = total_epoch_loss / len(train_loader) | |
| writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch) | |
| model.eval() | |
| total_val_loss = 0.0 | |
| with torch.no_grad(): | |
| for idx, batch in enumerate(tqdm(val_loader, desc=f"Val Epoch {epoch}")): | |
| mel = batch["mel"].to(DEVICE) | |
| lengths = batch["lengths"].to(DEVICE) | |
| nps = batch["nps"].to(DEVICE) | |
| difficulty = batch["difficulty"].to(DEVICE) | |
| level = batch["level"].to(DEVICE) | |
| outputs = model(mel, lengths, nps, difficulty, level) | |
| loss = criterion(outputs, batch) | |
| total_val_loss += loss.item() | |
| if idx < 3: | |
| if mel.size(0) > 0: | |
| pred_don = outputs["presence"][0, :, 0] | |
| pred_ka = outputs["presence"][0, :, 1] | |
| pred_drumroll = outputs["presence"][0, :, 2] | |
| true_don = batch["don_labels"][0] | |
| true_ka = batch["ka_labels"][0] | |
| true_drumroll = batch["drumroll_labels"][0] | |
| valid_length = batch["lengths"][0].item() | |
| log_energy_plots_to_tensorboard( | |
| writer, | |
| f"Val_Sample_Batch_{idx}_Sample_0", | |
| epoch, | |
| pred_don, | |
| pred_ka, | |
| pred_drumroll, | |
| true_don, | |
| true_ka, | |
| true_drumroll, | |
| valid_length, | |
| output_frame_hop_sec, | |
| ) | |
| avg_val_loss = total_val_loss / len(val_loader) | |
| writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch) | |
| current_lr = optimizer.param_groups[0]["lr"] | |
| writer.add_scalar("LR/learning_rate", current_lr, epoch) | |
| if "nps" in batch: | |
| writer.add_scalar( | |
| "NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch | |
| ) | |
| print( | |
| f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}" | |
| ) | |
| if avg_val_loss < best_val_loss: | |
| best_val_loss = avg_val_loss | |
| pat_count = 0 | |
| torch.save(model.state_dict(), "best_model.pt") | |
| print(f"Saved new best model to best_model.pt at epoch {epoch}") | |
| else: | |
| pat_count += 1 | |
| if pat_count >= patience: | |
| print("Early stopping!") | |
| break | |
| writer.close() | |
| model_id = "JacobLinCool/taiko-conformer-7" | |
| try: | |
| model.push_to_hub( | |
| model_id, commit_message=f"Epoch {epoch}, Val Loss: {avg_val_loss:.4f}" | |
| ) | |
| upload_folder( | |
| repo_id=model_id, | |
| folder_path="runs", | |
| path_in_repo="runs", | |
| commit_message="Upload TensorBoard logs", | |
| ) | |
| except Exception as e: | |
| print(f"Error uploading model or logs: {e}") | |
| print("Make sure you have the correct permissions and try again.") | |
| if __name__ == "__main__": | |
| main() | |