Spaces:
Build error
Build error
| import os | |
| import json | |
| import random | |
| import argparse | |
| import torch | |
| import torchaudio | |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset | |
| from huggingface_hub import upload_folder | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix | |
| from collections import Counter | |
| from transformers.integrations import TensorBoardCallback | |
| from transformers import ( | |
| Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification, | |
| Trainer, TrainingArguments, | |
| EarlyStoppingCallback | |
| ) | |
| MODEL = "ntu-spml/distilhubert" # modelo base | |
| FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(MODEL) # feature extractor del modelo base | |
| seed = 123 | |
| MAX_DURATION = 1.00 # Máxima duración de los audios | |
| SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16kHz | |
| token = os.getenv("HF_TOKEN") | |
| config_file = "models_config.json" | |
| batch_size = 1024 # TODO: repasar si sigue siendo necesario | |
| num_workers = 12 # Núcleos de la CPU | |
| class AudioDataset(Dataset): | |
| def __init__(self, dataset_path, label2id, filter_white_noise, undersample_normal): | |
| self.dataset_path = dataset_path | |
| self.label2id = label2id | |
| self.file_paths = [] | |
| self.filter_white_noise = filter_white_noise | |
| self.labels = [] | |
| for label_dir, label_id in self.label2id.items(): | |
| label_path = os.path.join(self.dataset_path, label_dir) | |
| if os.path.isdir(label_path): | |
| for file_name in os.listdir(label_path): | |
| audio_path = os.path.join(label_path, file_name) | |
| self.file_paths.append(audio_path) | |
| self.labels.append(label_id) | |
| if undersample_normal and self.label2id: | |
| self.undersample_normal_class() | |
| def undersample_normal_class(self): | |
| normal_label = self.label2id.get('1s_normal') | |
| label_counts = Counter(self.labels) | |
| other_counts = [count for label, count in label_counts.items() if label != normal_label] | |
| if other_counts: # Ensure there are other counts before taking max | |
| target_count = max(other_counts) | |
| normal_indices = [i for i, label in enumerate(self.labels) if label == normal_label] | |
| keep_indices = random.sample(normal_indices, target_count) | |
| new_file_paths = [] | |
| new_labels = [] | |
| for i, (path, label) in enumerate(zip(self.file_paths, self.labels)): | |
| if label != normal_label or i in keep_indices: | |
| new_file_paths.append(path) | |
| new_labels.append(label) | |
| self.file_paths = new_file_paths | |
| self.labels = new_labels | |
| def __len__(self): | |
| return len(self.file_paths) | |
| def __getitem__(self, idx): | |
| audio_path = self.file_paths[idx] | |
| label = self.labels[idx] | |
| input_values = self.preprocess_audio(audio_path) | |
| return { | |
| "input_values": input_values, | |
| "labels": torch.tensor(label) | |
| } | |
| def preprocess_audio(self, audio_path): | |
| waveform, sample_rate = torchaudio.load( | |
| audio_path, | |
| normalize=True, | |
| ) | |
| if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz | |
| resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE) | |
| waveform = resampler(waveform) | |
| if waveform.shape[0] > 1: # Si es stereo, convertir a mono | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # TODO: probar a quitar porque ya se hace, sin 1e-6 el accuracy es pésimo!! | |
| max_length = int(SAMPLING_RATE * MAX_DURATION) | |
| if waveform.shape[1] > max_length: | |
| waveform = waveform[:, :max_length] # Truncar | |
| else: | |
| waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) # Padding | |
| inputs = FEATURE_EXTRACTOR( | |
| waveform.squeeze(), | |
| sampling_rate=SAMPLING_RATE, # Hecho a mano, por si acaso | |
| return_tensors="pt", | |
| ) | |
| return inputs.input_values.squeeze() | |
| def is_white_noise(audio): | |
| mean = torch.mean(audio) | |
| std = torch.std(audio) | |
| return torch.abs(mean) < 0.001 and std < 0.01 | |
| def seed_everything(): # TODO: mirar si es necesario algo más | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| # torch.backends.cudnn.deterministic = True # Para reproducibilidad | |
| # torch.backends.cudnn.benchmark = False # Para reproducibilidad | |
| def build_label_mappings(dataset_path): | |
| label2id = {} | |
| id2label = {} | |
| label_id = 0 | |
| for label_dir in os.listdir(dataset_path): | |
| if os.path.isdir(os.path.join(dataset_path, label_dir)): | |
| label2id[label_dir] = label_id | |
| id2label[label_id] = label_dir | |
| label_id += 1 | |
| return label2id, id2label | |
| def compute_class_weights(labels): | |
| class_counts = Counter(labels) | |
| total_samples = len(labels) | |
| class_weights = {cls: total_samples / count for cls, count in class_counts.items()} | |
| return [class_weights[label] for label in labels] | |
| def create_dataloader(dataset_path, filter_white_noise, undersample_normal, test_size=0.2, shuffle=True, pin_memory=True): | |
| label2id, id2label = build_label_mappings(dataset_path) | |
| dataset = AudioDataset(dataset_path, label2id, filter_white_noise, undersample_normal) | |
| dataset_size = len(dataset) | |
| indices = list(range(dataset_size)) | |
| random.shuffle(indices) | |
| split_idx = int(dataset_size * (1 - test_size)) | |
| train_indices = indices[:split_idx] | |
| test_indices = indices[split_idx:] | |
| train_dataset = Subset(dataset, train_indices) | |
| test_dataset = Subset(dataset, test_indices) | |
| labels = [dataset.labels[i] for i in train_indices] | |
| class_weights = compute_class_weights(labels) | |
| sampler = WeightedRandomSampler( | |
| weights=class_weights, | |
| num_samples=len(train_dataset), | |
| replacement=True | |
| ) | |
| train_dataloader = DataLoader( | |
| train_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory | |
| ) | |
| test_dataloader = DataLoader( | |
| test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory | |
| ) | |
| return train_dataloader, test_dataloader, id2label | |
| def load_model(model_path, id2label, num_labels): | |
| config = HubertConfig.from_pretrained( | |
| pretrained_model_name_or_path=model_path, | |
| num_labels=num_labels, | |
| id2label=id2label, | |
| finetuning_task="audio-classification" | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = HubertForSequenceClassification.from_pretrained( | |
| pretrained_model_name_or_path=model_path, | |
| config=config, | |
| torch_dtype=torch.float32, # TODO: Comprobar si se necesita float32 y ver si se puede cambiar por float16 | |
| ) | |
| model.to(device) | |
| return model | |
| def train_params(dataset_path, filter_white_noise, undersample_normal): | |
| train_dataloader, test_dataloader, id2label = create_dataloader(dataset_path, filter_white_noise, undersample_normal) | |
| model = load_model(MODEL, id2label, num_labels=len(id2label)) | |
| return model, train_dataloader, test_dataloader, id2label | |
| def predict_params(dataset_path, model_path, filter_white_noise, undersample_normal): | |
| _, _, id2label = create_dataloader(dataset_path, filter_white_noise, undersample_normal) | |
| model = load_model(model_path, id2label, num_labels=len(id2label)) | |
| return model, id2label | |
| def compute_metrics(pred): | |
| labels = pred.label_ids | |
| preds = pred.predictions.argmax(-1) | |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted') | |
| acc = accuracy_score(labels, preds) | |
| cm = confusion_matrix(labels, preds) | |
| return { | |
| 'accuracy': acc, | |
| 'f1': f1, | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'confusion_matrix': cm.tolist() | |
| } | |
| def main(training_args, output_dir, dataset_path, filter_white_noise, undersample_normal): | |
| seed_everything() | |
| model, train_dataloader, test_dataloader, id2label = train_params(dataset_path, filter_white_noise, undersample_normal) | |
| early_stopping_callback = EarlyStoppingCallback( | |
| early_stopping_patience=5, | |
| early_stopping_threshold=0.001 | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| compute_metrics=compute_metrics, | |
| train_dataset=train_dataloader.dataset, | |
| eval_dataset=test_dataloader.dataset, | |
| callbacks=[TensorBoardCallback, early_stopping_callback] | |
| ) | |
| torch.cuda.empty_cache() # liberar memoria de la GPU | |
| trainer.train() # resume_from_checkpoint para continuar el train | |
| # trainer.save_model(output_dir) # Guardar modelo local. | |
| os.makedirs(output_dir, exist_ok=True) | |
| trainer.save_model(output_dir) # Guardar modelo local. | |
| eval_results = trainer.evaluate() | |
| print(f"Evaluation results: {eval_results}") | |
| trainer.push_to_hub(token=token) # Subir modelo a perfil | |
| upload_folder(repo_id=f"A-POR-LOS-8000/{output_dir}", folder_path=output_dir, token=token) # subir a organización y local | |
| def predict(audio_path): | |
| waveform, sample_rate = torchaudio.load(audio_path, normalize=True) | |
| if sample_rate != SAMPLING_RATE: | |
| resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE) | |
| waveform = resampler(waveform) | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) | |
| max_length = int(SAMPLING_RATE * MAX_DURATION) | |
| if waveform.shape[1] > max_length: | |
| waveform = waveform[:, :max_length] | |
| else: | |
| waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) | |
| inputs = FEATURE_EXTRACTOR( | |
| waveform.squeeze(), | |
| sampling_rate=SAMPLING_RATE, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| logits = model(inputs.input_values.to(model.device)).logits | |
| predicted_class_id = logits.argmax().item() | |
| predicted_label = id2label[predicted_class_id] | |
| return predicted_label, logits | |
| test_samples = random.sample(test_dataloader.dataset.dataset.file_paths, 15) | |
| for sample in test_samples: | |
| predicted_label, logits = predict(sample) | |
| print(f"File: {sample}") | |
| print(f"Predicted label: {predicted_label}") | |
| print(f"Logits: {logits}") | |
| print("---") | |
| def load_config(model_name): | |
| with open(config_file, 'r') as f: | |
| config = json.load(f) | |
| model_config = config[model_name] | |
| training_args = TrainingArguments(**model_config["training_args"]) | |
| model_config["training_args"] = training_args | |
| return model_config | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--n", choices=["mon", "class"], | |
| required=True, help="Elegir qué modelo entrenar" | |
| ) | |
| args = parser.parse_args() | |
| config = load_config(args.n) | |
| training_args = config["training_args"] | |
| output_dir = config["output_dir"] | |
| dataset_path = config["dataset_path"] | |
| if args.n == "mon": | |
| filter_white_noise = False | |
| undersample_normal = False | |
| elif args.n == "class": | |
| filter_white_noise = True | |
| undersample_normal = True | |
| main(training_args, output_dir, dataset_path, filter_white_noise, undersample_normal) | |