import torchaudio, torch, numpy as np from torch.nn.functional import pad from .config import TARGET_SR, MAX_SECONDS def ensure_mono(wav: torch.Tensor) -> torch.Tensor: if wav.dim() != 2: raise ValueError("Expected [channels, T]") return wav if wav.size(0) == 1 else wav.mean(0, keepdim=True) def resample_if_needed(wav: torch.Tensor, sr: int, target_sr: int = TARGET_SR) -> torch.Tensor: return wav if sr == target_sr else torchaudio.functional.resample(wav, sr, target_sr) def clamp_duration(wav: torch.Tensor, sr: int, max_seconds: float = MAX_SECONDS) -> torch.Tensor: if max_seconds <= 0: return wav max_len = int(sr * max_seconds) return wav[:, :max_len] if wav.size(1) > max_len else wav def add_convenient_samples(wav: torch.Tensor) -> torch.Tensor: frames = wav.shape[1] // 256 if frames % 2 == 0: return pad(wav, (0, 256)) return wav