KNN Explainability with SHAP
A comprehensive Jupyter notebook demonstrating how to use SHAP (SHapley Additive exPlanations) to interpret K-Nearest Neighbors (KNN) model predictions with detailed visualizations.
Overview
This project provides a complete walkthrough of:
- Training a K-Nearest Neighbors classifier on the Breast Cancer Wisconsin dataset
- Using SHAP to explain model predictions at both global and local levels
- Creating comprehensive visualizations to understand feature importance and model behavior
- Interactive exploration of individual predictions
Features
Model Training
- Optimal K value selection through cross-validation
- StandardScaler preprocessing for KNN optimization
- Comprehensive model evaluation metrics
- ROC curves and confusion matrices
SHAP Explainability
- Summary Plots: Global feature importance with value distributions
- Bar Plots: Mean absolute feature importance rankings
- Waterfall Plots: Step-by-step breakdown of individual predictions
- Force Plots: Visual representation of feature contributions
- Dependence Plots: Feature-value relationships and interactions
- Decision Plots: Cumulative prediction paths for multiple samples
Interactive Analysis
- Custom
explain_sample()function for detailed prediction exploration - Comparison of correct vs. incorrect predictions
- GPU memory management utilities
- Comprehensive reporting and summaries
Prerequisites
Required Packages
torch
shap
scikit-learn
matplotlib
seaborn
pandas
numpy
plotly
ipywidgets
tqdm
Environment
- Python 3.7+
- Jupyter Notebook or JupyterLab
- VS Code with Jupyter extension (recommended)
- GPU support optional (CUDA-enabled PyTorch for faster computation)
Installation
- Clone this repository:
git clone <repository-url>
cd <repository-directory>
- Install required packages:
pip install torch shap scikit-learn matplotlib seaborn pandas numpy plotly ipywidgets tqdm
- Launch Jupyter:
jupyter notebook
- Open
KNN_SHAP_Explainability.ipynb
Usage
Basic Usage
Run the notebook cells sequentially from top to bottom. The notebook is structured in logical sections:
- Environment Setup: GPU verification and package installation
- Data Loading: Load and explore the Breast Cancer Wisconsin dataset
- Preprocessing: Feature scaling and train-test split
- Model Training: KNN optimization and evaluation
- SHAP Analysis: Compute SHAP values for test samples
- Visualizations: Generate all SHAP plots and explanations
- Interactive Exploration: Use custom functions to explore predictions
Interactive Exploration
After running all cells, use the explain_sample() function to explore any prediction:
# Explain the first sample
explain_sample(0)
# Explain a high-confidence correct prediction
explain_sample(5)
# Explain a misclassified sample
explain_sample(42)
GPU Memory Management
If running on GPU, monitor and manage memory:
# Check current GPU memory usage
print_gpu_memory()
# Clear GPU cache
clear_gpu_memory()
# Get optimal device (CPU or GPU)
device = get_optimal_device()
Dataset
The notebook uses the Breast Cancer Wisconsin (Diagnostic) dataset from scikit-learn:
- Samples: 569
- Features: 30 (mean, standard error, and worst values of 10 real-valued features)
- Classes: 2 (Malignant, Benign)
- Task: Binary classification
Features include radius, texture, perimeter, area, smoothness, compactness, concavity, concave points, symmetry, and fractal dimension.
Model Performance
The KNN model achieves:
- High accuracy on both training and test sets
- Optimal K value determined through cross-validation
- ROC AUC score > 0.95 (typical)
- Interpretable predictions through SHAP analysis
SHAP Visualization Guide
Summary Plot (Beeswarm)
- Shows global feature importance across all samples
- Color indicates feature value (red=high, blue=low)
- Horizontal position shows impact on prediction
Waterfall Plot
- Explains individual predictions step-by-step
- Starts from base value (expected prediction)
- Each bar shows a feature's contribution
- Red pushes toward malignant, blue toward benign
Dependence Plot
- Reveals non-linear feature relationships
- Shows feature interactions through color
- Identifies threshold effects
Decision Plot
- Visualizes prediction paths for multiple samples
- Shows cumulative effect of features
- Helps identify prediction patterns
Key Insights
- Feature Importance: SHAP identifies the most critical features for cancer diagnosis
- Non-linearity: Dependence plots reveal complex feature-value relationships
- Interactions: Color gradients show which features interact
- Individual Explanations: Each prediction can be fully explained and understood
- Model Trust: Transparent explanations increase confidence in model decisions
Customization
Using Different Datasets
Replace the data loading section with your own dataset:
# Load your dataset
X = pd.DataFrame(your_data)
y = pd.Series(your_labels)
# Continue with the rest of the notebook
Adjusting SHAP Computation
Modify SHAP parameters for speed/accuracy tradeoff:
# Faster computation (less accurate)
shap_values = explainer.shap_values(X_test_sample, nsamples=50)
# More accurate (slower)
shap_values = explainer.shap_values(X_test_sample, nsamples=200)
# Smaller background dataset (faster)
background = shap.kmeans(X_train_scaled, 50)
Trying Other Algorithms
The SHAP approach works with any model:
from sklearn.ensemble import RandomForestClassifier
# Train Random Forest instead of KNN
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train_scaled, y_train)
# Use TreeExplainer for faster computation on tree-based models
explainer = shap.TreeExplainer(model)
Performance Tips
- GPU Acceleration: Use GPU for faster PyTorch operations
- Background Size: Reduce background dataset size for faster SHAP computation
- Sample Size: Start with fewer samples (e.g., 50) for quick testing
- nsamples Parameter: Lower values speed up computation but reduce accuracy
- Memory Management: Clear GPU cache between major computations
Troubleshooting
Common Issues
GPU not detected:
- Check CUDA installation
- Verify PyTorch GPU support:
torch.cuda.is_available() - Notebook will fall back to CPU automatically
SHAP computation too slow:
- Reduce background dataset size
- Decrease number of test samples
- Lower nsamples parameter
Memory errors:
- Process fewer samples at once
- Clear GPU cache with
clear_gpu_memory() - Reduce background dataset size
Visualization issues:
- Ensure matplotlib backend is compatible
- Update SHAP to latest version
- Restart kernel if plots don't render
Contributing
Contributions are welcome! Please feel free to submit pull requests or open issues for bugs, questions, or new features.
License
This project is licensed under the MIT License - see the LICENSE file for details.
Acknowledgments
- SHAP: Scott Lundberg et al. for the SHAP library
- scikit-learn: For the Breast Cancer Wisconsin dataset and ML tools
- PyTorch: For GPU acceleration capabilities
- Community: All contributors to the open-source ML/AI ecosystem
References
- Lundberg, S. M., & Lee, S. I. (2017). A unified approach to interpreting model predictions. NeurIPS.
- SHAP Documentation: https://shap.readthedocs.io/
- scikit-learn Documentation: https://scikit-learn.org/
- Breast Cancer Wisconsin Dataset: https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)
Contact
For questions or feedback, please open an issue in the repository.
Note: This notebook is designed for educational purposes and demonstrates best practices for ML model interpretability using SHAP.