Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class TaikoEnergyLoss(nn.Module): | |
| def __init__(self, reduction="mean"): | |
| super().__init__() | |
| # Use 'none' reduction to get element-wise losses, then manually apply masking and reduction | |
| self.mse_loss = nn.MSELoss(reduction="none") | |
| self.reduction = reduction | |
| def forward(self, outputs, batch): | |
| """ | |
| Calculates the MSE loss for energy-based predictions. | |
| Args: | |
| outputs (dict): Model output, containing 'presence' tensor. | |
| outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies. | |
| batch (dict): Batch data from collate_fn, containing true labels and lengths. | |
| batch['don_labels'], batch['ka_labels'], batch['drumroll_labels'] shape: (B, T) | |
| batch['lengths'] shape: (B,) - valid sequence lengths for time dimension T. | |
| Returns: | |
| torch.Tensor: The calculated loss. | |
| """ | |
| pred_energies = outputs["presence"] # (B, T, 3) | |
| true_don = batch["don_labels"] # (B, T) | |
| true_ka = batch["ka_labels"] # (B, T) | |
| true_drumroll = batch["drumroll_labels"] # (B, T) | |
| # Stack true labels to match the structure of pred_energies (B, T, 3) | |
| true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2) | |
| B, T, _ = pred_energies.shape | |
| # Create a mask based on batch['lengths'] to ignore padded parts of sequences | |
| # batch['lengths'] gives the actual length of each sequence in the batch | |
| # mask shape: (B, T) | |
| mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[ | |
| "lengths" | |
| ].unsqueeze(1) | |
| # Expand mask to (B, T, 1) to broadcast across the 3 energy channels | |
| mask_3d = mask_2d.unsqueeze(2) | |
| # Calculate element-wise MSE loss | |
| loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3) | |
| # Apply the mask to the loss | |
| masked_loss = loss_elementwise * mask_3d | |
| if self.reduction == "mean": | |
| # Sum the loss over all valid (unmasked) elements and divide by the number of valid elements | |
| total_loss = masked_loss.sum() | |
| num_valid_elements = mask_3d.sum() # Total number of unmasked float values | |
| if num_valid_elements > 0: | |
| return total_loss / num_valid_elements | |
| else: | |
| # Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0) | |
| return torch.tensor( | |
| 0.0, device=pred_energies.device, requires_grad=True | |
| ) | |
| elif self.reduction == "sum": | |
| return masked_loss.sum() | |
| else: # 'none' or any other case | |
| return masked_loss | |