import matplotlib.pyplot as plt
import pandas as pd
from utils.model_loader import load_zero_shot
from utils.helpers import fig_to_html, df_to_html_table
def classification_handler(text_input, scenario="Sentiment", multi_label=False, custom_labels=""):
"""Show zero-shot classification capabilities."""
output_html = []
# Add result area container
output_html.append('
')
output_html.append('')
output_html.append("""
Zero-shot classification can categorize text into arbitrary classes without having been specifically trained on those categories.
""")
# Model info
output_html.append("""
Model Used:
- facebook/bart-large-mnli - BART model fine-tuned on MultiNLI dataset
- Capabilities - Can classify text into any user-defined categories
- Performance - Best performance on distinct, well-defined categories
""")
# Classification scenarios
scenarios = {
"Sentiment": ["positive", "negative", "neutral"],
"Emotion": ["joy", "sadness", "anger", "fear", "surprise"],
"Writing Style": ["formal", "informal", "technical", "creative", "persuasive"],
"Intent": ["inform", "persuade", "entertain", "instruct"],
"Content Type": ["news", "opinion", "review", "instruction", "narrative"],
"Audience Level": ["beginner", "intermediate", "advanced", "expert"],
"Custom": []
}
try:
# Get labels based on scenario
if scenario == "Custom":
labels = [label.strip() for label in custom_labels.split("\n") if label.strip()]
if not labels:
output_html.append("""
No Custom Categories
Please enter at least one custom category.
""")
output_html.append('
') # Close result-area div
return '\n'.join(output_html)
else:
labels = scenarios[scenario]
# Update multi-label default for certain categories
if scenario in ["Emotion", "Intent", "Content Type"] and not multi_label:
multi_label = True
# Load model
classifier = load_zero_shot()
# Classification process
result = classifier(text_input, labels, multi_label=multi_label)
# Display results
output_html.append('')
# Create DataFrame
class_df = pd.DataFrame({
'Category': result['labels'],
'Confidence': result['scores']
})
# Visualization
fig = plt.figure(figsize=(10, 6))
bars = plt.barh(class_df['Category'], class_df['Confidence'], color='#1976D2')
# Add percentage labels
for i, bar in enumerate(bars):
plt.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2,
f"{bar.get_width():.1%}", va='center')
plt.xlim(0, 1.1)
plt.xlabel('Confidence Score')
plt.title(f'{scenario} Classification')
plt.tight_layout()
# Layout with vertical stacking - Chart first
output_html.append('')
output_html.append('
')
output_html.append('
Detailed Results
')
output_html.append(df_to_html_table(class_df))
output_html.append('')
# Top result
output_html.append('
')
top_class = class_df.iloc[0]['Category']
top_score = class_df.iloc[0]['Confidence']
output_html.append(f"""
Primary Classification
{top_class}
Confidence: {top_score:.1%}
""")
output_html.append('
') # Close result column
output_html.append('
') # Close row
# Multiple categories (if multi-label)
if multi_label:
# Get all categories with significant confidence
significant_classes = class_df[class_df['Confidence'] > 0.5]
if len(significant_classes) > 1:
output_html.append(f"""