YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

KNN Explainability with SHAP

A comprehensive Jupyter notebook demonstrating how to use SHAP (SHapley Additive exPlanations) to interpret K-Nearest Neighbors (KNN) model predictions with detailed visualizations.

Overview

This project provides a complete walkthrough of:

  • Training a K-Nearest Neighbors classifier on the Breast Cancer Wisconsin dataset
  • Using SHAP to explain model predictions at both global and local levels
  • Creating comprehensive visualizations to understand feature importance and model behavior
  • Interactive exploration of individual predictions

Features

Model Training

  • Optimal K value selection through cross-validation
  • StandardScaler preprocessing for KNN optimization
  • Comprehensive model evaluation metrics
  • ROC curves and confusion matrices

SHAP Explainability

  • Summary Plots: Global feature importance with value distributions
  • Bar Plots: Mean absolute feature importance rankings
  • Waterfall Plots: Step-by-step breakdown of individual predictions
  • Force Plots: Visual representation of feature contributions
  • Dependence Plots: Feature-value relationships and interactions
  • Decision Plots: Cumulative prediction paths for multiple samples

Interactive Analysis

  • Custom explain_sample() function for detailed prediction exploration
  • Comparison of correct vs. incorrect predictions
  • GPU memory management utilities
  • Comprehensive reporting and summaries

Prerequisites

Required Packages

torch
shap
scikit-learn
matplotlib
seaborn
pandas
numpy
plotly
ipywidgets
tqdm

Environment

  • Python 3.7+
  • Jupyter Notebook or JupyterLab
  • VS Code with Jupyter extension (recommended)
  • GPU support optional (CUDA-enabled PyTorch for faster computation)

Installation

  1. Clone this repository:
git clone <repository-url>
cd <repository-directory>
  1. Install required packages:
pip install torch shap scikit-learn matplotlib seaborn pandas numpy plotly ipywidgets tqdm
  1. Launch Jupyter:
jupyter notebook
  1. Open KNN_SHAP_Explainability.ipynb

Usage

Basic Usage

Run the notebook cells sequentially from top to bottom. The notebook is structured in logical sections:

  1. Environment Setup: GPU verification and package installation
  2. Data Loading: Load and explore the Breast Cancer Wisconsin dataset
  3. Preprocessing: Feature scaling and train-test split
  4. Model Training: KNN optimization and evaluation
  5. SHAP Analysis: Compute SHAP values for test samples
  6. Visualizations: Generate all SHAP plots and explanations
  7. Interactive Exploration: Use custom functions to explore predictions

Interactive Exploration

After running all cells, use the explain_sample() function to explore any prediction:

# Explain the first sample
explain_sample(0)

# Explain a high-confidence correct prediction
explain_sample(5)

# Explain a misclassified sample
explain_sample(42)

GPU Memory Management

If running on GPU, monitor and manage memory:

# Check current GPU memory usage
print_gpu_memory()

# Clear GPU cache
clear_gpu_memory()

# Get optimal device (CPU or GPU)
device = get_optimal_device()

Dataset

The notebook uses the Breast Cancer Wisconsin (Diagnostic) dataset from scikit-learn:

  • Samples: 569
  • Features: 30 (mean, standard error, and worst values of 10 real-valued features)
  • Classes: 2 (Malignant, Benign)
  • Task: Binary classification

Features include radius, texture, perimeter, area, smoothness, compactness, concavity, concave points, symmetry, and fractal dimension.

Model Performance

The KNN model achieves:

  • High accuracy on both training and test sets
  • Optimal K value determined through cross-validation
  • ROC AUC score > 0.95 (typical)
  • Interpretable predictions through SHAP analysis

SHAP Visualization Guide

Summary Plot (Beeswarm)

  • Shows global feature importance across all samples
  • Color indicates feature value (red=high, blue=low)
  • Horizontal position shows impact on prediction

Waterfall Plot

  • Explains individual predictions step-by-step
  • Starts from base value (expected prediction)
  • Each bar shows a feature's contribution
  • Red pushes toward malignant, blue toward benign

Dependence Plot

  • Reveals non-linear feature relationships
  • Shows feature interactions through color
  • Identifies threshold effects

Decision Plot

  • Visualizes prediction paths for multiple samples
  • Shows cumulative effect of features
  • Helps identify prediction patterns

Key Insights

  1. Feature Importance: SHAP identifies the most critical features for cancer diagnosis
  2. Non-linearity: Dependence plots reveal complex feature-value relationships
  3. Interactions: Color gradients show which features interact
  4. Individual Explanations: Each prediction can be fully explained and understood
  5. Model Trust: Transparent explanations increase confidence in model decisions

Customization

Using Different Datasets

Replace the data loading section with your own dataset:

# Load your dataset
X = pd.DataFrame(your_data)
y = pd.Series(your_labels)

# Continue with the rest of the notebook

Adjusting SHAP Computation

Modify SHAP parameters for speed/accuracy tradeoff:

# Faster computation (less accurate)
shap_values = explainer.shap_values(X_test_sample, nsamples=50)

# More accurate (slower)
shap_values = explainer.shap_values(X_test_sample, nsamples=200)

# Smaller background dataset (faster)
background = shap.kmeans(X_train_scaled, 50)

Trying Other Algorithms

The SHAP approach works with any model:

from sklearn.ensemble import RandomForestClassifier

# Train Random Forest instead of KNN
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train_scaled, y_train)

# Use TreeExplainer for faster computation on tree-based models
explainer = shap.TreeExplainer(model)

Performance Tips

  1. GPU Acceleration: Use GPU for faster PyTorch operations
  2. Background Size: Reduce background dataset size for faster SHAP computation
  3. Sample Size: Start with fewer samples (e.g., 50) for quick testing
  4. nsamples Parameter: Lower values speed up computation but reduce accuracy
  5. Memory Management: Clear GPU cache between major computations

Troubleshooting

Common Issues

GPU not detected:

  • Check CUDA installation
  • Verify PyTorch GPU support: torch.cuda.is_available()
  • Notebook will fall back to CPU automatically

SHAP computation too slow:

  • Reduce background dataset size
  • Decrease number of test samples
  • Lower nsamples parameter

Memory errors:

  • Process fewer samples at once
  • Clear GPU cache with clear_gpu_memory()
  • Reduce background dataset size

Visualization issues:

  • Ensure matplotlib backend is compatible
  • Update SHAP to latest version
  • Restart kernel if plots don't render

Contributing

Contributions are welcome! Please feel free to submit pull requests or open issues for bugs, questions, or new features.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

  • SHAP: Scott Lundberg et al. for the SHAP library
  • scikit-learn: For the Breast Cancer Wisconsin dataset and ML tools
  • PyTorch: For GPU acceleration capabilities
  • Community: All contributors to the open-source ML/AI ecosystem

References

Contact

For questions or feedback, please open an issue in the repository.


Note: This notebook is designed for educational purposes and demonstrates best practices for ML model interpretability using SHAP.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support