from fastapi import APIRouter from datetime import datetime from datasets import load_dataset from sklearn.metrics import accuracy_score import os import torch from torch.utils.data import DataLoader, TensorDataset from torchaudio import transforms #from torchvision import models import onnxruntime as ort # Add ONNX Runtime from .utils.evaluation import AudioEvaluationRequest from .utils.emissions import tracker, clean_emissions_data, get_space_info from dotenv import load_dotenv load_dotenv() router = APIRouter() DESCRIPTION = "Tiny_DNN" ROUTE = "/audio" torch.set_num_threads(4) torch.set_num_interop_threads(2) @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION) async def evaluate_audio(request: AudioEvaluationRequest): # Get space info username, space_url = get_space_info() # Define the label mapping LABEL_MAPPING = { "chainsaw": 0, "environment": 1 } # Load and prepare the dataset dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN")) train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed) test_dataset = train_test["test"] true_labels = test_dataset["label"] resampler = transforms.Resample(orig_freq=12000, new_freq=16000) mel_transform = transforms.MelSpectrogram(sample_rate=16000, n_mels=64) amplitude_to_db = transforms.AmplitudeToDB() def resize_audio(_waveform, target_length): num_frames = _waveform.shape[-1] if num_frames != target_length: _resampler = transforms.Resample(orig_freq=num_frames, new_freq=target_length) _waveform = _resampler(_waveform) return _waveform resized_waveforms = [ resize_audio(torch.tensor(sample['audio']['array'], dtype=torch.float32).unsqueeze(0), target_length=72000) for sample in test_dataset ] waveforms, labels = [], [] for waveform, label in zip(resized_waveforms, true_labels): waveforms.append(amplitude_to_db(mel_transform(resampler(waveform)))) labels.append(label) waveforms = torch.stack(waveforms) labels = torch.tensor(labels) test_loader = DataLoader( TensorDataset(waveforms, labels), batch_size=128, shuffle=False, pin_memory=True, num_workers=4 ) # Load ONNX model onnx_model_path = "./output_model.onnx" session_options = ort.SessionOptions() session_options.intra_op_num_threads = 4 session_options.inter_op_num_threads = 2 ort_session = ort.InferenceSession(onnx_model_path, session_options) # Start tracking emissions tracker.start() tracker.start_task("inference") # ONNX inference predictions = [] for data, target in test_loader: inputs = data.numpy() # Convert tensor to numpy ort_inputs = {'input': inputs} ort_outputs = ort_session.run(None, ort_inputs) predicted = ort_outputs[0].argmax(axis=1) # Assuming output shape is [batch_size, num_classes] predictions.extend(predicted.tolist()) # Stop tracking emissions emissions_data = tracker.stop_task() # Calculate accuracy accuracy = accuracy_score(true_labels, predictions) # Prepare results dictionary results = { "username": username, "space_url": space_url, "submission_timestamp": datetime.now().isoformat(), "model_description": DESCRIPTION, "accuracy": float(accuracy), "energy_consumed_wh": emissions_data.energy_consumed * 1000, "emissions_gco2eq": emissions_data.emissions * 1000, "emissions_data": clean_emissions_data(emissions_data), "api_route": ROUTE, "dataset_config": { "dataset_name": request.dataset_name, "test_size": request.test_size, "test_seed": request.test_seed } } return results