Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter | |
| from datetime import datetime | |
| from datasets import load_dataset | |
| from sklearn.metrics import accuracy_score | |
| import os | |
| import torch | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from torchaudio import transforms | |
| #from torchvision import models | |
| import onnxruntime as ort # Add ONNX Runtime | |
| from .utils.evaluation import AudioEvaluationRequest | |
| from .utils.emissions import tracker, clean_emissions_data, get_space_info | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| router = APIRouter() | |
| DESCRIPTION = "Tiny_DNN" | |
| ROUTE = "/audio" | |
| torch.set_num_threads(4) | |
| torch.set_num_interop_threads(2) | |
| async def evaluate_audio(request: AudioEvaluationRequest): | |
| # Get space info | |
| username, space_url = get_space_info() | |
| # Define the label mapping | |
| LABEL_MAPPING = { | |
| "chainsaw": 0, | |
| "environment": 1 | |
| } | |
| # Load and prepare the dataset | |
| dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN")) | |
| train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed) | |
| test_dataset = train_test["test"] | |
| true_labels = test_dataset["label"] | |
| resampler = transforms.Resample(orig_freq=12000, new_freq=16000) | |
| mel_transform = transforms.MelSpectrogram(sample_rate=16000, n_mels=64) | |
| amplitude_to_db = transforms.AmplitudeToDB() | |
| def resize_audio(_waveform, target_length): | |
| num_frames = _waveform.shape[-1] | |
| if num_frames != target_length: | |
| _resampler = transforms.Resample(orig_freq=num_frames, new_freq=target_length) | |
| _waveform = _resampler(_waveform) | |
| return _waveform | |
| resized_waveforms = [ | |
| resize_audio(torch.tensor(sample['audio']['array'], dtype=torch.float32).unsqueeze(0), target_length=72000) | |
| for sample in test_dataset | |
| ] | |
| waveforms, labels = [], [] | |
| for waveform, label in zip(resized_waveforms, true_labels): | |
| waveforms.append(amplitude_to_db(mel_transform(resampler(waveform)))) | |
| labels.append(label) | |
| waveforms = torch.stack(waveforms) | |
| labels = torch.tensor(labels) | |
| test_loader = DataLoader( | |
| TensorDataset(waveforms, labels), | |
| batch_size=128, | |
| shuffle=False, | |
| pin_memory=True, | |
| num_workers=4 | |
| ) | |
| # Load ONNX model | |
| onnx_model_path = "./output_model.onnx" | |
| session_options = ort.SessionOptions() | |
| session_options.intra_op_num_threads = 4 | |
| session_options.inter_op_num_threads = 2 | |
| ort_session = ort.InferenceSession(onnx_model_path, session_options) | |
| # Start tracking emissions | |
| tracker.start() | |
| tracker.start_task("inference") | |
| # ONNX inference | |
| predictions = [] | |
| for data, target in test_loader: | |
| inputs = data.numpy() # Convert tensor to numpy | |
| ort_inputs = {'input': inputs} | |
| ort_outputs = ort_session.run(None, ort_inputs) | |
| predicted = ort_outputs[0].argmax(axis=1) # Assuming output shape is [batch_size, num_classes] | |
| predictions.extend(predicted.tolist()) | |
| # Stop tracking emissions | |
| emissions_data = tracker.stop_task() | |
| # Calculate accuracy | |
| accuracy = accuracy_score(true_labels, predictions) | |
| # Prepare results dictionary | |
| results = { | |
| "username": username, | |
| "space_url": space_url, | |
| "submission_timestamp": datetime.now().isoformat(), | |
| "model_description": DESCRIPTION, | |
| "accuracy": float(accuracy), | |
| "energy_consumed_wh": emissions_data.energy_consumed * 1000, | |
| "emissions_gco2eq": emissions_data.emissions * 1000, | |
| "emissions_data": clean_emissions_data(emissions_data), | |
| "api_route": ROUTE, | |
| "dataset_config": { | |
| "dataset_name": request.dataset_name, | |
| "test_size": request.test_size, | |
| "test_seed": request.test_seed | |
| } | |
| } | |
| return results |