Spaces:
Sleeping
Sleeping
Upload 11 files
Browse files- app.py +4 -6
- argshield.py +2 -3
- config.py +33 -30
- engine.py +239 -194
- models.py +323 -359
app.py
CHANGED
|
@@ -88,7 +88,7 @@ def process_audio_files(zip_file, model_name, layer, alpha):
|
|
| 88 |
model_defaults = {
|
| 89 |
"wavlm": 24, "wav2vec2": 24, "hubert": 24,
|
| 90 |
"wavlm_base": 12, "wav2vec2_base": 12, "hubert_base": 12,
|
| 91 |
-
"wav2vec2_xlsr": 24
|
| 92 |
}
|
| 93 |
layer_final = layer if layer is not None else model_defaults.get(model_name, 12)
|
| 94 |
|
|
@@ -184,7 +184,6 @@ def create_interface():
|
|
| 184 |
| `wav2vec2_base` | Wav2Vec2 Base | 12 | Faster, good quality |
|
| 185 |
| `hubert_base` | HuBERT Base | 12 | |
|
| 186 |
| `wav2vec2_xlsr` | Wav2Vec2 XLSR-53 | 24 | Multilingual |
|
| 187 |
-
| `ast` | Audio Spectrogram Transformer | 12 | Music |
|
| 188 |
|
| 189 |
## Parameters
|
| 190 |
|
|
@@ -233,7 +232,7 @@ def create_interface():
|
|
| 233 |
model_dropdown = gr.Dropdown(
|
| 234 |
choices=["raw", "wavlm", "wav2vec2", "hubert",
|
| 235 |
"wavlm_base", "wav2vec2_base", "hubert_base",
|
| 236 |
-
"wav2vec2_xlsr"
|
| 237 |
value="wav2vec2_base",
|
| 238 |
label="Select embedding model"
|
| 239 |
)
|
|
@@ -265,8 +264,7 @@ def create_interface():
|
|
| 265 |
"wav2vec2_xlsr": {"maximum": 24, "value": 24, "interactive": True},
|
| 266 |
"wavlm_base": {"maximum": 12, "value": 12, "interactive": True},
|
| 267 |
"wav2vec2_base": {"maximum": 12, "value": 12, "interactive": True},
|
| 268 |
-
"hubert_base": {"maximum": 12, "value": 12, "interactive": True}
|
| 269 |
-
"ast": {"maximum": 12, "value": 12, "interactive": True}
|
| 270 |
}
|
| 271 |
|
| 272 |
config = model_configs.get(model_name, {"maximum": 12, "value": 12, "interactive": True})
|
|
@@ -308,4 +306,4 @@ def create_interface():
|
|
| 308 |
|
| 309 |
if __name__ == "__main__":
|
| 310 |
demo = create_interface()
|
| 311 |
-
demo.launch()
|
|
|
|
| 88 |
model_defaults = {
|
| 89 |
"wavlm": 24, "wav2vec2": 24, "hubert": 24,
|
| 90 |
"wavlm_base": 12, "wav2vec2_base": 12, "hubert_base": 12,
|
| 91 |
+
"wav2vec2_xlsr": 24
|
| 92 |
}
|
| 93 |
layer_final = layer if layer is not None else model_defaults.get(model_name, 12)
|
| 94 |
|
|
|
|
| 184 |
| `wav2vec2_base` | Wav2Vec2 Base | 12 | Faster, good quality |
|
| 185 |
| `hubert_base` | HuBERT Base | 12 | |
|
| 186 |
| `wav2vec2_xlsr` | Wav2Vec2 XLSR-53 | 24 | Multilingual |
|
|
|
|
| 187 |
|
| 188 |
## Parameters
|
| 189 |
|
|
|
|
| 232 |
model_dropdown = gr.Dropdown(
|
| 233 |
choices=["raw", "wavlm", "wav2vec2", "hubert",
|
| 234 |
"wavlm_base", "wav2vec2_base", "hubert_base",
|
| 235 |
+
"wav2vec2_xlsr"],
|
| 236 |
value="wav2vec2_base",
|
| 237 |
label="Select embedding model"
|
| 238 |
)
|
|
|
|
| 264 |
"wav2vec2_xlsr": {"maximum": 24, "value": 24, "interactive": True},
|
| 265 |
"wavlm_base": {"maximum": 12, "value": 12, "interactive": True},
|
| 266 |
"wav2vec2_base": {"maximum": 12, "value": 12, "interactive": True},
|
| 267 |
+
"hubert_base": {"maximum": 12, "value": 12, "interactive": True}
|
|
|
|
| 268 |
}
|
| 269 |
|
| 270 |
config = model_configs.get(model_name, {"maximum": 12, "value": 12, "interactive": True})
|
|
|
|
| 306 |
|
| 307 |
if __name__ == "__main__":
|
| 308 |
demo = create_interface()
|
| 309 |
+
demo.launch()
|
argshield.py
CHANGED
|
@@ -17,7 +17,6 @@ MODEL_DEFAULT_LAYER = {
|
|
| 17 |
"wav2vec2_base": 12,
|
| 18 |
"hubert_base": 12,
|
| 19 |
"wav2vec2_xlsr": 24,
|
| 20 |
-
"ast": 12,
|
| 21 |
}
|
| 22 |
|
| 23 |
def _read_manifest_json(path: Path):
|
|
@@ -88,7 +87,7 @@ def _parse_args():
|
|
| 88 |
required=True,
|
| 89 |
help=("Embedding model. Choices: "
|
| 90 |
"raw, wavlm, wav2vec2, hubert, wavlm_base, wav2vec2_base, "
|
| 91 |
-
"hubert_base, wav2vec2_xlsr
|
| 92 |
)
|
| 93 |
parser.add_argument(
|
| 94 |
"--layer",
|
|
@@ -141,4 +140,4 @@ def _validate_gpus(max_gpus_opt):
|
|
| 141 |
raise SystemExit("--max-gpus must be an integer >= 0.")
|
| 142 |
if mg < 0:
|
| 143 |
raise SystemExit("--max-gpus must be >= 0.")
|
| 144 |
-
return mg
|
|
|
|
| 17 |
"wav2vec2_base": 12,
|
| 18 |
"hubert_base": 12,
|
| 19 |
"wav2vec2_xlsr": 24,
|
|
|
|
| 20 |
}
|
| 21 |
|
| 22 |
def _read_manifest_json(path: Path):
|
|
|
|
| 87 |
required=True,
|
| 88 |
help=("Embedding model. Choices: "
|
| 89 |
"raw, wavlm, wav2vec2, hubert, wavlm_base, wav2vec2_base, "
|
| 90 |
+
"hubert_base, wav2vec2_xlsr"),
|
| 91 |
)
|
| 92 |
parser.add_argument(
|
| 93 |
"--layer",
|
|
|
|
| 140 |
raise SystemExit("--max-gpus must be an integer >= 0.")
|
| 141 |
if mg < 0:
|
| 142 |
raise SystemExit("--max-gpus must be >= 0.")
|
| 143 |
+
return mg
|
config.py
CHANGED
|
@@ -1,30 +1,33 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
import warnings
|
| 5 |
-
warnings.filterwarnings(
|
| 6 |
-
"ignore",
|
| 7 |
-
category=UserWarning,
|
| 8 |
-
message=r"^expandable_segments not supported on this platform"
|
| 9 |
-
)
|
| 10 |
-
|
| 11 |
-
SR = 16_000
|
| 12 |
-
RESULTS_ROOT = "results"
|
| 13 |
-
BATCH_SIZE = 2
|
| 14 |
-
ENERGY_WIN_MS = 20
|
| 15 |
-
ENERGY_HOP_MS = 20
|
| 16 |
-
SILENCE_RATIO = 0.1
|
| 17 |
-
EPS = 1e-4
|
| 18 |
-
COV_TOL = 1e-6
|
| 19 |
-
|
| 20 |
-
DEFAULT_LAYER = 2
|
| 21 |
-
DEFAULT_ADD_CI = True
|
| 22 |
-
DEFAULT_DELTA_CI = 0.05
|
| 23 |
-
DEFAULT_ALPHA = 1.0
|
| 24 |
-
|
| 25 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.6"
|
| 26 |
-
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
| 27 |
-
|
| 28 |
-
torch.backends.cudnn.benchmark = True
|
| 29 |
-
torch.backends.cudnn.deterministic = False
|
| 30 |
-
torch.backends.cudnn.enabled = True
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
import warnings
|
| 5 |
+
warnings.filterwarnings(
|
| 6 |
+
"ignore",
|
| 7 |
+
category=UserWarning,
|
| 8 |
+
message=r"^expandable_segments not supported on this platform"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
SR = 16_000
|
| 12 |
+
RESULTS_ROOT = "results"
|
| 13 |
+
BATCH_SIZE = 2
|
| 14 |
+
ENERGY_WIN_MS = 20
|
| 15 |
+
ENERGY_HOP_MS = 20
|
| 16 |
+
SILENCE_RATIO = 0.1
|
| 17 |
+
EPS = 1e-4
|
| 18 |
+
COV_TOL = 1e-6
|
| 19 |
+
|
| 20 |
+
DEFAULT_LAYER = 2
|
| 21 |
+
DEFAULT_ADD_CI = True
|
| 22 |
+
DEFAULT_DELTA_CI = 0.05
|
| 23 |
+
DEFAULT_ALPHA = 1.0
|
| 24 |
+
|
| 25 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.6"
|
| 26 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
| 27 |
+
|
| 28 |
+
torch.backends.cudnn.benchmark = True
|
| 29 |
+
torch.backends.cudnn.deterministic = False
|
| 30 |
+
torch.backends.cudnn.enabled = True
|
| 31 |
+
|
| 32 |
+
if torch.cuda.is_available():
|
| 33 |
+
torch.cuda.set_per_process_memory_fraction(0.8)
|
engine.py
CHANGED
|
@@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|
| 4 |
from datetime import datetime
|
| 5 |
import librosa
|
| 6 |
import pandas as pd
|
|
|
|
| 7 |
from audio import (
|
| 8 |
assign_outputs_to_refs_by_corr,
|
| 9 |
loudness_normalize,
|
|
@@ -73,7 +74,7 @@ def compute_mapss_measures(
|
|
| 73 |
|
| 74 |
if algos is None:
|
| 75 |
algos_to_run = sorted(
|
| 76 |
-
{algo for
|
| 77 |
)
|
| 78 |
else:
|
| 79 |
algos_to_run = list(algos)
|
|
@@ -132,6 +133,7 @@ def compute_mapss_measures(
|
|
| 132 |
win = int(ENERGY_WIN_MS * SR / 1000)
|
| 133 |
hop = int(ENERGY_HOP_MS * SR / 1000)
|
| 134 |
voiced_mask_mix = []
|
|
|
|
| 135 |
|
| 136 |
for i, mix in enumerate(mixture_entries):
|
| 137 |
if verbose:
|
|
@@ -142,6 +144,7 @@ def compute_mapss_measures(
|
|
| 142 |
refs_for_mix = [all_refs[e["id"]].cuda() for e in mix]
|
| 143 |
mask = make_union_voiced_mask(refs_for_mix, win, hop)
|
| 144 |
voiced_mask_mix.append(mask.cpu())
|
|
|
|
| 145 |
# Explicitly delete GPU tensors
|
| 146 |
for ref in refs_for_mix:
|
| 147 |
del ref
|
|
@@ -150,21 +153,33 @@ def compute_mapss_measures(
|
|
| 150 |
refs_for_mix = [all_refs[e["id"]].cpu() for e in mix]
|
| 151 |
mask = make_union_voiced_mask(refs_for_mix, win, hop)
|
| 152 |
voiced_mask_mix.append(mask.cpu())
|
|
|
|
| 153 |
|
| 154 |
ordered_speakers = [e["id"] for e in flat_entries]
|
| 155 |
|
| 156 |
-
for
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
|
| 161 |
-
os.makedirs(algo_dir, exist_ok=True)
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
|
| 166 |
-
|
| 167 |
-
for e in mix:
|
| 168 |
assigned_path = e.get("outs", {}).get(algo)
|
| 169 |
if assigned_path is None:
|
| 170 |
missing.append((e["mixture"], e["id"]))
|
|
@@ -173,49 +188,49 @@ def compute_mapss_measures(
|
|
| 173 |
wav, _ = librosa.load(str(assigned_path), sr=SR)
|
| 174 |
all_outs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
if not all_outs:
|
| 185 |
-
if verbose:
|
| 186 |
-
warnings.warn(f"[{algo}] No outputs provided. Skipping algorithm.")
|
| 187 |
-
continue
|
| 188 |
-
|
| 189 |
-
ps_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 190 |
-
pm_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 191 |
-
ps_bias_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 192 |
-
ps_prob_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 193 |
-
pm_bias_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 194 |
-
pm_prob_ts = {m: {s: [] for s in ordered_speakers} for m in models}
|
| 195 |
-
|
| 196 |
-
for model_idx, mname in enumerate(models):
|
| 197 |
-
if verbose:
|
| 198 |
-
print(f" Processing Model {model_idx + 1}/{len(models)}: {mname}")
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
-
|
| 205 |
-
|
|
|
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
|
| 210 |
-
|
| 211 |
-
speakers_this_mix = [e for e in
|
| 212 |
if not speakers_this_mix:
|
| 213 |
continue
|
| 214 |
|
| 215 |
if verbose:
|
| 216 |
-
print(
|
| 217 |
-
f"Processing mixture {k + 1}/{len(mixture_entries)} for {metric_type}"
|
| 218 |
-
)
|
| 219 |
|
| 220 |
all_signals_mix = []
|
| 221 |
all_masks_mix = []
|
|
@@ -240,7 +255,7 @@ def compute_mapss_measures(
|
|
| 240 |
sigs = [all_refs[s].numpy(), all_outs[s].numpy()] + dists
|
| 241 |
lbls = ["ref", "out"] + [f"d{i}" for i in range(len(dists))]
|
| 242 |
|
| 243 |
-
masks = [voiced_mask_mix[
|
| 244 |
all_signals_mix.extend(sigs)
|
| 245 |
all_masks_mix.extend(masks)
|
| 246 |
all_labels_mix.extend([f"{s}-{l}" for l in lbls])
|
|
@@ -269,12 +284,124 @@ def compute_mapss_measures(
|
|
| 269 |
|
| 270 |
if embeddings_list:
|
| 271 |
embeddings = torch.cat(embeddings_list, dim=0)
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
except Exception as ex:
|
| 276 |
if verbose:
|
| 277 |
-
print(f" ERROR processing mixture {
|
| 278 |
continue
|
| 279 |
finally:
|
| 280 |
# Always clean up after processing a mixture
|
|
@@ -284,155 +411,73 @@ def compute_mapss_measures(
|
|
| 284 |
clear_gpu_memory()
|
| 285 |
gc.collect()
|
| 286 |
|
| 287 |
-
|
| 288 |
-
print(f" Computing {metric_type} scores for {mname}...")
|
| 289 |
-
|
| 290 |
-
# Process mixtures with their stored embeddings and labels
|
| 291 |
-
with ThreadPoolExecutor(
|
| 292 |
-
max_workers=min(2, ngpu if ngpu > 0 else 1)
|
| 293 |
-
) as executor:
|
| 294 |
-
for k in range(len(mixture_entries)):
|
| 295 |
-
if k not in embs_by_mix:
|
| 296 |
-
continue
|
| 297 |
-
|
| 298 |
-
E, L, D = embs_by_mix[k].shape
|
| 299 |
-
if L == 0:
|
| 300 |
-
if verbose:
|
| 301 |
-
print(f" WARNING: mixture {k + 1} produced 0 frames after masking; skipping.")
|
| 302 |
-
continue
|
| 303 |
-
|
| 304 |
-
# Get the labels for this mixture
|
| 305 |
-
labels_for_mix = labels_by_mix[k]
|
| 306 |
-
|
| 307 |
-
def process_frame(f, embeddings_mix, labels_mix):
|
| 308 |
-
try:
|
| 309 |
-
frame_emb = embeddings_mix[:, f, :].detach().cpu().numpy()
|
| 310 |
-
|
| 311 |
-
if add_ci:
|
| 312 |
-
coords_d, coords_c, eigvals, k_sub_gauss = (
|
| 313 |
-
gpu_distributor.execute_on_gpu(
|
| 314 |
-
diffusion_map_torch,
|
| 315 |
-
frame_emb,
|
| 316 |
-
labels_mix,
|
| 317 |
-
alpha=alpha,
|
| 318 |
-
eig_solver="full",
|
| 319 |
-
return_eigs=True,
|
| 320 |
-
return_complement=True,
|
| 321 |
-
return_cval=add_ci,
|
| 322 |
-
)
|
| 323 |
-
)
|
| 324 |
-
else:
|
| 325 |
-
coords_d = gpu_distributor.execute_on_gpu(
|
| 326 |
-
diffusion_map_torch,
|
| 327 |
-
frame_emb,
|
| 328 |
-
labels_mix,
|
| 329 |
-
alpha=alpha,
|
| 330 |
-
eig_solver="full",
|
| 331 |
-
return_eigs=False,
|
| 332 |
-
return_complement=False,
|
| 333 |
-
return_cval=False,
|
| 334 |
-
)
|
| 335 |
-
coords_c = None
|
| 336 |
-
eigvals = None
|
| 337 |
-
k_sub_gauss = 1
|
| 338 |
-
|
| 339 |
-
if metric_type == "PS":
|
| 340 |
-
score = compute_ps(
|
| 341 |
-
coords_d, labels_mix, max_gpus
|
| 342 |
-
)
|
| 343 |
-
bias = prob = None
|
| 344 |
-
if add_ci:
|
| 345 |
-
bias, prob = ps_ci_components_full(
|
| 346 |
-
coords_d,
|
| 347 |
-
coords_c,
|
| 348 |
-
eigvals,
|
| 349 |
-
labels_mix,
|
| 350 |
-
delta=DEFAULT_DELTA_CI,
|
| 351 |
-
)
|
| 352 |
-
return f, "PS", score, bias, prob
|
| 353 |
-
else:
|
| 354 |
-
score = compute_pm(
|
| 355 |
-
coords_d, labels_mix, "gamma", max_gpus
|
| 356 |
-
)
|
| 357 |
-
bias = prob = None
|
| 358 |
-
if add_ci:
|
| 359 |
-
bias, prob = pm_ci_components_full(
|
| 360 |
-
coords_d,
|
| 361 |
-
coords_c,
|
| 362 |
-
eigvals,
|
| 363 |
-
labels_mix,
|
| 364 |
-
delta=DEFAULT_DELTA_CI,
|
| 365 |
-
K=k_sub_gauss,
|
| 366 |
-
)
|
| 367 |
-
return f, "PM", score, bias, prob
|
| 368 |
-
|
| 369 |
-
except Exception as ex:
|
| 370 |
-
if verbose:
|
| 371 |
-
print(f" ERROR frame {f + 1}: {ex}")
|
| 372 |
-
return None
|
| 373 |
-
|
| 374 |
-
futures = [
|
| 375 |
-
executor.submit(process_frame, f, embs_by_mix[k], labels_for_mix)
|
| 376 |
-
for f in range(L)
|
| 377 |
-
]
|
| 378 |
-
for fut in futures:
|
| 379 |
-
result = fut.result()
|
| 380 |
-
if result is None:
|
| 381 |
-
continue
|
| 382 |
-
|
| 383 |
-
f, metric, score, bias, prob = result
|
| 384 |
-
|
| 385 |
-
if metric == "PS":
|
| 386 |
-
for sp in score:
|
| 387 |
-
ps_ts[mname][sp].append(score[sp])
|
| 388 |
-
if add_ci and bias is not None:
|
| 389 |
-
ps_bias_ts[mname][sp].append(bias[sp])
|
| 390 |
-
ps_prob_ts[mname][sp].append(prob[sp])
|
| 391 |
-
else:
|
| 392 |
-
for sp in score:
|
| 393 |
-
pm_ts[mname][sp].append(score[sp])
|
| 394 |
-
if add_ci and bias is not None:
|
| 395 |
-
pm_bias_ts[mname][sp].append(bias[sp])
|
| 396 |
-
pm_prob_ts[mname][sp].append(prob[sp])
|
| 397 |
-
|
| 398 |
-
# Clean up after processing all mixtures for this metric
|
| 399 |
-
del embs_by_mix, labels_by_mix
|
| 400 |
clear_gpu_memory()
|
| 401 |
gc.collect()
|
| 402 |
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
if verbose:
|
| 408 |
-
print(f" Saving results for {
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
for
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
)
|
| 437 |
|
| 438 |
del all_outs
|
|
|
|
| 4 |
from datetime import datetime
|
| 5 |
import librosa
|
| 6 |
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
from audio import (
|
| 9 |
assign_outputs_to_refs_by_corr,
|
| 10 |
loudness_normalize,
|
|
|
|
| 74 |
|
| 75 |
if algos is None:
|
| 76 |
algos_to_run = sorted(
|
| 77 |
+
{algo for algo in canon_mix[0].systems.keys()} if canon_mix and canon_mix[0].systems else []
|
| 78 |
)
|
| 79 |
else:
|
| 80 |
algos_to_run = list(algos)
|
|
|
|
| 133 |
win = int(ENERGY_WIN_MS * SR / 1000)
|
| 134 |
hop = int(ENERGY_HOP_MS * SR / 1000)
|
| 135 |
voiced_mask_mix = []
|
| 136 |
+
total_frames_per_mix = [] # Store total frames for each mixture
|
| 137 |
|
| 138 |
for i, mix in enumerate(mixture_entries):
|
| 139 |
if verbose:
|
|
|
|
| 144 |
refs_for_mix = [all_refs[e["id"]].cuda() for e in mix]
|
| 145 |
mask = make_union_voiced_mask(refs_for_mix, win, hop)
|
| 146 |
voiced_mask_mix.append(mask.cpu())
|
| 147 |
+
total_frames_per_mix.append(mask.shape[0])
|
| 148 |
# Explicitly delete GPU tensors
|
| 149 |
for ref in refs_for_mix:
|
| 150 |
del ref
|
|
|
|
| 153 |
refs_for_mix = [all_refs[e["id"]].cpu() for e in mix]
|
| 154 |
mask = make_union_voiced_mask(refs_for_mix, win, hop)
|
| 155 |
voiced_mask_mix.append(mask.cpu())
|
| 156 |
+
total_frames_per_mix.append(mask.shape[0])
|
| 157 |
|
| 158 |
ordered_speakers = [e["id"] for e in flat_entries]
|
| 159 |
|
| 160 |
+
# Initialize storage for all mixtures and algorithms
|
| 161 |
+
all_mixture_results = {} # mixture_id -> {algo -> {model -> data}}
|
| 162 |
+
|
| 163 |
+
for mix_idx, (mix_canon, mix_entries) in enumerate(zip(canon_mix, mixture_entries)):
|
| 164 |
+
mixture_id = mix_canon.mixture_id
|
| 165 |
+
all_mixture_results[mixture_id] = {}
|
| 166 |
+
|
| 167 |
+
# Get total frames for this mixture
|
| 168 |
+
total_frames = total_frames_per_mix[mix_idx]
|
| 169 |
+
|
| 170 |
+
# Get speakers for this mixture
|
| 171 |
+
mixture_speakers = [e["id"] for e in mix_entries]
|
| 172 |
+
|
| 173 |
+
for algo_idx, algo in enumerate(algos_to_run):
|
| 174 |
+
if verbose:
|
| 175 |
+
print(f"\nProcessing Mixture {mixture_id}, Algorithm {algo_idx + 1}/{len(algos_to_run)}: {algo}")
|
| 176 |
|
| 177 |
+
# Remove the old algo_dir creation here - we don't need these empty folders anymore
|
|
|
|
| 178 |
|
| 179 |
+
all_outs = {}
|
| 180 |
+
missing = []
|
| 181 |
|
| 182 |
+
for e in mix_entries:
|
|
|
|
| 183 |
assigned_path = e.get("outs", {}).get(algo)
|
| 184 |
if assigned_path is None:
|
| 185 |
missing.append((e["mixture"], e["id"]))
|
|
|
|
| 188 |
wav, _ = librosa.load(str(assigned_path), sr=SR)
|
| 189 |
all_outs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
|
| 190 |
|
| 191 |
+
if missing:
|
| 192 |
+
msg = f"[{algo}] missing outputs for {len(missing)} speaker(s) in mixture {mixture_id}"
|
| 193 |
+
if on_missing == "error":
|
| 194 |
+
raise FileNotFoundError(msg)
|
| 195 |
+
else:
|
| 196 |
+
if verbose:
|
| 197 |
+
warnings.warn(msg + " Skipping those speakers.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
+
if not all_outs:
|
| 200 |
+
if verbose:
|
| 201 |
+
warnings.warn(f"[{algo}] No outputs for mixture {mixture_id}. Skipping.")
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
# Initialize storage for this algorithm
|
| 205 |
+
if algo not in all_mixture_results[mixture_id]:
|
| 206 |
+
all_mixture_results[mixture_id][algo] = {}
|
| 207 |
+
|
| 208 |
+
# Initialize frame-wise storage with NaN for all frames
|
| 209 |
+
ps_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
|
| 210 |
+
pm_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
|
| 211 |
+
ps_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
|
| 212 |
+
ps_prob_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
|
| 213 |
+
pm_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
|
| 214 |
+
pm_prob_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
|
| 215 |
+
|
| 216 |
+
for model_idx, mname in enumerate(models):
|
| 217 |
+
if verbose:
|
| 218 |
+
print(f" Processing Model {model_idx + 1}/{len(models)}: {mname}")
|
| 219 |
|
| 220 |
+
for metric_type in ["PS", "PM"]:
|
| 221 |
+
clear_gpu_memory()
|
| 222 |
+
gc.collect()
|
| 223 |
|
| 224 |
+
model_wrapper, layer_eff = load_model(mname, layer, max_gpus)
|
| 225 |
+
get_gpu_memory_info(verbose)
|
| 226 |
|
| 227 |
+
# Process only this mixture
|
| 228 |
+
speakers_this_mix = [e for e in mix_entries if e["id"] in all_outs]
|
| 229 |
if not speakers_this_mix:
|
| 230 |
continue
|
| 231 |
|
| 232 |
if verbose:
|
| 233 |
+
print(f" Processing {metric_type} for mixture {mixture_id}")
|
|
|
|
|
|
|
| 234 |
|
| 235 |
all_signals_mix = []
|
| 236 |
all_masks_mix = []
|
|
|
|
| 255 |
sigs = [all_refs[s].numpy(), all_outs[s].numpy()] + dists
|
| 256 |
lbls = ["ref", "out"] + [f"d{i}" for i in range(len(dists))]
|
| 257 |
|
| 258 |
+
masks = [voiced_mask_mix[mix_idx]] * len(sigs)
|
| 259 |
all_signals_mix.extend(sigs)
|
| 260 |
all_masks_mix.extend(masks)
|
| 261 |
all_labels_mix.extend([f"{s}-{l}" for l in lbls])
|
|
|
|
| 284 |
|
| 285 |
if embeddings_list:
|
| 286 |
embeddings = torch.cat(embeddings_list, dim=0)
|
| 287 |
+
E, L, D = embeddings.shape
|
| 288 |
+
|
| 289 |
+
if L == 0:
|
| 290 |
+
if verbose:
|
| 291 |
+
print(
|
| 292 |
+
f" WARNING: mixture {mixture_id} produced 0 frames after masking; skipping.")
|
| 293 |
+
continue
|
| 294 |
+
|
| 295 |
+
# Get valid frame indices for this mixture
|
| 296 |
+
mask = voiced_mask_mix[mix_idx]
|
| 297 |
+
valid_frame_indices = torch.where(mask)[0].tolist()
|
| 298 |
+
|
| 299 |
+
if verbose:
|
| 300 |
+
print(f" Computing {metric_type} scores for {mname}...")
|
| 301 |
+
|
| 302 |
+
# Process frames with their stored embeddings and labels
|
| 303 |
+
with ThreadPoolExecutor(
|
| 304 |
+
max_workers=min(2, ngpu if ngpu > 0 else 1)
|
| 305 |
+
) as executor:
|
| 306 |
+
|
| 307 |
+
def process_frame(f, frame_idx, embeddings_mix, labels_mix):
|
| 308 |
+
try:
|
| 309 |
+
frame_emb = embeddings_mix[:, f, :].detach().cpu().numpy()
|
| 310 |
+
|
| 311 |
+
if add_ci:
|
| 312 |
+
coords_d, coords_c, eigvals, k_sub_gauss = (
|
| 313 |
+
gpu_distributor.execute_on_gpu(
|
| 314 |
+
diffusion_map_torch,
|
| 315 |
+
frame_emb,
|
| 316 |
+
labels_mix,
|
| 317 |
+
alpha=alpha,
|
| 318 |
+
eig_solver="full",
|
| 319 |
+
return_eigs=True,
|
| 320 |
+
return_complement=True,
|
| 321 |
+
return_cval=add_ci,
|
| 322 |
+
)
|
| 323 |
+
)
|
| 324 |
+
else:
|
| 325 |
+
coords_d = gpu_distributor.execute_on_gpu(
|
| 326 |
+
diffusion_map_torch,
|
| 327 |
+
frame_emb,
|
| 328 |
+
labels_mix,
|
| 329 |
+
alpha=alpha,
|
| 330 |
+
eig_solver="full",
|
| 331 |
+
return_eigs=False,
|
| 332 |
+
return_complement=False,
|
| 333 |
+
return_cval=False,
|
| 334 |
+
)
|
| 335 |
+
coords_c = None
|
| 336 |
+
eigvals = None
|
| 337 |
+
k_sub_gauss = 1
|
| 338 |
+
|
| 339 |
+
if metric_type == "PS":
|
| 340 |
+
score = compute_ps(
|
| 341 |
+
coords_d, labels_mix, max_gpus
|
| 342 |
+
)
|
| 343 |
+
bias = prob = None
|
| 344 |
+
if add_ci:
|
| 345 |
+
bias, prob = ps_ci_components_full(
|
| 346 |
+
coords_d,
|
| 347 |
+
coords_c,
|
| 348 |
+
eigvals,
|
| 349 |
+
labels_mix,
|
| 350 |
+
delta=DEFAULT_DELTA_CI,
|
| 351 |
+
)
|
| 352 |
+
return frame_idx, "PS", score, bias, prob
|
| 353 |
+
else:
|
| 354 |
+
score = compute_pm(
|
| 355 |
+
coords_d, labels_mix, "gamma", max_gpus
|
| 356 |
+
)
|
| 357 |
+
bias = prob = None
|
| 358 |
+
if add_ci:
|
| 359 |
+
bias, prob = pm_ci_components_full(
|
| 360 |
+
coords_d,
|
| 361 |
+
coords_c,
|
| 362 |
+
eigvals,
|
| 363 |
+
labels_mix,
|
| 364 |
+
delta=DEFAULT_DELTA_CI,
|
| 365 |
+
K=k_sub_gauss,
|
| 366 |
+
)
|
| 367 |
+
return frame_idx, "PM", score, bias, prob
|
| 368 |
+
|
| 369 |
+
except Exception as ex:
|
| 370 |
+
if verbose:
|
| 371 |
+
print(f" ERROR frame {frame_idx}: {ex}")
|
| 372 |
+
return None
|
| 373 |
+
|
| 374 |
+
futures = [
|
| 375 |
+
executor.submit(process_frame, f, valid_frame_indices[f], embeddings,
|
| 376 |
+
all_labels_mix)
|
| 377 |
+
for f in range(L)
|
| 378 |
+
]
|
| 379 |
+
|
| 380 |
+
for fut in futures:
|
| 381 |
+
result = fut.result()
|
| 382 |
+
if result is None:
|
| 383 |
+
continue
|
| 384 |
+
|
| 385 |
+
frame_idx, metric, score, bias, prob = result
|
| 386 |
+
|
| 387 |
+
if metric == "PS":
|
| 388 |
+
for sp in score:
|
| 389 |
+
if sp in mixture_speakers:
|
| 390 |
+
ps_frames[mname][sp][frame_idx] = score[sp]
|
| 391 |
+
if add_ci and bias is not None:
|
| 392 |
+
ps_bias_frames[mname][sp][frame_idx] = bias[sp]
|
| 393 |
+
ps_prob_frames[mname][sp][frame_idx] = prob[sp]
|
| 394 |
+
else:
|
| 395 |
+
for sp in score:
|
| 396 |
+
if sp in mixture_speakers:
|
| 397 |
+
pm_frames[mname][sp][frame_idx] = score[sp]
|
| 398 |
+
if add_ci and bias is not None:
|
| 399 |
+
pm_bias_frames[mname][sp][frame_idx] = bias[sp]
|
| 400 |
+
pm_prob_frames[mname][sp][frame_idx] = prob[sp]
|
| 401 |
|
| 402 |
except Exception as ex:
|
| 403 |
if verbose:
|
| 404 |
+
print(f" ERROR processing mixture {mixture_id}: {ex}")
|
| 405 |
continue
|
| 406 |
finally:
|
| 407 |
# Always clean up after processing a mixture
|
|
|
|
| 411 |
clear_gpu_memory()
|
| 412 |
gc.collect()
|
| 413 |
|
| 414 |
+
del model_wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
clear_gpu_memory()
|
| 416 |
gc.collect()
|
| 417 |
|
| 418 |
+
# Store results for this mixture and algorithm
|
| 419 |
+
all_mixture_results[mixture_id][algo][mname] = {
|
| 420 |
+
'ps_frames': ps_frames[mname],
|
| 421 |
+
'pm_frames': pm_frames[mname],
|
| 422 |
+
'ps_bias_frames': ps_bias_frames[mname] if add_ci else None,
|
| 423 |
+
'ps_prob_frames': ps_prob_frames[mname] if add_ci else None,
|
| 424 |
+
'pm_bias_frames': pm_bias_frames[mname] if add_ci else None,
|
| 425 |
+
'pm_prob_frames': pm_prob_frames[mname] if add_ci else None,
|
| 426 |
+
'total_frames': total_frames
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
# Save results for this mixture after processing all algorithms
|
| 430 |
if verbose:
|
| 431 |
+
print(f" Saving results for mixture {mixture_id}...")
|
| 432 |
+
|
| 433 |
+
# Create timestamps in milliseconds - using lowercase hop
|
| 434 |
+
timestamps_ms = [i * hop * 1000 / SR for i in range(total_frames)]
|
| 435 |
+
|
| 436 |
+
for model in models:
|
| 437 |
+
# Prepare PS data
|
| 438 |
+
ps_data = {'timestamp_ms': timestamps_ms}
|
| 439 |
+
pm_data = {'timestamp_ms': timestamps_ms}
|
| 440 |
+
ci_data = {'timestamp_ms': timestamps_ms} if add_ci else None
|
| 441 |
+
|
| 442 |
+
# Combine data from all algorithms for this mixture
|
| 443 |
+
for algo in algos_to_run:
|
| 444 |
+
if algo not in all_mixture_results[mixture_id]:
|
| 445 |
+
continue
|
| 446 |
+
if model not in all_mixture_results[mixture_id][algo]:
|
| 447 |
+
continue
|
| 448 |
+
|
| 449 |
+
model_data = all_mixture_results[mixture_id][algo][model]
|
| 450 |
+
|
| 451 |
+
# Add PS data
|
| 452 |
+
for speaker in mixture_speakers:
|
| 453 |
+
col_name = f"{algo}_{speaker}"
|
| 454 |
+
ps_data[col_name] = model_data['ps_frames'][speaker]
|
| 455 |
+
pm_data[col_name] = model_data['pm_frames'][speaker]
|
| 456 |
+
|
| 457 |
+
if add_ci and ci_data is not None:
|
| 458 |
+
ci_data[f"{algo}_{speaker}_ps_bias"] = model_data['ps_bias_frames'][speaker]
|
| 459 |
+
ci_data[f"{algo}_{speaker}_ps_prob"] = model_data['ps_prob_frames'][speaker]
|
| 460 |
+
ci_data[f"{algo}_{speaker}_pm_bias"] = model_data['pm_bias_frames'][speaker]
|
| 461 |
+
ci_data[f"{algo}_{speaker}_pm_prob"] = model_data['pm_prob_frames'][speaker]
|
| 462 |
+
|
| 463 |
+
# Save CSV files for this mixture
|
| 464 |
+
mixture_dir = os.path.join(exp_root, mixture_id)
|
| 465 |
+
os.makedirs(mixture_dir, exist_ok=True)
|
| 466 |
+
|
| 467 |
+
pd.DataFrame(ps_data).to_csv(
|
| 468 |
+
os.path.join(mixture_dir, f"ps_scores_{model}.csv"),
|
| 469 |
+
index=False
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
pd.DataFrame(pm_data).to_csv(
|
| 473 |
+
os.path.join(mixture_dir, f"pm_scores_{model}.csv"),
|
| 474 |
+
index=False
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
if add_ci and ci_data is not None:
|
| 478 |
+
pd.DataFrame(ci_data).to_csv(
|
| 479 |
+
os.path.join(mixture_dir, f"ci_{model}.csv"),
|
| 480 |
+
index=False
|
| 481 |
)
|
| 482 |
|
| 483 |
del all_outs
|
models.py
CHANGED
|
@@ -1,360 +1,324 @@
|
|
| 1 |
-
import queue
|
| 2 |
-
import threading
|
| 3 |
-
import gc
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
from transformers import (
|
| 8 |
-
HubertModel,
|
| 9 |
-
Wav2Vec2FeatureExtractor,
|
| 10 |
-
Wav2Vec2Model,
|
| 11 |
-
WavLMModel,
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
self.
|
| 24 |
-
self.
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
self.
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
#
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
for
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
if
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
)
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
all_keeps
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
L_max = max(x.shape[0] for x in all_keeps)
|
| 325 |
-
keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
|
| 326 |
-
result = torch.stack(keep_padded, dim=0)
|
| 327 |
-
# Clean up intermediate lists
|
| 328 |
-
del all_keeps, keep_padded
|
| 329 |
-
return result
|
| 330 |
-
else:
|
| 331 |
-
return torch.empty(0, 0, 0)
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
|
| 335 |
-
if model_wrapper == "raw":
|
| 336 |
-
return embed_batch_raw(signals, masks_audio)
|
| 337 |
-
if isinstance(model_wrapper, BalancedDualGPUModel):
|
| 338 |
-
all_embeddings = []
|
| 339 |
-
batch_size = min(BATCH_SIZE, 2)
|
| 340 |
-
for i in range(0, len(signals), batch_size):
|
| 341 |
-
batch_emb = model_wrapper.process_batch(
|
| 342 |
-
signals[i: i + batch_size], masks_audio[i: i + batch_size], use_mlm
|
| 343 |
-
)
|
| 344 |
-
if batch_emb.numel() > 0:
|
| 345 |
-
all_embeddings.append(batch_emb)
|
| 346 |
-
# Clear cache after each batch
|
| 347 |
-
if torch.cuda.is_available():
|
| 348 |
-
torch.cuda.empty_cache()
|
| 349 |
-
|
| 350 |
-
if all_embeddings:
|
| 351 |
-
result = torch.cat(all_embeddings, dim=0)
|
| 352 |
-
del all_embeddings
|
| 353 |
-
return result
|
| 354 |
-
else:
|
| 355 |
-
return torch.empty(0, 0, 0)
|
| 356 |
-
else:
|
| 357 |
-
extractor, model = model_wrapper
|
| 358 |
-
return embed_batch_single_gpu(
|
| 359 |
-
signals, masks_audio, extractor, model, layer, use_mlm=use_mlm
|
| 360 |
)
|
|
|
|
| 1 |
+
import queue
|
| 2 |
+
import threading
|
| 3 |
+
import gc
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers import (
|
| 8 |
+
HubertModel,
|
| 9 |
+
Wav2Vec2FeatureExtractor,
|
| 10 |
+
Wav2Vec2Model,
|
| 11 |
+
WavLMModel,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from config import BATCH_SIZE, ENERGY_HOP_MS, ENERGY_WIN_MS, SR
|
| 15 |
+
from utils import get_gpu_count
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BalancedDualGPUModel:
|
| 19 |
+
|
| 20 |
+
def __init__(self, model_name, layer, max_gpus=None):
|
| 21 |
+
self.layer = layer
|
| 22 |
+
self.models = []
|
| 23 |
+
self.extractors = []
|
| 24 |
+
self.devices = []
|
| 25 |
+
ngpu = get_gpu_count(max_gpus)
|
| 26 |
+
|
| 27 |
+
for gpu_id in range(min(ngpu, 2)):
|
| 28 |
+
device = f"cuda:{gpu_id}"
|
| 29 |
+
self.devices.append(device)
|
| 30 |
+
ckpt, cls, _ = get_model_config(layer)[model_name]
|
| 31 |
+
extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
|
| 32 |
+
|
| 33 |
+
attn_impl = "eager" if cls is WavLMModel else "sdpa"
|
| 34 |
+
model = cls.from_pretrained(
|
| 35 |
+
ckpt,
|
| 36 |
+
output_hidden_states=True,
|
| 37 |
+
use_safetensors=True,
|
| 38 |
+
torch_dtype=torch.float16,
|
| 39 |
+
low_cpu_mem_usage=True,
|
| 40 |
+
attn_implementation=attn_impl
|
| 41 |
+
)
|
| 42 |
+
model.eval()
|
| 43 |
+
model = model.to(device)
|
| 44 |
+
|
| 45 |
+
for param in model.parameters():
|
| 46 |
+
param.requires_grad = False
|
| 47 |
+
|
| 48 |
+
self.extractors.append(extractor)
|
| 49 |
+
self.models.append(model)
|
| 50 |
+
|
| 51 |
+
self.gpu_queues = [queue.Queue() for _ in range(len(self.devices))]
|
| 52 |
+
self.result_queue = queue.Queue()
|
| 53 |
+
self.workers = []
|
| 54 |
+
for i in range(len(self.devices)):
|
| 55 |
+
worker = threading.Thread(target=self._gpu_worker, args=(i,))
|
| 56 |
+
worker.daemon = True
|
| 57 |
+
worker.start()
|
| 58 |
+
self.workers.append(worker)
|
| 59 |
+
|
| 60 |
+
def _gpu_worker(self, gpu_id):
|
| 61 |
+
device = self.devices[gpu_id]
|
| 62 |
+
model = self.models[gpu_id]
|
| 63 |
+
extractor = self.extractors[gpu_id]
|
| 64 |
+
while True:
|
| 65 |
+
task = self.gpu_queues[gpu_id].get()
|
| 66 |
+
if task is None:
|
| 67 |
+
break
|
| 68 |
+
signals, masks, use_mlm, task_id = task
|
| 69 |
+
try:
|
| 70 |
+
inputs = extractor(
|
| 71 |
+
signals, sampling_rate=SR, return_tensors="pt", padding=True
|
| 72 |
+
)
|
| 73 |
+
input_values = inputs.input_values.to(device, non_blocking=True)
|
| 74 |
+
|
| 75 |
+
torch.cuda.empty_cache()
|
| 76 |
+
|
| 77 |
+
orig_mode = model.training
|
| 78 |
+
model.train() if use_mlm else model.eval()
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
| 81 |
+
hs = model(
|
| 82 |
+
input_values, output_hidden_states=True
|
| 83 |
+
).hidden_states[self.layer]
|
| 84 |
+
model.train(orig_mode)
|
| 85 |
+
|
| 86 |
+
B, T, D = hs.shape
|
| 87 |
+
keep = []
|
| 88 |
+
for b in range(B):
|
| 89 |
+
mask_b = masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
|
| 90 |
+
mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
|
| 91 |
+
keep.append(hs[b][mask_t].cpu())
|
| 92 |
+
|
| 93 |
+
# Aggressive cleanup
|
| 94 |
+
del hs, input_values, inputs
|
| 95 |
+
torch.cuda.empty_cache()
|
| 96 |
+
|
| 97 |
+
if keep:
|
| 98 |
+
L_max = max(x.shape[0] for x in keep)
|
| 99 |
+
keep_padded = [
|
| 100 |
+
F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in keep
|
| 101 |
+
]
|
| 102 |
+
result = torch.stack(keep_padded, dim=0)
|
| 103 |
+
else:
|
| 104 |
+
result = torch.empty(0, 0, 0)
|
| 105 |
+
self.result_queue.put((task_id, result))
|
| 106 |
+
except Exception as e:
|
| 107 |
+
self.result_queue.put((task_id, e))
|
| 108 |
+
finally:
|
| 109 |
+
# Always clear cache after processing
|
| 110 |
+
torch.cuda.empty_cache()
|
| 111 |
+
|
| 112 |
+
def process_batch(self, signals, masks, use_mlm=False):
|
| 113 |
+
if not signals:
|
| 114 |
+
return torch.empty(0, 0, 0)
|
| 115 |
+
batch_size = len(signals)
|
| 116 |
+
split = (batch_size + len(self.devices) - 1) // len(self.devices)
|
| 117 |
+
results = {}
|
| 118 |
+
task_id = 0
|
| 119 |
+
for i in range(0, batch_size, split):
|
| 120 |
+
end = min(i + split, batch_size)
|
| 121 |
+
gpu_id = (i // split) % len(self.devices)
|
| 122 |
+
self.gpu_queues[gpu_id].put(
|
| 123 |
+
(signals[i:end], masks[i:end], use_mlm, task_id)
|
| 124 |
+
)
|
| 125 |
+
task_id += 1
|
| 126 |
+
for _ in range(task_id):
|
| 127 |
+
tid, result = self.result_queue.get()
|
| 128 |
+
if isinstance(result, Exception):
|
| 129 |
+
raise result
|
| 130 |
+
results[tid] = result
|
| 131 |
+
parts = [results[i] for i in range(task_id) if results[i].numel() > 0]
|
| 132 |
+
return torch.cat(parts, dim=0) if parts else torch.empty(0, 0, 0)
|
| 133 |
+
|
| 134 |
+
def cleanup(self):
|
| 135 |
+
"""Explicit cleanup method"""
|
| 136 |
+
for q in self.gpu_queues:
|
| 137 |
+
q.put(None)
|
| 138 |
+
for w in self.workers:
|
| 139 |
+
w.join(timeout=5.0)
|
| 140 |
+
for model in self.models:
|
| 141 |
+
del model
|
| 142 |
+
for extractor in self.extractors:
|
| 143 |
+
del extractor
|
| 144 |
+
self.models.clear()
|
| 145 |
+
self.extractors.clear()
|
| 146 |
+
torch.cuda.empty_cache()
|
| 147 |
+
gc.collect()
|
| 148 |
+
|
| 149 |
+
def __del__(self):
|
| 150 |
+
self.cleanup()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# NO CACHE - we need to clean up models properly between runs
|
| 154 |
+
def get_model_config(layer):
|
| 155 |
+
return {
|
| 156 |
+
"raw": (None, None, None),
|
| 157 |
+
"wavlm": ("microsoft/wavlm-large", WavLMModel, layer),
|
| 158 |
+
"wav2vec2": ("facebook/wav2vec2-large-lv60", Wav2Vec2Model, layer),
|
| 159 |
+
"hubert": ("facebook/hubert-large-ll60k", HubertModel, layer),
|
| 160 |
+
"wavlm_base": ("microsoft/wavlm-base", WavLMModel, layer),
|
| 161 |
+
"wav2vec2_base": ("facebook/wav2vec2-base", Wav2Vec2Model, layer),
|
| 162 |
+
"hubert_base": ("facebook/hubert-base-ls960", HubertModel, layer),
|
| 163 |
+
"wav2vec2_xlsr": ("facebook/wav2vec2-large-xlsr-53", Wav2Vec2Model, layer),
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# Store loaded models globally to properly manage them
|
| 168 |
+
_loaded_models = {}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def load_model(name, layer, max_gpus=None):
|
| 172 |
+
global _loaded_models
|
| 173 |
+
|
| 174 |
+
# Clean up any previously loaded models first
|
| 175 |
+
if _loaded_models:
|
| 176 |
+
for key, model_data in _loaded_models.items():
|
| 177 |
+
if isinstance(model_data, tuple) and len(model_data) == 2:
|
| 178 |
+
if isinstance(model_data[0], BalancedDualGPUModel):
|
| 179 |
+
model_data[0].cleanup()
|
| 180 |
+
elif isinstance(model_data[0], tuple):
|
| 181 |
+
# Single GPU model
|
| 182 |
+
_, model = model_data[0]
|
| 183 |
+
del model
|
| 184 |
+
_loaded_models.clear()
|
| 185 |
+
torch.cuda.empty_cache()
|
| 186 |
+
gc.collect()
|
| 187 |
+
|
| 188 |
+
if name.lower() in {"raw", "waveform"}:
|
| 189 |
+
return "raw", layer
|
| 190 |
+
|
| 191 |
+
ngpu = get_gpu_count(max_gpus)
|
| 192 |
+
if ngpu > 1:
|
| 193 |
+
model = BalancedDualGPUModel(name, layer, max_gpus)
|
| 194 |
+
_loaded_models[name] = (model, layer)
|
| 195 |
+
return model, layer
|
| 196 |
+
else:
|
| 197 |
+
ckpt, cls, layer_eff = get_model_config(layer)[name]
|
| 198 |
+
extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
|
| 199 |
+
|
| 200 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 201 |
+
attn_impl = "eager" if cls is WavLMModel else "sdpa"
|
| 202 |
+
model = cls.from_pretrained(
|
| 203 |
+
ckpt,
|
| 204 |
+
output_hidden_states=True,
|
| 205 |
+
use_safetensors=True,
|
| 206 |
+
torch_dtype=torch.float16,
|
| 207 |
+
low_cpu_mem_usage=True,
|
| 208 |
+
attn_implementation=attn_impl
|
| 209 |
+
)
|
| 210 |
+
model.eval()
|
| 211 |
+
model = model.to(device)
|
| 212 |
+
|
| 213 |
+
for param in model.parameters():
|
| 214 |
+
param.requires_grad = False
|
| 215 |
+
|
| 216 |
+
model_tuple = ((extractor, model), layer_eff)
|
| 217 |
+
_loaded_models[name] = model_tuple
|
| 218 |
+
return (extractor, model), layer_eff
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def cleanup_all_models():
|
| 222 |
+
"""Call this at the end of each experiment to ensure complete cleanup"""
|
| 223 |
+
global _loaded_models
|
| 224 |
+
if _loaded_models:
|
| 225 |
+
for key, model_data in _loaded_models.items():
|
| 226 |
+
if isinstance(model_data, tuple) and len(model_data) == 2:
|
| 227 |
+
if isinstance(model_data[0], BalancedDualGPUModel):
|
| 228 |
+
model_data[0].cleanup()
|
| 229 |
+
elif isinstance(model_data[0], tuple):
|
| 230 |
+
# Single GPU model
|
| 231 |
+
_, model = model_data[0]
|
| 232 |
+
del model
|
| 233 |
+
_loaded_models.clear()
|
| 234 |
+
torch.cuda.empty_cache()
|
| 235 |
+
gc.collect()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def embed_batch_raw(signals, masks_audio):
|
| 239 |
+
win = int(ENERGY_WIN_MS * SR / 1000)
|
| 240 |
+
hop = int(ENERGY_HOP_MS * SR / 1000)
|
| 241 |
+
reps, L_max = [], 0
|
| 242 |
+
for sig_np, mask_np in zip(signals, masks_audio):
|
| 243 |
+
x = torch.as_tensor(sig_np[:-1], dtype=torch.float32)
|
| 244 |
+
frames = x.unfold(0, win, hop)
|
| 245 |
+
mask = torch.as_tensor(mask_np[: len(frames)], dtype=torch.bool)
|
| 246 |
+
keep = frames[mask] if mask.any() else frames[:1]
|
| 247 |
+
reps.append(keep)
|
| 248 |
+
L_max = max(L_max, keep.size(0))
|
| 249 |
+
reps = [F.pad(r, (0, 0, 0, L_max - r.size(0))) for r in reps]
|
| 250 |
+
return torch.stack(reps, dim=0)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def embed_batch_single_gpu(
|
| 254 |
+
signals, masks_audio, extractor, model, layer, use_mlm=False
|
| 255 |
+
):
|
| 256 |
+
if not signals:
|
| 257 |
+
return torch.empty(0, 0, 0)
|
| 258 |
+
device = next(model.parameters()).device
|
| 259 |
+
|
| 260 |
+
max_batch = 2
|
| 261 |
+
all_keeps = []
|
| 262 |
+
|
| 263 |
+
for i in range(0, len(signals), max_batch):
|
| 264 |
+
batch_signals = signals[i:i + max_batch]
|
| 265 |
+
batch_masks = masks_audio[i:i + max_batch]
|
| 266 |
+
|
| 267 |
+
inputs = extractor(batch_signals, sampling_rate=SR, return_tensors="pt", padding=True)
|
| 268 |
+
input_values = inputs.input_values.to(device, non_blocking=True)
|
| 269 |
+
|
| 270 |
+
orig_mode = model.training
|
| 271 |
+
model.train() if use_mlm else model.eval()
|
| 272 |
+
|
| 273 |
+
with torch.no_grad():
|
| 274 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
| 275 |
+
hs = model(input_values, output_hidden_states=True).hidden_states[layer]
|
| 276 |
+
model.train(orig_mode)
|
| 277 |
+
|
| 278 |
+
B, T, D = hs.shape
|
| 279 |
+
for b in range(B):
|
| 280 |
+
mask_b = batch_masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
|
| 281 |
+
mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
|
| 282 |
+
all_keeps.append(hs[b][mask_t].cpu())
|
| 283 |
+
|
| 284 |
+
# Aggressive cleanup
|
| 285 |
+
del hs, input_values, inputs
|
| 286 |
+
torch.cuda.empty_cache()
|
| 287 |
+
|
| 288 |
+
if all_keeps:
|
| 289 |
+
L_max = max(x.shape[0] for x in all_keeps)
|
| 290 |
+
keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
|
| 291 |
+
result = torch.stack(keep_padded, dim=0)
|
| 292 |
+
# Clean up intermediate lists
|
| 293 |
+
del all_keeps, keep_padded
|
| 294 |
+
return result
|
| 295 |
+
else:
|
| 296 |
+
return torch.empty(0, 0, 0)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
|
| 300 |
+
if model_wrapper == "raw":
|
| 301 |
+
return embed_batch_raw(signals, masks_audio)
|
| 302 |
+
if isinstance(model_wrapper, BalancedDualGPUModel):
|
| 303 |
+
all_embeddings = []
|
| 304 |
+
batch_size = min(BATCH_SIZE, 2)
|
| 305 |
+
for i in range(0, len(signals), batch_size):
|
| 306 |
+
batch_emb = model_wrapper.process_batch(
|
| 307 |
+
signals[i: i + batch_size], masks_audio[i: i + batch_size], use_mlm
|
| 308 |
+
)
|
| 309 |
+
if batch_emb.numel() > 0:
|
| 310 |
+
all_embeddings.append(batch_emb)
|
| 311 |
+
# Clear cache after each batch
|
| 312 |
+
torch.cuda.empty_cache()
|
| 313 |
+
|
| 314 |
+
if all_embeddings:
|
| 315 |
+
result = torch.cat(all_embeddings, dim=0)
|
| 316 |
+
del all_embeddings
|
| 317 |
+
return result
|
| 318 |
+
else:
|
| 319 |
+
return torch.empty(0, 0, 0)
|
| 320 |
+
else:
|
| 321 |
+
extractor, model = model_wrapper
|
| 322 |
+
return embed_batch_single_gpu(
|
| 323 |
+
signals, masks_audio, extractor, model, layer, use_mlm=use_mlm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
)
|