File size: 3,999 Bytes
bb22a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import pandas as pd


gr.Markdown(
    """
    <style>
        .center-btn button {
            margin-left: auto;
            margin-right: auto;
            display: block;
        }
    </style>
    """
)


# Load model and tokenizer
model_name = "ale-dp/xlm-roberta-email-classifier"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Label map
label_map = {
    0: 'Billing and Payments',
    1: 'Customer Service',
    2: 'General Inquiry',
    3: 'Human Resources',
    4: 'IT Support',
    5: 'Product Support',
    6: 'Returns and Exchanges',
    7: 'Sales and Pre-Sales',
    8: 'Service Outages and Maintenance',
    9: 'Technical Support'
}

# Prediction function
def classify_email_with_probs(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits[0]
    probs = torch.nn.functional.softmax(logits, dim=0)

    prob_dict = {label_map[i]: round(float(probs[i]) * 100, 2) for i in range(len(probs))}
    sorted_probs = dict(sorted(prob_dict.items(), key=lambda item: item[1], reverse=True))
    df = pd.DataFrame(sorted_probs.items(), columns=["Category", "Confidence (%)"])
    top_label = df.iloc[0]["Category"]
    return top_label, df

# Sample emails
examples = [
    "Hello, I recently purchased a pair of headphones from your online store (Order #48392) and unfortunately, they arrived damaged. The left earcup is completely detached and the sound is distorted. I’d like to request a return or exchange. Please let me know the steps I need to follow and whether I need to ship the item back first. Thank you for your assistance.",
    "Dear Customer Support Team,\n\nI hope this message reaches you well. I am reaching out to request detailed billing details and payment options for a QuickBooks Online subscription. Specifically, I am interested in understanding the available plans, their pricing structures, and any tailored options for institutional clients within the financial services industry.",
    "Hello, I’m reaching out on behalf of a mid-sized retail company interested in your cloud-based inventory solution. We’re currently evaluating vendors and would appreciate a demo of your platform, along with pricing tiers for teams of 50+ users. Please let me know your availability this week for a call.",
    "Currently facing sporadic connectivity difficulties with the cloud-native SaaS system. The suspected reason appears to be linked to orchestration resource distribution within Kubernetes-managed microservices. After restarting the affected services and examining deployment logs, the issue continues. Further investigation and escalation are required to resolve this matter swiftly."
]

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 📬 Email Ticket Classifier")
    gr.Markdown("Classify emails into support categories using XLM-RoBERTa. See top prediction and full confidence breakdown.")

    email_input = gr.Textbox(
        lines=12,
        label="Email Text",
        placeholder="Paste your email here...",
        elem_id="email_input"
    )

    with gr.Row():
        submit_btn = gr.Button("Classify", variant="primary", elem_classes="center-btn")

    gr.Markdown("<br><br>")

    gr.Markdown("### Examples:")

    with gr.Column():
        for example in examples:
            gr.Button(example).click(fn=lambda x=example: x, outputs=email_input)

    
    top_label = gr.Label(label="Predicted Category")
    prob_table = gr.Dataframe(
        headers=["Category", "Confidence (%)"],
        label="Confidence Breakdown",
        datatype=["str", "number"],
        row_count=10
    )

    submit_btn.click(fn=classify_email_with_probs, inputs=email_input, outputs=[top_label, prob_table])

demo.launch()