Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import pandas as pd | |
| gr.Markdown( | |
| """ | |
| <style> | |
| .center-btn button { | |
| margin-left: auto; | |
| margin-right: auto; | |
| display: block; | |
| } | |
| </style> | |
| """ | |
| ) | |
| # Load model and tokenizer | |
| model_name = "ale-dp/xlm-roberta-email-classifier" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| # Label map | |
| label_map = { | |
| 0: 'Billing and Payments', | |
| 1: 'Customer Service', | |
| 2: 'General Inquiry', | |
| 3: 'Human Resources', | |
| 4: 'IT Support', | |
| 5: 'Product Support', | |
| 6: 'Returns and Exchanges', | |
| 7: 'Sales and Pre-Sales', | |
| 8: 'Service Outages and Maintenance', | |
| 9: 'Technical Support' | |
| } | |
| # Prediction function | |
| def classify_email_with_probs(text): | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits[0] | |
| probs = torch.nn.functional.softmax(logits, dim=0) | |
| prob_dict = {label_map[i]: round(float(probs[i]) * 100, 2) for i in range(len(probs))} | |
| sorted_probs = dict(sorted(prob_dict.items(), key=lambda item: item[1], reverse=True)) | |
| df = pd.DataFrame(sorted_probs.items(), columns=["Category", "Confidence (%)"]) | |
| top_label = df.iloc[0]["Category"] | |
| return top_label, df | |
| # Sample emails | |
| examples = [ | |
| "Hello, I recently purchased a pair of headphones from your online store (Order #48392) and unfortunately, they arrived damaged. The left earcup is completely detached and the sound is distorted. I’d like to request a return or exchange. Please let me know the steps I need to follow and whether I need to ship the item back first. Thank you for your assistance.", | |
| "Dear Customer Support Team,\n\nI hope this message reaches you well. I am reaching out to request detailed billing details and payment options for a QuickBooks Online subscription. Specifically, I am interested in understanding the available plans, their pricing structures, and any tailored options for institutional clients within the financial services industry.", | |
| "Hello, I’m reaching out on behalf of a mid-sized retail company interested in your cloud-based inventory solution. We’re currently evaluating vendors and would appreciate a demo of your platform, along with pricing tiers for teams of 50+ users. Please let me know your availability this week for a call.", | |
| "Currently facing sporadic connectivity difficulties with the cloud-native SaaS system. The suspected reason appears to be linked to orchestration resource distribution within Kubernetes-managed microservices. After restarting the affected services and examining deployment logs, the issue continues. Further investigation and escalation are required to resolve this matter swiftly." | |
| ] | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 📬 Email Ticket Classifier") | |
| gr.Markdown("Classify emails into support categories using XLM-RoBERTa. See top prediction and full confidence breakdown.") | |
| email_input = gr.Textbox( | |
| lines=12, | |
| label="Email Text", | |
| placeholder="Paste your email here...", | |
| elem_id="email_input" | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Classify", variant="primary", elem_classes="center-btn") | |
| gr.Markdown("<br><br>") | |
| gr.Markdown("### Examples:") | |
| with gr.Column(): | |
| for example in examples: | |
| gr.Button(example).click(fn=lambda x=example: x, outputs=email_input) | |
| top_label = gr.Label(label="Predicted Category") | |
| prob_table = gr.Dataframe( | |
| headers=["Category", "Confidence (%)"], | |
| label="Confidence Breakdown", | |
| datatype=["str", "number"], | |
| row_count=10 | |
| ) | |
| submit_btn.click(fn=classify_email_with_probs, inputs=email_input, outputs=[top_label, prob_table]) | |
| demo.launch() | |