AB739 commited on
Commit
9a3271e
·
verified ·
1 Parent(s): 017385e

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +10 -7
tasks/audio.py CHANGED
@@ -66,22 +66,25 @@ async def evaluate_audio(request: AudioEvaluationRequest):
66
  _waveform = _resampler(_waveform)
67
  return _waveform
68
 
69
- def preprocess_audio(sample):
70
- waveform = torch.tensor(sample['audio']['array'], dtype=torch.float32).unsqueeze(0)
71
- resized_waveform = resize_audio(waveform, target_length=72000)
72
- return amplitude_to_db(mel_transform(resampler(resized_waveform)))
73
 
74
- waveforms = [preprocess_audio(sample) for sample in test_dataset]
75
- labels = torch.tensor(true_labels)
 
 
76
 
77
  waveforms = torch.stack(waveforms)
 
78
 
79
  test_loader = DataLoader(
80
  TensorDataset(waveforms, labels),
81
  batch_size=128,
82
  shuffle=False,
83
  pin_memory=True,
84
- num_workers=4
85
  )
86
 
87
  scripted_model = torch.jit.load("./optimized_qat_blazeface_model.pt", map_location=torch.device('cpu'))
 
66
  _waveform = _resampler(_waveform)
67
  return _waveform
68
 
69
+ resized_waveforms = [
70
+ resize_audio(torch.tensor(sample['audio']['array'], dtype=torch.float32).unsqueeze(0), target_length=72000)
71
+ for sample in test_dataset
72
+ ]
73
 
74
+ waveforms, labels = [], []
75
+ for waveform, label in zip(resized_waveforms, true_labels):
76
+ waveforms.append(amplitude_to_db(mel_transform(resampler(waveform))))
77
+ labels.append(label)
78
 
79
  waveforms = torch.stack(waveforms)
80
+ labels = torch.tensor(labels)
81
 
82
  test_loader = DataLoader(
83
  TensorDataset(waveforms, labels),
84
  batch_size=128,
85
  shuffle=False,
86
  pin_memory=True,
87
+ num_workers=4
88
  )
89
 
90
  scripted_model = torch.jit.load("./optimized_qat_blazeface_model.pt", map_location=torch.device('cpu'))