Spaces:
Sleeping
Sleeping
| from accelerate.utils import set_seed | |
| set_seed(1024) | |
| import math | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| from datasets import concatenate_datasets | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from .config import ( | |
| BATCH_SIZE, | |
| DEVICE, | |
| EPOCHS, | |
| LR, | |
| GRAD_ACCUM_STEPS, | |
| HOP_LENGTH, | |
| SAMPLE_RATE, | |
| ) | |
| from .model import TaikoConformer5 | |
| from .dataset import ds | |
| from .preprocess import preprocess, collate_fn | |
| from .loss import TaikoEnergyLoss | |
| from huggingface_hub import upload_folder | |
| # --- Helper function to log energy plots --- | |
| def log_energy_plots_to_tensorboard( | |
| writer, | |
| tag_prefix, | |
| epoch, | |
| pred_don, | |
| pred_ka, | |
| pred_drumroll, | |
| true_don, | |
| true_ka, | |
| true_drumroll, | |
| valid_length, # Actual valid length of the sequence (before padding) | |
| hop_sec, | |
| ): | |
| """ | |
| Logs a plot of predicted vs. true energies for one sample to TensorBoard. | |
| Energies should be 1D numpy arrays for the single sample, up to valid_length. | |
| """ | |
| # Ensure data is on CPU and converted to numpy, and select only the valid part | |
| pred_don = pred_don[:valid_length].detach().cpu().numpy() | |
| pred_ka = pred_ka[:valid_length].detach().cpu().numpy() | |
| pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy() | |
| true_don = true_don[:valid_length].cpu().numpy() | |
| true_ka = true_ka[:valid_length].cpu().numpy() | |
| true_drumroll = true_drumroll[:valid_length].cpu().numpy() | |
| time_axis = np.arange(valid_length) * hop_sec | |
| fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True) | |
| fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16) | |
| axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--") | |
| axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8) | |
| axs[0].set_ylabel("Don Energy") | |
| axs[0].legend() | |
| axs[0].grid(True) | |
| axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--") | |
| axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8) | |
| axs[1].set_ylabel("Ka Energy") | |
| axs[1].legend() | |
| axs[1].grid(True) | |
| axs[2].plot( | |
| time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--" | |
| ) | |
| axs[2].plot( | |
| time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8 | |
| ) | |
| axs[2].set_ylabel("Drumroll Energy") | |
| axs[2].set_xlabel("Time (s)") | |
| axs[2].legend() | |
| axs[2].grid(True) | |
| plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle | |
| writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch) | |
| plt.close(fig) | |
| def main(): | |
| global ds | |
| # Calculate hop seconds for model output frames | |
| # This assumes the model output time dimension corresponds to the mel spectrogram time dimension | |
| output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE | |
| best_val_loss = float("inf") | |
| patience = 10 # Increased patience a bit | |
| pat_count = 0 | |
| ds_oni = ds.map( | |
| preprocess, | |
| remove_columns=ds.column_names, | |
| fn_kwargs={"difficulty": "oni"}, | |
| writer_batch_size=10, | |
| ) | |
| ds_hard = ds.map( | |
| preprocess, | |
| remove_columns=ds.column_names, | |
| fn_kwargs={"difficulty": "hard"}, | |
| writer_batch_size=10, | |
| ) | |
| ds_normal = ds.map( | |
| preprocess, | |
| remove_columns=ds.column_names, | |
| fn_kwargs={"difficulty": "normal"}, | |
| writer_batch_size=10, | |
| ) | |
| ds = concatenate_datasets([ds_oni, ds_hard, ds_normal]) | |
| ds_train_test = ds.train_test_split(test_size=0.1, seed=42) | |
| train_loader = DataLoader( | |
| ds_train_test["train"], | |
| batch_size=BATCH_SIZE, | |
| shuffle=True, | |
| collate_fn=collate_fn, | |
| num_workers=2, | |
| ) | |
| val_loader = DataLoader( | |
| ds_train_test["test"], | |
| batch_size=BATCH_SIZE, | |
| shuffle=False, | |
| collate_fn=collate_fn, | |
| num_workers=2, | |
| ) | |
| model = TaikoConformer5().to(DEVICE) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR) | |
| criterion = TaikoEnergyLoss(reduction="mean").to(DEVICE) | |
| # Adjust scheduler steps for gradient accumulation | |
| num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS) | |
| total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch | |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer, max_lr=LR, total_steps=total_optimizer_steps | |
| ) | |
| writer = SummaryWriter() | |
| for epoch in range(1, EPOCHS + 1): | |
| model.train() | |
| total_epoch_loss = 0.0 | |
| optimizer.zero_grad() | |
| for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")): | |
| mel = batch["mel"].to(DEVICE) | |
| # Unpack new energy-based labels | |
| don_labels = batch["don_labels"].to(DEVICE) | |
| ka_labels = batch["ka_labels"].to(DEVICE) | |
| drumroll_labels = batch["drumroll_labels"].to(DEVICE) | |
| lengths = batch["lengths"].to( | |
| DEVICE | |
| ) # These are for the Conformer model output | |
| nps = batch["nps"].to(DEVICE) | |
| output_dict = model(mel, lengths, nps) | |
| # output_dict["presence"] is now (B, T_out, 3) for don, ka, drumroll energies | |
| pred_energies_batch = output_dict["presence"] # (B, T_out, 3) | |
| loss_input_batch = { | |
| "don_labels": don_labels, | |
| "ka_labels": ka_labels, | |
| "drumroll_labels": drumroll_labels, | |
| "lengths": lengths, # Pass lengths for masking within the loss function | |
| } | |
| loss = criterion(output_dict, loss_input_batch) | |
| (loss / GRAD_ACCUM_STEPS).backward() | |
| if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader): | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| total_epoch_loss += loss.item() | |
| # Log plot for the first sample of the first batch in each training epoch | |
| if idx == 0: | |
| first_sample_pred_don = pred_energies_batch[0, :, 0] | |
| first_sample_pred_ka = pred_energies_batch[0, :, 1] | |
| first_sample_pred_drumroll = pred_energies_batch[0, :, 2] | |
| first_sample_true_don = don_labels[0, :] | |
| first_sample_true_ka = ka_labels[0, :] | |
| first_sample_true_drumroll = drumroll_labels[0, :] | |
| first_sample_length = lengths[ | |
| 0 | |
| ].item() # Get the valid length of the first sample | |
| log_energy_plots_to_tensorboard( | |
| writer, | |
| "Train/Sample_0", | |
| epoch, | |
| first_sample_pred_don, | |
| first_sample_pred_ka, | |
| first_sample_pred_drumroll, | |
| first_sample_true_don, | |
| first_sample_true_ka, | |
| first_sample_true_drumroll, | |
| first_sample_length, | |
| output_frame_hop_sec, | |
| ) | |
| avg_train_loss = total_epoch_loss / len(train_loader) | |
| writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch) | |
| # Validation | |
| model.eval() | |
| total_val_loss = 0.0 | |
| # Removed storage for classification logits/labels and confusion matrix components | |
| with torch.no_grad(): | |
| for val_idx, batch in enumerate( | |
| tqdm(val_loader, desc=f"Val Epoch {epoch}") | |
| ): | |
| mel = batch["mel"].to(DEVICE) | |
| don_labels = batch["don_labels"].to(DEVICE) | |
| ka_labels = batch["ka_labels"].to(DEVICE) | |
| drumroll_labels = batch["drumroll_labels"].to(DEVICE) | |
| lengths = batch["lengths"].to(DEVICE) | |
| nps = batch["nps"].to(DEVICE) # Ground truth NPS from batch | |
| output_dict = model(mel, lengths, nps) | |
| pred_energies_val_batch = output_dict["presence"] # (B, T_out, 3) | |
| val_loss_input_batch = { | |
| "don_labels": don_labels, | |
| "ka_labels": ka_labels, | |
| "drumroll_labels": drumroll_labels, | |
| "lengths": lengths, | |
| } | |
| val_loss = criterion(output_dict, val_loss_input_batch) | |
| total_val_loss += val_loss.item() | |
| # Log plot for the first sample of the first batch in each validation epoch | |
| if val_idx == 0: | |
| first_val_sample_pred_don = pred_energies_val_batch[0, :, 0] | |
| first_val_sample_pred_ka = pred_energies_val_batch[0, :, 1] | |
| first_val_sample_pred_drumroll = pred_energies_val_batch[0, :, 2] | |
| first_val_sample_true_don = don_labels[0, :] | |
| first_val_sample_true_ka = ka_labels[0, :] | |
| first_val_sample_true_drumroll = drumroll_labels[0, :] | |
| first_val_sample_length = lengths[0].item() | |
| log_energy_plots_to_tensorboard( | |
| writer, | |
| "Eval/Sample_0", | |
| epoch, | |
| first_val_sample_pred_don, | |
| first_val_sample_pred_ka, | |
| first_val_sample_pred_drumroll, | |
| first_val_sample_true_don, | |
| first_val_sample_true_ka, | |
| first_val_sample_true_drumroll, | |
| first_val_sample_length, | |
| output_frame_hop_sec, | |
| ) | |
| # Log ground truth NPS for reference during validation if needed | |
| # writer.add_scalar("NPS/GT_Val_Batch_Avg", nps.mean().item(), epoch * len(val_loader) + idx) | |
| avg_val_loss = total_val_loss / len(val_loader) | |
| writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch) | |
| # Log learning rate | |
| current_lr = optimizer.param_groups[0]["lr"] | |
| writer.add_scalar("LR/learning_rate", current_lr, epoch) | |
| # Log ground truth NPS from the last validation batch (or mean over epoch) | |
| if "nps" in batch: # Check if nps is in the last batch | |
| writer.add_scalar( | |
| "NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch | |
| ) | |
| print( | |
| f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}" | |
| ) | |
| if avg_val_loss < best_val_loss: | |
| best_val_loss = avg_val_loss | |
| pat_count = 0 | |
| torch.save(model.state_dict(), "best_model.pt") # Changed model save name | |
| print(f"Saved new best model to best_model.pt at epoch {epoch}") | |
| else: | |
| pat_count += 1 | |
| if pat_count >= patience: | |
| print("Early stopping!") | |
| break | |
| writer.close() | |
| model_id = "JacobLinCool/taiko-conformer-5" | |
| try: | |
| model.push_to_hub(model_id, commit_message="Upload trained model") | |
| upload_folder( | |
| repo_id=model_id, | |
| folder_path="runs", | |
| path_in_repo=".", | |
| commit_message="Upload training logs", | |
| ignore_patterns=["*.txt", "*.json", "*.csv"], | |
| ) | |
| print(f"Model and logs uploaded to {model_id}") | |
| except Exception as e: | |
| print(f"Error uploading to Hugging Face Hub: {e}") | |
| print("Make sure you have the correct permissions and try again.") | |
| if __name__ == "__main__": | |
| main() | |