AB739 commited on
Commit
017385e
·
verified ·
1 Parent(s): 578ee05

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +8 -12
tasks/audio.py CHANGED
@@ -66,25 +66,22 @@ async def evaluate_audio(request: AudioEvaluationRequest):
66
  _waveform = _resampler(_waveform)
67
  return _waveform
68
 
69
- resized_waveforms = [
70
- resize_audio(torch.tensor(sample['audio']['array'], dtype=torch.float32).unsqueeze(0), target_length=72000)
71
- for sample in test_dataset
72
- ]
73
 
74
- waveforms, labels = [], []
75
- for waveform, label in zip(resized_waveforms, true_labels):
76
- waveforms.append(amplitude_to_db(mel_transform(resampler(waveform))))
77
- labels.append(label)
78
 
79
  waveforms = torch.stack(waveforms)
80
- labels = torch.tensor(labels)
81
 
82
  test_loader = DataLoader(
83
  TensorDataset(waveforms, labels),
84
  batch_size=128,
85
  shuffle=False,
86
  pin_memory=True,
87
- num_workers=2
88
  )
89
 
90
  scripted_model = torch.jit.load("./optimized_qat_blazeface_model.pt", map_location=torch.device('cpu'))
@@ -101,8 +98,7 @@ async def evaluate_audio(request: AudioEvaluationRequest):
101
  predictions = []
102
  with torch.no_grad():
103
  #with autocast():
104
- #with torch.amp.autocast(device_type='cpu'):
105
- with torch.autocast(device_type='cpu'):
106
  for data, target in test_loader:
107
  outputs = scripted_model(data)
108
  _, predicted = torch.max(outputs, 1)
 
66
  _waveform = _resampler(_waveform)
67
  return _waveform
68
 
69
+ def preprocess_audio(sample):
70
+ waveform = torch.tensor(sample['audio']['array'], dtype=torch.float32).unsqueeze(0)
71
+ resized_waveform = resize_audio(waveform, target_length=72000)
72
+ return amplitude_to_db(mel_transform(resampler(resized_waveform)))
73
 
74
+ waveforms = [preprocess_audio(sample) for sample in test_dataset]
75
+ labels = torch.tensor(true_labels)
 
 
76
 
77
  waveforms = torch.stack(waveforms)
 
78
 
79
  test_loader = DataLoader(
80
  TensorDataset(waveforms, labels),
81
  batch_size=128,
82
  shuffle=False,
83
  pin_memory=True,
84
+ num_workers=4
85
  )
86
 
87
  scripted_model = torch.jit.load("./optimized_qat_blazeface_model.pt", map_location=torch.device('cpu'))
 
98
  predictions = []
99
  with torch.no_grad():
100
  #with autocast():
101
+ with torch.amp.autocast(device_type='cpu'):
 
102
  for data, target in test_loader:
103
  outputs = scripted_model(data)
104
  _, predicted = torch.max(outputs, 1)