Spaces:
Runtime error
Runtime error
| """ | |
| Gradio Chatbot Interface for CGT-LLM-Beta RAG System | |
| This application provides a web interface for the RAG chatbot, allowing users to: | |
| - Select different LLM models from a dropdown | |
| - Choose education level for personalized answers (Middle School, High School, Professional, Improved) | |
| - View answers with Flesch-Kincaid grade level scores | |
| - See source documents and similarity scores for every answer | |
| Usage: | |
| python app.py | |
| IMPORTANT: Before using, update the MODEL_MAP dictionary with correct HuggingFace paths | |
| for models that currently have placeholder paths (Llama-4-Scout, MediPhi, Phi-4-reasoning). | |
| For Hugging Face Spaces: | |
| - Ensure vector database is built (run bot.py with indexing first) | |
| - Model will be loaded on startup | |
| - Access via the Gradio interface | |
| """ | |
| import gradio as gr | |
| import argparse | |
| import sys | |
| import os | |
| from typing import Tuple, Optional, List | |
| import logging | |
| import textstat | |
| import torch | |
| # Import from bot.py - wrap in try/except to handle import errors gracefully | |
| try: | |
| from bot import RAGBot, parse_args, Chunk | |
| BOT_AVAILABLE = True | |
| except ImportError as e: | |
| logger.error(f"Failed to import bot module: {e}") | |
| BOT_AVAILABLE = False | |
| # Create dummy classes so the module can still load | |
| class RAGBot: | |
| pass | |
| class Chunk: | |
| pass | |
| def parse_args(): | |
| return None | |
| # Set up logging first (before any logger usage) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # For Hugging Face Inference API | |
| try: | |
| from huggingface_hub import InferenceClient | |
| HF_INFERENCE_AVAILABLE = True | |
| except ImportError: | |
| HF_INFERENCE_AVAILABLE = False | |
| logger.warning("huggingface_hub not available, InferenceClient will not work") | |
| # Model mapping: short name -> full HuggingFace path | |
| MODEL_MAP = { | |
| "Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", | |
| "Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2", | |
| "Llama-4-Scout-17B-16E-Instruct": "meta-llama/Llama-4-Scout-17B-16E-Instruct", | |
| "MediPhi-Instruct": "microsoft/MediPhi-Instruct", | |
| "MediPhi": "microsoft/MediPhi", | |
| "Phi-4-reasoning": "microsoft/Phi-4-reasoning", | |
| } | |
| # Education level mapping | |
| EDUCATION_LEVELS = { | |
| "Middle School": "middle_school", | |
| "High School": "high_school", | |
| "College": "college", | |
| "Doctoral": "doctoral" | |
| } | |
| # Example questions from the results CSV (hardcoded for easy access) | |
| EXAMPLE_QUESTIONS = [ | |
| "Can a BRCA2 variant skip a generation?", | |
| "Can a PMS2 variant skip a generation?", | |
| "Can an EPCAM/MSH2 variant skip a generation?", | |
| "Can an MLH1 variant skip a generation?", | |
| "Can an MSH2 variant skip a generation?", | |
| "Can an MSH6 variant skip a generation?", | |
| "Can I pass this MSH2 variant to my kids?", | |
| "Can only women carry a BRCA inherited mutation?", | |
| "Does GINA cover life or disability insurance?", | |
| "Does having a BRCA1 mutation mean I will definitely have cancer?", | |
| "Does having a BRCA2 mutation mean I will definitely have cancer?", | |
| "Does having a PMS2 mutation mean I will definitely have cancer?", | |
| "Does having an EPCAM/MSH2 mutation mean I will definitely have cancer?", | |
| "Does having an MLH1 mutation mean I will definitely have cancer?", | |
| "Does having an MSH2 mutation mean I will definitely have cancer?", | |
| "Does having an MSH6 mutation mean I will definitely have cancer?", | |
| "Does this BRCA1 genetic variant affect my cancer treatment?", | |
| "Does this BRCA2 genetic variant affect my cancer treatment?", | |
| "Does this EPCAM/MSH2 genetic variant affect my cancer treatment?", | |
| "Does this MLH1 genetic variant affect my cancer treatment?", | |
| "Does this MSH2 genetic variant affect my cancer treatment?", | |
| "Does this MSH6 genetic variant affect my cancer treatment?", | |
| "Does this PMS2 genetic variant affect my cancer treatment?", | |
| "How can I cope with this diagnosis?", | |
| "How can I get my kids tested?", | |
| "How can I help others with my condition?", | |
| "How might my genetic test results change over time?", | |
| "I don't talk to my family/parents/sister/brother. How can I share this with them?", | |
| "I have a BRCA pathogenic variant and I want to have children, what are my options?", | |
| "Is genetic testing for my family members covered by insurance?", | |
| "Is new research being done on my condition?", | |
| "Is this BRCA1 variant something I inherited?", | |
| "Is this BRCA2 variant something I inherited?", | |
| "Is this EPCAM/MSH2 variant something I inherited?", | |
| "Is this MLH1 variant something I inherited?", | |
| "Is this MSH2 variant something I inherited?", | |
| "Is this MSH6 variant something I inherited?", | |
| "Is this PMS2 variant something I inherited?", | |
| "My relative doesn't have insurance. What should they do?", | |
| "People who test positive for a genetic mutation are they at risk of losing their health insurance?", | |
| "Should I contact my male and female relatives?", | |
| "Should my family members get tested?", | |
| "What are the Risks and Benefits of Risk-Reducing Surgeries for Lynch Syndrome?", | |
| "What are the recommendations for my family members if I have a BRCA1 mutation?", | |
| "What are the recommendations for my family members if I have a BRCA2 mutation?", | |
| "What are the recommendations for my family members if I have a PMS2 mutation?", | |
| "What are the recommendations for my family members if I have an EPCAM/MSH2 mutation?", | |
| "What are the recommendations for my family members if I have an MLH1 mutation?", | |
| "What are the recommendations for my family members if I have an MSH2 mutation?", | |
| "What are the recommendations for my family members if I have an MSH6 mutation?", | |
| "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have a BRCA mutation?", | |
| "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an EPCAM/MSH2 mutation?", | |
| "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an MSH2 mutation?", | |
| "What does a BRCA1 genetic variant mean for me?", | |
| "What does a BRCA2 genetic variant mean for me?", | |
| "What does a PMS2 genetic variant mean for me?", | |
| "What does an EPCAM/MSH2 genetic variant mean for me?", | |
| "What does an MLH1 genetic variant mean for me?", | |
| "What does an MSH2 genetic variant mean for me?", | |
| "What does an MSH6 genetic variant mean for me?", | |
| "What if I feel overwhelmed?", | |
| "What if I want to have children and have a hereditary cancer gene? What are my reproductive options?", | |
| "What if a family member doesn't want to get tested?", | |
| "What is Lynch Syndrome?", | |
| "What is my cancer risk if I have BRCA1 Hereditary Breast and Ovarian Cancer syndrome?", | |
| "What is my cancer risk if I have BRCA2 Hereditary Breast and Ovarian Cancer syndrome?", | |
| "What is my cancer risk if I have MLH1 Lynch syndrome?", | |
| "What is my cancer risk if I have MSH2 or EPCAM-associated Lynch syndrome?", | |
| "What is my cancer risk if I have MSH6 Lynch syndrome?", | |
| "What is my cancer risk if I have PMS2 Lynch syndrome?", | |
| "What other resources are available to help me?", | |
| "What screening tests do you recommend for BRCA1 carriers?", | |
| "What screening tests do you recommend for BRCA2 carriers?", | |
| "What screening tests do you recommend for EPCAM/MSH2 carriers?", | |
| "What screening tests do you recommend for MLH1 carriers?", | |
| "What screening tests do you recommend for MSH2 carriers?", | |
| "What screening tests do you recommend for MSH6 carriers?", | |
| "What screening tests do you recommend for PMS2 carriers?", | |
| "What steps can I take to manage my cancer risk if I have Lynch syndrome?", | |
| "What types of cancers am I at risk for with a BRCA1 mutation?", | |
| "What types of cancers am I at risk for with a BRCA2 mutation?", | |
| "What types of cancers am I at risk for with a PMS2 mutation?", | |
| "What types of cancers am I at risk for with an EPCAM/MSH2 mutation?", | |
| "What types of cancers am I at risk for with an MLH1 mutation?", | |
| "What types of cancers am I at risk for with an MSH2 mutation?", | |
| "What types of cancers am I at risk for with an MSH6 mutation?", | |
| "Where can I find a genetic counselor?", | |
| "Which of my relatives are at risk?", | |
| "Who are my first-degree relatives?", | |
| "Who do my family members call to have genetic testing?", | |
| "Why do some families with Lynch syndrome have more cases of cancer than others?", | |
| "Why should I share my BRCA1 genetic results with family?", | |
| "Why should I share my BRCA2 genetic results with family?", | |
| "Why should I share my EPCAM/MSH2 genetic results with family?", | |
| "Why should I share my MLH1 genetic results with family?", | |
| "Why should I share my MSH2 genetic results with family?", | |
| "Why should I share my MSH6 genetic results with family?", | |
| "Why should I share my PMS2 genetic results with family?", | |
| "Why would my relatives want to know if they have this? What can they do about it?", | |
| "Will my insurance cover testing for my parents/brother/sister?", | |
| "Will this affect my health insurance?", | |
| ] | |
| class InferenceAPIBot: | |
| """Wrapper that uses Hugging Face Inference API instead of loading models locally""" | |
| def __init__(self, bot: RAGBot, hf_token: Optional[str] = None): | |
| """Initialize with a RAGBot (for vector DB) and optional HF token for Inference API""" | |
| self.bot = bot # Use bot for vector DB and formatting | |
| # Initialize client - token is optional for public models | |
| if hf_token: | |
| try: | |
| self.client = InferenceClient(api_key=hf_token) | |
| logger.info("Using Inference API with provided token") | |
| except Exception as e: | |
| logger.error(f"Failed to create InferenceClient with token: {e}") | |
| raise | |
| else: | |
| # Try without token first (works for public models) | |
| try: | |
| self.client = InferenceClient() | |
| logger.info("Using Inference API without token (public models)") | |
| except Exception as e: | |
| logger.error(f"Failed to create InferenceClient without token: {e}") | |
| raise | |
| self.current_model = bot.args.model | |
| # Don't set args as attribute - access via bot.args instead | |
| logger.info(f"InferenceAPIBot initialized with model: {self.current_model}") | |
| # Test the client with a simple call to verify it works | |
| try: | |
| logger.info("Testing Inference API connection...") | |
| # Just verify the client is accessible, don't make an actual call during init | |
| if not hasattr(self.client, 'text_generation'): | |
| logger.warning("InferenceClient may not support text_generation method") | |
| except Exception as e: | |
| logger.warning(f"Could not verify InferenceClient: {e}") | |
| def args(self): | |
| """Access args from the wrapped bot""" | |
| return self.bot.args | |
| def generate_answer(self, prompt: str, **kwargs) -> str: | |
| """Generate answer using Inference API""" | |
| try: | |
| max_tokens = kwargs.get('max_new_tokens', 512) | |
| temperature = kwargs.get('temperature', 0.2) | |
| top_p = kwargs.get('top_p', 0.9) | |
| # Use text_generation API directly (more reliable and widely supported) | |
| logger.info(f"Calling Inference API for model: {self.current_model}") | |
| response = self.client.text_generation( | |
| prompt, | |
| model=self.current_model, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| return_full_text=False, | |
| ) | |
| logger.info(f"Inference API response received (length: {len(response) if response else 0})") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error calling Inference API: {e}", exc_info=True) | |
| import traceback | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| return f"Error generating answer: {str(e)}. Please check the logs for details." | |
| def enhance_readability(self, answer: str, target_level: str = "middle_school") -> Tuple[str, float]: | |
| """Enhance readability using Inference API""" | |
| try: | |
| # Define prompts for different reading levels (same as bot.py) | |
| if target_level == "middle_school": | |
| level_description = "middle school reading level (ages 12-14, 6th-8th grade)" | |
| instructions = """ | |
| - Use simpler medical terms or explain them | |
| - Medium-length sentences | |
| - Clear, structured explanations | |
| - Keep important medical information accessible""" | |
| elif target_level == "high_school": | |
| level_description = "high school reading level (ages 15-18, 9th-12th grade)" | |
| instructions = """ | |
| - Use appropriate medical terminology with context | |
| - Varied sentence length | |
| - Comprehensive yet accessible explanations | |
| - Maintain technical accuracy while ensuring clarity""" | |
| elif target_level == "college": | |
| level_description = "college reading level (undergraduate level, ages 18-22)" | |
| instructions = """ | |
| - Use standard medical terminology with brief explanations | |
| - Professional and clear writing style | |
| - Include relevant clinical context | |
| - Maintain scientific accuracy and precision | |
| - Appropriate for undergraduate students in health sciences""" | |
| elif target_level == "doctoral": | |
| level_description = "doctoral/professional reading level (graduate level, medical professionals)" | |
| instructions = """ | |
| - Use advanced medical and scientific terminology | |
| - Include detailed clinical and research context | |
| - Reference specific mechanisms, pathways, and evidence | |
| - Provide comprehensive technical explanations | |
| - Appropriate for medical professionals, researchers, and graduate students | |
| - Include nuanced discussions of clinical implications and research findings""" | |
| else: | |
| raise ValueError(f"Unknown target_level: {target_level}") | |
| # Create messages for chat API | |
| system_message = f"""You are a helpful medical assistant who specializes in explaining complex medical information at appropriate reading levels. Rewrite the following medical answer for {level_description}: | |
| {instructions} | |
| - Keep the same important information but adapt the complexity | |
| - Provide context for technical terms | |
| - Ensure the answer is informative yet understandable""" | |
| user_message = f"Please rewrite this medical answer for {level_description}:\n\n{answer}" | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": user_message} | |
| ] | |
| # Call Inference API using text_generation (more reliable) | |
| max_tokens = 512 if target_level in ["college", "doctoral"] else 384 | |
| temperature = 0.4 if target_level in ["college", "doctoral"] else 0.3 | |
| # Combine system and user messages for text generation | |
| combined_prompt = f"{system_message}\n\n{user_message}" | |
| logger.info(f"Enhancing readability for {target_level} level") | |
| enhanced_answer = self.client.text_generation( | |
| combined_prompt, | |
| model=self.current_model, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| return_full_text=False, | |
| ) | |
| # Clean the answer (same as bot.py) | |
| cleaned = self.bot._clean_readability_answer(enhanced_answer, target_level) | |
| # Calculate Flesch score | |
| try: | |
| flesch_score = textstat.flesch_kincaid_grade(cleaned) | |
| except: | |
| flesch_score = 0.0 | |
| return cleaned, flesch_score | |
| except Exception as e: | |
| logger.error(f"Error enhancing readability: {e}", exc_info=True) | |
| return answer, 0.0 | |
| # Delegate other methods to bot | |
| def format_prompt(self, context_chunks: List[Chunk], question: str) -> str: | |
| return self.bot.format_prompt(context_chunks, question) | |
| def retrieve_with_scores(self, query: str, k: int) -> Tuple[List[Chunk], List[float]]: | |
| return self.bot.retrieve_with_scores(query, k) | |
| def _categorize_question(self, question: str) -> str: | |
| return self.bot._categorize_question(question) | |
| def args(self): | |
| return self.bot.args | |
| def vector_retriever(self): | |
| return self.bot.vector_retriever | |
| class GradioRAGInterface: | |
| """Wrapper class to integrate RAGBot with Gradio""" | |
| def __init__(self, initial_bot: RAGBot, use_inference_api: bool = False): | |
| # Check if we should use Inference API (on Spaces) | |
| if use_inference_api and HF_INFERENCE_AVAILABLE: | |
| # Try to get token, but it's optional for public models | |
| # On Spaces, HF_TOKEN is automatically available | |
| hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| try: | |
| self.bot = InferenceAPIBot(initial_bot, hf_token) | |
| self.use_inference_api = True | |
| if hf_token: | |
| logger.info("Using Hugging Face Inference API with token") | |
| else: | |
| logger.info("Using Hugging Face Inference API without token (public models)") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Inference API: {e}") | |
| # On Spaces, we MUST use Inference API, but don't raise - let the demo show an error | |
| if IS_SPACES: | |
| logger.error("Cannot use local models on Spaces. Please configure HF_TOKEN.") | |
| # Store error message but don't raise - we'll show it in the UI | |
| self.inference_error = str(e) | |
| # Still set bot to initial_bot so the interface can be created | |
| self.bot = initial_bot | |
| self.use_inference_api = False | |
| else: | |
| logger.warning("Falling back to local model") | |
| self.bot = initial_bot | |
| self.use_inference_api = False | |
| else: | |
| self.bot = initial_bot | |
| self.use_inference_api = False | |
| # Get current model from bot args (not a direct attribute) | |
| self.current_model = self.bot.args.model if hasattr(self.bot, 'args') else getattr(self.bot, 'current_model', None) | |
| if self.current_model is None and hasattr(self.bot, 'bot'): | |
| # If using InferenceAPIBot, get from the wrapped bot | |
| self.current_model = self.bot.bot.args.model | |
| self.data_dir = initial_bot.args.data_dir | |
| logger.info("GradioRAGInterface initialized") | |
| def _find_file_path(self, filename: str) -> str: | |
| """Find the full file path for a given filename""" | |
| from pathlib import Path | |
| data_path = Path(self.data_dir) | |
| if not data_path.exists(): | |
| return "" | |
| # Search for the file recursively | |
| for file_path in data_path.rglob(filename): | |
| return str(file_path) | |
| return "" | |
| def reload_model(self, model_short_name: str) -> str: | |
| """Reload the model when user selects a different one""" | |
| if model_short_name not in MODEL_MAP: | |
| return f"Error: Unknown model '{model_short_name}'" | |
| new_model_path = MODEL_MAP[model_short_name] | |
| # If same model, no need to reload | |
| if new_model_path == self.current_model: | |
| return f"Model already loaded: {model_short_name}" | |
| try: | |
| logger.info(f"Switching model from {self.current_model} to {new_model_path}") | |
| if self.use_inference_api: | |
| # For Inference API, just update the model name | |
| self.bot.current_model = new_model_path | |
| self.current_model = new_model_path | |
| return f"✓ Model switched to: {model_short_name} (using Inference API)" | |
| else: | |
| # For local model, reload it | |
| # Update args | |
| self.bot.args.model = new_model_path | |
| # Clear old model from memory | |
| if hasattr(self.bot, 'model') and self.bot.model is not None: | |
| del self.bot.model | |
| del self.bot.tokenizer | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| # Load new model | |
| self.bot._load_model() | |
| self.current_model = new_model_path | |
| return f"✓ Model loaded: {model_short_name}" | |
| except Exception as e: | |
| logger.error(f"Error reloading model: {e}", exc_info=True) | |
| return f"✗ Error loading model: {str(e)}" | |
| def process_question( | |
| self, | |
| question: str, | |
| model_name: str, | |
| education_level: str, | |
| k: int, | |
| temperature: float, | |
| max_tokens: int | |
| ) -> Tuple[str, str, str, str, str]: | |
| """ | |
| Process a single question and return formatted results | |
| Returns: | |
| Tuple of (answer, flesch_score, sources, similarity_scores, question_category) | |
| """ | |
| import time | |
| if not question or not question.strip(): | |
| return "Please enter a question.", "N/A", "", "", "" | |
| # Check if we're on Spaces but not using Inference API | |
| if IS_SPACES and not self.use_inference_api: | |
| error_msg = """⚠️ **Configuration Error** | |
| This Space is not configured to use the Hugging Face Inference API. | |
| **To fix this:** | |
| 1. Go to your Space settings: https://huggingface.co/spaces/alrahrooh/cgt-llm-chatbot-v2/settings | |
| 2. Add a secret named `HF_TOKEN` with your Hugging Face token | |
| 3. Get your token from: https://huggingface.co/settings/tokens | |
| 4. Restart the Space | |
| **Note:** The Inference API is required on Spaces because we cannot load models locally.""" | |
| return error_msg, "N/A", "", "", "" | |
| try: | |
| start_time = time.time() | |
| logger.info(f"Processing question: {question[:50]}...") | |
| # Reload model if changed (this can take 1-3 minutes) | |
| if model_name in MODEL_MAP: | |
| model_path = MODEL_MAP[model_name] | |
| if model_path != self.current_model: | |
| logger.info(f"Model changed, reloading from {self.current_model} to {model_path}") | |
| reload_status = self.reload_model(model_name) | |
| if reload_status.startswith("✗"): | |
| return f"Error: {reload_status}", "N/A", "", "", "" | |
| logger.info(f"Model reloaded in {time.time() - start_time:.1f}s") | |
| # Update bot args for this query | |
| self.bot.args.k = k | |
| self.bot.args.temperature = temperature | |
| # Limit max_tokens for faster generation in Gradio | |
| self.bot.args.max_new_tokens = min(max_tokens, 512) # Cap at 512 for faster responses | |
| # Categorize question | |
| logger.info("Categorizing question...") | |
| question_group = self.bot._categorize_question(question) | |
| # Retrieve relevant chunks with similarity scores | |
| logger.info("Retrieving relevant documents...") | |
| retrieve_start = time.time() | |
| context_chunks, similarity_scores = self.bot.retrieve_with_scores(question, k) | |
| logger.info(f"Retrieved {len(context_chunks)} chunks in {time.time() - retrieve_start:.2f}s") | |
| if not context_chunks: | |
| return ( | |
| "I don't have enough information to answer this question. Please try rephrasing or asking about a different topic.", | |
| "N/A", | |
| "No sources found", | |
| "No matches found", | |
| question_group | |
| ) | |
| # Format similarity scores | |
| similarity_scores_str = ", ".join([f"{score:.3f}" for score in similarity_scores]) | |
| # Format sources with chunk text and file paths | |
| sources_list = [] | |
| for i, (chunk, score) in enumerate(zip(context_chunks, similarity_scores)): | |
| # Try to find the file path | |
| file_path = self._find_file_path(chunk.filename) | |
| source_info = f""" | |
| {'='*80} | |
| SOURCE {i+1} | Similarity: {score:.3f} | |
| {'='*80} | |
| 📄 File: {chunk.filename} | |
| 📍 Path: {file_path if file_path else 'File path not found (search in Data Resources directory)'} | |
| 📊 Chunk: {chunk.chunk_id + 1}/{chunk.total_chunks} (Position: {chunk.start_pos}-{chunk.end_pos}) | |
| 📝 Full Chunk Text: | |
| {chunk.text} | |
| """ | |
| sources_list.append(source_info) | |
| sources = "\n".join(sources_list) | |
| # Generation kwargs | |
| gen_kwargs = { | |
| 'max_new_tokens': min(max_tokens, 512), # Cap for faster responses | |
| 'temperature': temperature, | |
| 'top_p': self.bot.args.top_p, | |
| 'repetition_penalty': self.bot.args.repetition_penalty | |
| } | |
| # Generate answer based on education level | |
| answer = "" | |
| flesch_score = 0.0 | |
| # Generate original answer first (needed for all enhancement levels) | |
| logger.info("Generating original answer...") | |
| gen_start = time.time() | |
| prompt = self.bot.format_prompt(context_chunks, question) | |
| original_answer = self.bot.generate_answer(prompt, **gen_kwargs) | |
| logger.info(f"Original answer generated in {time.time() - gen_start:.1f}s") | |
| # Enhance based on education level | |
| logger.info(f"Enhancing answer for {education_level} level...") | |
| enhance_start = time.time() | |
| if education_level == "middle_school": | |
| # Simplify to middle school level | |
| answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="middle_school") | |
| elif education_level == "high_school": | |
| # Simplify to high school level | |
| answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="high_school") | |
| elif education_level == "college": | |
| # Enhance to college level | |
| answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="college") | |
| elif education_level == "doctoral": | |
| # Enhance to doctoral/professional level | |
| answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="doctoral") | |
| else: | |
| answer = "Invalid education level selected." | |
| flesch_score = 0.0 | |
| logger.info(f"Answer enhanced in {time.time() - enhance_start:.1f}s") | |
| total_time = time.time() - start_time | |
| logger.info(f"Total processing time: {total_time:.1f}s") | |
| # Clean the answer - remove special tokens and formatting | |
| import re | |
| cleaned_answer = answer | |
| # Remove special tokens (case-insensitive) | |
| special_tokens = [ | |
| "<|end|>", | |
| "<|endoftext|>", | |
| "<|end_of_text|>", | |
| "<|eot_id|>", | |
| "<|start_header_id|>", | |
| "<|end_header_id|>", | |
| "<|assistant|>", | |
| "<|endoftext|>", | |
| "<|end_of_text|>", | |
| ] | |
| for token in special_tokens: | |
| # Remove case-insensitive | |
| cleaned_answer = re.sub(re.escape(token), '', cleaned_answer, flags=re.IGNORECASE) | |
| # Remove any remaining special token patterns like <|...|> | |
| cleaned_answer = re.sub(r'<\|[^|]+\|>', '', cleaned_answer) | |
| # Remove any markdown-style headers that might have been added | |
| cleaned_answer = re.sub(r'^\*\*.*?\*\*.*?\n', '', cleaned_answer, flags=re.MULTILINE) | |
| # Clean up extra whitespace and newlines | |
| cleaned_answer = re.sub(r'\n\s*\n\s*\n+', '\n\n', cleaned_answer) # Multiple newlines to double | |
| cleaned_answer = re.sub(r'^\s+|\s+$', '', cleaned_answer, flags=re.MULTILINE) # Trim lines | |
| cleaned_answer = cleaned_answer.strip() | |
| # Return just the clean answer (no headers or metadata) | |
| return ( | |
| cleaned_answer, | |
| f"{flesch_score:.1f}", | |
| sources, | |
| similarity_scores_str, | |
| question_group # Add question category as 5th return value | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing question: {e}", exc_info=True) | |
| return ( | |
| f"An error occurred while processing your question: {str(e)}", | |
| "N/A", | |
| "", | |
| "", | |
| "Error" | |
| ) | |
| def create_interface(initial_bot: RAGBot, use_inference_api: bool = False) -> gr.Blocks: | |
| """Create and configure the Gradio interface""" | |
| # Use Inference API on Spaces, local model otherwise | |
| if use_inference_api is None: | |
| use_inference_api = os.getenv("SPACE_ID") is not None or os.getenv("SYSTEM") == "spaces" | |
| try: | |
| interface = GradioRAGInterface(initial_bot, use_inference_api=use_inference_api) | |
| except Exception as e: | |
| logger.error(f"Failed to create GradioRAGInterface: {e}") | |
| # Create a minimal interface that shows the error | |
| with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: | |
| gr.Markdown(f""" | |
| # ⚠️ Initialization Error | |
| Failed to initialize the chatbot interface. | |
| **Error:** {str(e)} | |
| Please check the logs for more details. | |
| """) | |
| return demo | |
| # Get initial model name from bot | |
| initial_model_short = None | |
| for short_name, full_path in MODEL_MAP.items(): | |
| if full_path == initial_bot.args.model: | |
| initial_model_short = short_name | |
| break | |
| if initial_model_short is None: | |
| initial_model_short = list(MODEL_MAP.keys())[0] | |
| # Create the Gradio interface with error handling | |
| # CRITICAL: All components and event handlers must be INSIDE the with gr.Blocks() context | |
| try: | |
| with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: | |
| gr.Markdown(""" | |
| # 🧬 CGT-LLM-Beta: Genetic Counseling RAG Chatbot | |
| Ask questions about genetic counseling, cascade genetic testing, hereditary cancer syndromes, and related topics. | |
| The chatbot uses a Retrieval-Augmented Generation (RAG) system to provide evidence-based answers from medical literature. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| question_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="e.g., What is Lynch Syndrome? What screening is recommended for BRCA1 carriers?", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODEL_MAP.keys()), | |
| value=initial_model_short, | |
| label="Select Model", | |
| info="Choose which LLM model to use for generating answers" | |
| ) | |
| education_dropdown = gr.Dropdown( | |
| choices=list(EDUCATION_LEVELS.keys()), | |
| value=list(EDUCATION_LEVELS.keys())[0], | |
| label="Education Level", | |
| info="Select your education level for personalized answers" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Number of document chunks to retrieve (k)" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.1, | |
| label="Temperature (lower = more focused)" | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=128, | |
| maximum=1024, | |
| value=512, | |
| step=128, | |
| label="Max Tokens (lower = faster responses)" | |
| ) | |
| submit_btn = gr.Button("Ask Question", variant="primary", size="lg") | |
| with gr.Column(scale=3): | |
| answer_output = gr.Textbox( | |
| label="Answer", | |
| lines=20, | |
| interactive=False, | |
| elem_classes=["answer-box"] | |
| ) | |
| with gr.Row(): | |
| flesch_output = gr.Textbox( | |
| label="Flesch-Kincaid Grade Level", | |
| value="N/A", | |
| interactive=False, | |
| scale=1 | |
| ) | |
| similarity_output = gr.Textbox( | |
| label="Similarity Scores", | |
| value="", | |
| interactive=False, | |
| scale=1 | |
| ) | |
| category_output = gr.Textbox( | |
| label="Question Category", | |
| value="", | |
| interactive=False, | |
| scale=1 | |
| ) | |
| sources_output = gr.Textbox( | |
| label="Source Documents (with Chunk Text)", | |
| lines=15, | |
| interactive=False, | |
| info="Shows the retrieved document chunks with full text. File paths are shown for easy access." | |
| ) | |
| # Example questions - all questions from the results CSV (scrollable) | |
| gr.Markdown("### 💡 Example Questions") | |
| gr.Markdown(f"Select a question below to use it in the chatbot ({len(EXAMPLE_QUESTIONS)} questions - scrollable dropdown):") | |
| # Use Dropdown which is naturally scrollable with many options | |
| example_questions_dropdown = gr.Dropdown( | |
| choices=EXAMPLE_QUESTIONS, | |
| label="Example Questions", | |
| value=None, | |
| info="Open the dropdown and scroll through all questions. Select one to use it.", | |
| interactive=True, | |
| container=True, | |
| scale=1 | |
| ) | |
| # Update question input when dropdown selection changes | |
| def update_question_from_dropdown(selected_question): | |
| return selected_question if selected_question else "" | |
| example_questions_dropdown.change( | |
| fn=update_question_from_dropdown, | |
| inputs=example_questions_dropdown, | |
| outputs=question_input | |
| ) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| **Note:** This chatbot provides informational answers based on medical literature. | |
| It is not a substitute for professional medical advice, diagnosis, or treatment. | |
| Always consult with qualified healthcare providers for medical decisions. | |
| """) | |
| # Connect the submit button | |
| def process_with_education_level(question, model, education, k, temp, max_tok): | |
| education_key = EDUCATION_LEVELS[education] | |
| return interface.process_question(question, model, education_key, k, temp, max_tok) | |
| submit_btn.click( | |
| fn=process_with_education_level, | |
| inputs=[ | |
| question_input, | |
| model_dropdown, | |
| education_dropdown, | |
| k_slider, | |
| temperature_slider, | |
| max_tokens_slider | |
| ], | |
| outputs=[ | |
| answer_output, | |
| flesch_output, | |
| sources_output, | |
| similarity_output, | |
| category_output | |
| ] | |
| ) | |
| # Also allow Enter key to submit | |
| question_input.submit( | |
| fn=process_with_education_level, | |
| inputs=[ | |
| question_input, | |
| model_dropdown, | |
| education_dropdown, | |
| k_slider, | |
| temperature_slider, | |
| max_tokens_slider | |
| ], | |
| outputs=[ | |
| answer_output, | |
| flesch_output, | |
| sources_output, | |
| similarity_output, | |
| category_output | |
| ] | |
| ) | |
| except Exception as interface_error: | |
| logger.error(f"Error setting up Gradio interface components: {interface_error}", exc_info=True) | |
| import traceback | |
| error_trace = traceback.format_exc() | |
| # Create a minimal working demo | |
| with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: | |
| gr.Markdown(f""" | |
| # ⚠️ Interface Setup Error | |
| An error occurred while setting up the interface components. | |
| **Error:** {str(interface_error)} | |
| **Traceback:** | |
| ``` | |
| {error_trace[:1000]}... | |
| ``` | |
| Please check the logs for more details. | |
| """) | |
| return demo | |
| logger.info("Gradio interface created successfully") | |
| return demo | |
| def main(): | |
| """Main function to launch the Gradio app""" | |
| # Parse arguments with defaults suitable for Gradio | |
| parser = argparse.ArgumentParser(description="Gradio Interface for CGT-LLM-Beta RAG Chatbot") | |
| # Model and database settings | |
| parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct', | |
| help='HuggingFace model name') | |
| parser.add_argument('--vector-db-dir', default='./chroma_db', | |
| help='Directory for ChromaDB persistence') | |
| parser.add_argument('--data-dir', default='./Data Resources', | |
| help='Directory containing documents (for indexing if needed)') | |
| # Generation parameters | |
| parser.add_argument('--max-new-tokens', type=int, default=1024, | |
| help='Maximum new tokens to generate') | |
| parser.add_argument('--temperature', type=float, default=0.2, | |
| help='Generation temperature') | |
| parser.add_argument('--top-p', type=float, default=0.9, | |
| help='Top-p sampling parameter') | |
| parser.add_argument('--repetition-penalty', type=float, default=1.1, | |
| help='Repetition penalty') | |
| # Retrieval parameters | |
| parser.add_argument('--k', type=int, default=5, | |
| help='Number of chunks to retrieve per question') | |
| # Other settings | |
| parser.add_argument('--skip-indexing', action='store_true', | |
| help='Skip document indexing (use existing vector DB)') | |
| parser.add_argument('--verbose', action='store_true', | |
| help='Enable verbose logging') | |
| parser.add_argument('--share', action='store_true', | |
| help='Create a public Gradio share link') | |
| parser.add_argument('--server-name', type=str, default='127.0.0.1', | |
| help='Server name (0.0.0.0 for public access)') | |
| parser.add_argument('--server-port', type=int, default=7860, | |
| help='Server port') | |
| args = parser.parse_args() | |
| # Set logging level | |
| if args.verbose: | |
| logging.getLogger().setLevel(logging.DEBUG) | |
| logger.info("Initializing RAGBot for Gradio interface...") | |
| logger.info(f"Model: {args.model}") | |
| logger.info(f"Vector DB: {args.vector_db_dir}") | |
| try: | |
| # Initialize bot | |
| bot = RAGBot(args) | |
| # Check if vector database exists and has documents | |
| collection_stats = bot.vector_retriever.get_collection_stats() | |
| if collection_stats.get('total_chunks', 0) == 0: | |
| logger.warning("Vector database is empty. You may need to run indexing first:") | |
| logger.warning(" python bot.py --data-dir './Data Resources' --vector-db-dir './chroma_db'") | |
| logger.warning("Continuing anyway - the chatbot will work but may not find relevant documents.") | |
| # Create and launch Gradio interface | |
| demo = create_interface(bot) | |
| # For local use, launch it | |
| # (On Spaces, the demo is already created at module level) | |
| logger.info(f"Launching Gradio interface on http://{args.server_name}:{args.server_port}") | |
| demo.launch( | |
| server_name=args.server_name, | |
| server_port=args.server_port, | |
| share=args.share | |
| ) | |
| except KeyboardInterrupt: | |
| logger.info("Interrupted by user") | |
| sys.exit(0) | |
| except Exception as e: | |
| logger.error(f"Error launching Gradio app: {e}", exc_info=True) | |
| sys.exit(1) | |
| # For Hugging Face Spaces: create demo at module level | |
| # Following the HF Spaces pattern: create the Gradio app directly at module level | |
| # Spaces will import this module and look for a Gradio Blocks/Interface object | |
| # Pattern: demo = gr.Interface(...) or demo = gr.Blocks(...) | |
| # DO NOT call demo.launch() - Spaces handles that automatically | |
| # Check if we're on Spaces (be more permissive - check multiple env vars) | |
| IS_SPACES = ( | |
| os.getenv("SPACE_ID") is not None or | |
| os.getenv("SYSTEM") == "spaces" or | |
| os.getenv("HF_SPACE_ID") is not None | |
| ) | |
| # CRITICAL: Initialize demo variable FIRST before any try/except | |
| # This ensures it always exists, even if initialization fails | |
| demo = None | |
| def _create_demo(): | |
| """Create the demo - separated into function for better error handling""" | |
| try: | |
| logger.info("=" * 80) | |
| logger.info("Starting demo creation...") | |
| logger.info(f"IS_SPACES: {IS_SPACES}") | |
| logger.info(f"BOT_AVAILABLE: {BOT_AVAILABLE}") | |
| if not BOT_AVAILABLE: | |
| raise ImportError("bot module is not available - cannot create demo") | |
| # Initialize with default args | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct') | |
| parser.add_argument('--vector-db-dir', default='./chroma_db') | |
| parser.add_argument('--data-dir', default='./Data Resources') | |
| parser.add_argument('--max-new-tokens', type=int, default=1024) | |
| parser.add_argument('--temperature', type=float, default=0.2) | |
| parser.add_argument('--top-p', type=float, default=0.9) | |
| parser.add_argument('--repetition-penalty', type=float, default=1.1) | |
| parser.add_argument('--k', type=int, default=5) | |
| parser.add_argument('--skip-indexing', action='store_true', default=True) | |
| parser.add_argument('--verbose', action='store_true', default=False) | |
| parser.add_argument('--share', action='store_true', default=False) | |
| parser.add_argument('--server-name', type=str, default='0.0.0.0') | |
| parser.add_argument('--server-port', type=int, default=7860) | |
| parser.add_argument('--seed', type=int, default=42) | |
| args = parser.parse_args([]) # Empty args | |
| args.skip_model_loading = IS_SPACES # Skip model loading on Spaces, use Inference API | |
| logger.info("Creating RAGBot...") | |
| # Create bot - handle initialization errors gracefully | |
| bot = RAGBot(args) | |
| if bot.vector_retriever is None: | |
| raise Exception("Vector database not available") | |
| # Check if vector database has documents | |
| collection_stats = bot.vector_retriever.get_collection_stats() | |
| if collection_stats.get('total_chunks', 0) == 0: | |
| logger.warning("Vector database is empty. The chatbot may not find relevant documents.") | |
| logger.warning("This is OK for initial deployment - documents can be indexed later.") | |
| logger.info("Creating interface...") | |
| # Create the demo interface directly at module level (like HF docs example) | |
| demo = create_interface(bot, use_inference_api=IS_SPACES) | |
| logger.info(f"Demo created successfully: {type(demo)}") | |
| return demo | |
| except Exception as bot_error: | |
| logger.error(f"Error initializing: {bot_error}", exc_info=True) | |
| import traceback | |
| error_trace = traceback.format_exc() | |
| logger.error(f"Full traceback: {error_trace}") | |
| # Create a demo that shows the error but still allows the interface to load | |
| with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as error_demo: | |
| gr.Markdown(f""" | |
| # ⚠️ Initialization Error | |
| The chatbot encountered an error during initialization: | |
| **Error:** {str(bot_error)} | |
| **Possible causes:** | |
| - Missing vector database (chroma_db directory) | |
| - Missing dependencies | |
| - Configuration issues | |
| - Inference API initialization failed | |
| **For Spaces:** | |
| - Make sure HF_TOKEN is set as a secret | |
| - Check the logs tab for detailed error messages | |
| **Error Details:** | |
| ``` | |
| {error_trace[:1000]}... | |
| ``` | |
| """) | |
| logger.info(f"Error demo created: {type(error_demo)}") | |
| return error_demo | |
| # Create demo at module level (like HF docs example) | |
| # This ensures Spaces can always find it when importing the module | |
| # CRITICAL: For Spaces, create demo directly at module level (not through function) | |
| # This ensures it's definitely accessible when Spaces imports the module | |
| try: | |
| if IS_SPACES: | |
| logger.info("Creating demo directly at module level for Spaces...") | |
| else: | |
| logger.info("Creating demo for local execution...") | |
| # Call the function to create demo | |
| demo = _create_demo() | |
| # CRITICAL: Ensure demo is definitely set at module level | |
| if demo is None or not isinstance(demo, (gr.Blocks, gr.Interface)): | |
| raise ValueError(f"Demo creation returned invalid result: {type(demo)}") | |
| logger.info("Demo creation completed successfully") | |
| except Exception as e: | |
| logger.error(f"CRITICAL: Error creating demo: {e}", exc_info=True) | |
| import traceback | |
| error_trace = traceback.format_exc() | |
| logger.error(f"Full traceback: {error_trace}") | |
| # Create a fallback error demo so Spaces doesn't show blank | |
| with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: | |
| gr.Markdown(f""" | |
| # Error Initializing Chatbot | |
| A critical error occurred while initializing the chatbot. | |
| **Error:** {str(e)} | |
| **Traceback:** | |
| ``` | |
| {error_trace[:1500]}... | |
| ``` | |
| Please check the logs for more details. | |
| """) | |
| logger.info(f"Fallback error demo created: {type(demo)}") | |
| # Final verification - ensure demo exists and is valid | |
| # This is CRITICAL for Spaces - the demo variable MUST exist and be a valid Gradio object | |
| if demo is None: | |
| logger.error("CRITICAL: Demo variable is None! Creating fallback demo.") | |
| with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: | |
| gr.Markdown("# Error: Demo was not created properly\n\nPlease check the logs for details.") | |
| elif not isinstance(demo, (gr.Blocks, gr.Interface)): | |
| logger.error(f"CRITICAL: Demo is not a valid Gradio object: {type(demo)}") | |
| with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: | |
| gr.Markdown(f"# Error: Invalid demo type\n\nDemo type: {type(demo)}\n\nPlease check the logs for details.") | |
| else: | |
| logger.info(f"✅ Final demo check passed: demo type={type(demo)}") | |
| # Explicit print to ensure demo is accessible (Spaces might check this) | |
| print(f"DEMO_VARIABLE_SET: {type(demo)}") | |
| # CRITICAL: Ensure demo is always set for Spaces | |
| # Spaces will look for a variable named 'demo' at module level | |
| # Final safety check - if demo is still None or invalid, create a minimal one | |
| if demo is None or not isinstance(demo, (gr.Blocks, gr.Interface)): | |
| logger.error("CRITICAL: Demo is invalid, creating emergency fallback") | |
| with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo: | |
| gr.Markdown(""" | |
| # CGT-LLM-Beta RAG Chatbot | |
| The application encountered an error during initialization. | |
| Please check the logs for details. | |
| """) | |
| # CRITICAL FOR SPACES: Explicitly verify and expose the demo | |
| # Make sure it's accessible at module level | |
| if IS_SPACES: | |
| logger.info("=" * 80) | |
| logger.info("SPACES MODE: Final demo verification") | |
| logger.info(f"Demo type: {type(demo)}") | |
| logger.info(f"Demo is None: {demo is None}") | |
| logger.info(f"Demo is valid: {isinstance(demo, (gr.Blocks, gr.Interface))}") | |
| logger.info("=" * 80) | |
| # Explicitly set it again to ensure it's at module level | |
| if isinstance(demo, (gr.Blocks, gr.Interface)): | |
| # Make sure demo is accessible | |
| __all__ = ['demo'] # Explicitly export demo | |
| logger.info("Demo is ready for Spaces") | |
| # CRITICAL: For Spaces, we must ensure the demo is definitely accessible | |
| # Sometimes Spaces has issues if the demo isn't immediately available | |
| # Let's also print it to stdout so Spaces can definitely see it | |
| import sys | |
| print("=" * 80, file=sys.stdout) | |
| print(f"DEMO_READY: {type(demo)}", file=sys.stdout) | |
| print(f"DEMO_VALID: {isinstance(demo, (gr.Blocks, gr.Interface))}", file=sys.stdout) | |
| print("=" * 80, file=sys.stdout) | |
| else: | |
| logger.error("CRITICAL: Demo is not valid even after all checks!") | |
| # For local execution only (not on Spaces) | |
| if __name__ == "__main__": | |
| if not IS_SPACES: | |
| main() | |