Spaces:
Runtime error
Runtime error
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
| import torch | |
| import numpy as np | |
| # Load dataset | |
| dataset = load_dataset("imranraad/github-emotion-love") | |
| # Multi-label setup | |
| emotions = ["Anger", "Love", "Fear", "Joy", "Sadness", "Surprise"] | |
| # Tokenizer | |
| model_name = "distilbert-base-uncased" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| def tokenize(batch): | |
| return tokenizer(batch['modified_comment'], padding='max_length', truncation=True, max_length=128) | |
| dataset = dataset.map(tokenize, batched=True) | |
| # Convert labels to list of floats for multi-label | |
| def format_labels(batch): | |
| batch["labels"] = [[batch[emo][i] for emo in emotions] for i in range(len(batch[emotions[0]]))] | |
| return batch | |
| dataset = dataset.map(format_labels, batched=True) | |
| # Load model | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| num_labels=len(emotions), | |
| problem_type="multi_label_classification" | |
| ) | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir="./model", | |
| evaluation_strategy="epoch", | |
| learning_rate=2e-5, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=16, | |
| num_train_epochs=3, | |
| weight_decay=0.01, | |
| logging_dir="./logs", | |
| save_strategy="epoch" | |
| ) | |
| # Metrics | |
| def compute_metrics(pred): | |
| logits, labels = pred | |
| sigmoid = torch.nn.Sigmoid() | |
| probs = sigmoid(torch.tensor(logits)) | |
| preds = (probs > 0.5).float() | |
| accuracy = (preds == torch.tensor(labels)).float().mean() | |
| return {"accuracy": accuracy.item()} | |
| # Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"], | |
| tokenizer=tokenizer, | |
| compute_metrics=compute_metrics | |
| ) | |
| trainer.train() | |
| trainer.save_model("./model") | |