alx-d commited on
Commit
23e5fe5
·
verified ·
1 Parent(s): 2ace709

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. advanced_rag.py +1102 -1081
advanced_rag.py CHANGED
@@ -1,1081 +1,1102 @@
1
- import os
2
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
- import datetime
4
- import functools
5
- import traceback
6
- from typing import List, Optional, Any, Dict
7
-
8
- import torch
9
- import transformers
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
- from langchain_community.llms import HuggingFacePipeline
12
-
13
- # Other LangChain and community imports
14
- from langchain_community.document_loaders import OnlinePDFLoader
15
- from langchain.text_splitter import RecursiveCharacterTextSplitter
16
- from langchain_community.vectorstores import FAISS
17
- from langchain.embeddings import HuggingFaceEmbeddings
18
- from langchain_community.retrievers import BM25Retriever
19
- from langchain.retrievers import EnsembleRetriever
20
- from langchain.prompts import ChatPromptTemplate
21
- from langchain.schema import StrOutputParser, Document
22
- from langchain_core.runnables import RunnableParallel, RunnableLambda
23
- from transformers.quantizers.auto import AutoQuantizationConfig
24
- import gradio as gr
25
- import requests
26
- from pydantic import PrivateAttr
27
- import pydantic
28
-
29
- from langchain.llms.base import LLM
30
- from typing import Any, Optional, List
31
- import typing
32
- import time
33
-
34
- print("Pydantic Version: ")
35
- print(pydantic.__version__)
36
- # Add Mistral imports with fallback handling
37
-
38
- try:
39
- from mistralai import Mistral
40
- MISTRAL_AVAILABLE = True
41
- debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
42
- debug_print("Loaded latest Mistral client library")
43
- except ImportError:
44
- MISTRAL_AVAILABLE = False
45
- debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
46
- debug_print("Mistral client library not found. Install with: pip install mistralai")
47
-
48
- def debug_print(message: str):
49
- print(f"[{datetime.datetime.now().isoformat()}] {message}", flush=True)
50
-
51
- def word_count(text: str) -> int:
52
- return len(text.split())
53
-
54
- # Initialize a tokenizer for token counting (using gpt2 as a generic fallback)
55
- def initialize_tokenizer():
56
- try:
57
- return AutoTokenizer.from_pretrained("gpt2")
58
- except Exception as e:
59
- debug_print("Failed to initialize tokenizer: " + str(e))
60
- return None
61
-
62
- global_tokenizer = initialize_tokenizer()
63
-
64
- def count_tokens(text: str) -> int:
65
- if global_tokenizer:
66
- try:
67
- return len(global_tokenizer.encode(text))
68
- except Exception as e:
69
- return len(text.split())
70
- return len(text.split())
71
-
72
-
73
- # Add these imports at the top of your file
74
- import uuid
75
- import threading
76
- import queue
77
- from typing import Dict, Any, Tuple, Optional
78
- import time
79
-
80
- # Global storage for jobs and results
81
- jobs = {} # Stores job status and results
82
- results_queue = queue.Queue() # Thread-safe queue for completed jobs
83
- processing_lock = threading.Lock() # Prevent simultaneous processing of the same job
84
-
85
- # Add these missing async processing functions
86
-
87
- def process_in_background(job_id, function, args):
88
- """Process a function in the background and store results"""
89
- try:
90
- debug_print(f"Processing job {job_id} in background")
91
- result = function(*args)
92
- results_queue.put((job_id, result))
93
- debug_print(f"Job {job_id} completed and added to results queue")
94
- except Exception as e:
95
- debug_print(f"Error in background job {job_id}: {str(e)}")
96
- error_result = (f"Error processing job: {str(e)}", "", "", "")
97
- results_queue.put((job_id, error_result))
98
-
99
- def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
100
- """Asynchronous version of load_pdfs_updated to prevent timeouts"""
101
- if not file_links:
102
- return "Please enter non-empty URLs", "", "Model used: N/A"
103
-
104
- job_id = str(uuid.uuid4())
105
- debug_print(f"Starting async job {job_id} for file loading")
106
-
107
- # Start background thread
108
- threading.Thread(
109
- target=process_in_background,
110
- args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p])
111
- ).start()
112
-
113
- jobs[job_id] = {
114
- "status": "processing",
115
- "type": "load_files",
116
- "start_time": time.time(),
117
- "query": f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
118
- }
119
-
120
- return (
121
- f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
122
- f"Use 'Check Job Status' tab with this ID to get results.",
123
- f"Job ID: {job_id}",
124
- f"Model requested: {model_choice}"
125
- )
126
-
127
- def submit_query_async(query, model_choice=None):
128
- """Asynchronous version of submit_query_updated to prevent timeouts"""
129
- if not query:
130
- return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
131
-
132
- job_id = str(uuid.uuid4())
133
- debug_print(f"Starting async job {job_id} for query: {query}")
134
-
135
- # Update model if specified
136
- if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
137
- debug_print(f"Updating model to {model_choice} for this query")
138
- rag_chain.update_llm_pipeline(model_choice, rag_chain.temperature, rag_chain.top_p,
139
- rag_chain.prompt_template, rag_chain.bm25_weight)
140
-
141
- # Start background thread
142
- threading.Thread(
143
- target=process_in_background,
144
- args=(job_id, submit_query_updated, [query])
145
- ).start()
146
-
147
- jobs[job_id] = {
148
- "status": "processing",
149
- "type": "query",
150
- "start_time": time.time(),
151
- "query": query,
152
- "model": rag_chain.llm_choice if hasattr(rag_chain, 'llm_choice') else "Unknown"
153
- }
154
-
155
- return (
156
- f"Query submitted and processing in the background (Job ID: {job_id}).\n\n"
157
- f"Use 'Check Job Status' tab with this ID to get results.",
158
- f"Job ID: {job_id}",
159
- f"Input tokens: {count_tokens(query)}",
160
- "Output tokens: pending"
161
- )
162
-
163
- # Function to display all jobs as a clickable list
164
- def get_job_list():
165
- job_list_md = "### Submitted Jobs\n\n"
166
-
167
- if not jobs:
168
- return "No jobs found. Submit a query or load files to create jobs."
169
-
170
- # Sort jobs by start time (newest first)
171
- sorted_jobs = sorted(
172
- [(job_id, job_info) for job_id, job_info in jobs.items()],
173
- key=lambda x: x[1].get("start_time", 0),
174
- reverse=True
175
- )
176
-
177
- for job_id, job_info in sorted_jobs:
178
- status = job_info.get("status", "unknown")
179
- job_type = job_info.get("type", "unknown")
180
- query = job_info.get("query", "")
181
- start_time = job_info.get("start_time", 0)
182
- time_str = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
183
-
184
- # Create a shortened query preview
185
- query_preview = query[:30] + "..." if query and len(query) > 30 else query or "N/A"
186
-
187
- # Create clickable links using Markdown
188
- if job_type == "query":
189
- job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - Query: {query_preview}\n"
190
- else:
191
- job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - File Load Job\n"
192
-
193
- return job_list_md
194
-
195
- # Function to handle job list clicks
196
- def job_selected(job_id):
197
- if job_id in jobs:
198
- return job_id, jobs[job_id].get("query", "No query for this job")
199
- return job_id, "Job not found"
200
-
201
- # Function to refresh the job list
202
- def refresh_job_list():
203
- return get_job_list()
204
-
205
- # Function to sync model dropdown boxes
206
- def sync_model_dropdown(value):
207
- return value
208
-
209
- # Function to check job status
210
- def check_job_status(job_id):
211
- if not job_id:
212
- return "Please enter a job ID", "", "", "", ""
213
-
214
- # Process any completed jobs in the queue
215
- try:
216
- while not results_queue.empty():
217
- completed_id, result = results_queue.get_nowait()
218
- if completed_id in jobs:
219
- jobs[completed_id]["status"] = "completed"
220
- jobs[completed_id]["result"] = result
221
- jobs[completed_id]["end_time"] = time.time()
222
- debug_print(f"Job {completed_id} completed and stored in jobs dictionary")
223
- except queue.Empty:
224
- pass
225
-
226
- # Check if the requested job exists
227
- if job_id not in jobs:
228
- return "Job not found. Please check the ID and try again.", "", "", "", ""
229
-
230
- job = jobs[job_id]
231
- job_query = job.get("query", "No query available for this job")
232
-
233
- # If job is still processing
234
- if job["status"] == "processing":
235
- elapsed_time = time.time() - job["start_time"]
236
- job_type = job.get("type", "unknown")
237
-
238
- if job_type == "load_files":
239
- return (
240
- f"Files are still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
241
- f"Try checking again in a few seconds.",
242
- f"Job ID: {job_id}",
243
- f"Status: Processing",
244
- "",
245
- job_query
246
- )
247
- else: # query job
248
- return (
249
- f"Query is still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
250
- f"Try checking again in a few seconds.",
251
- f"Job ID: {job_id}",
252
- f"Input tokens: {count_tokens(job.get('query', ''))}",
253
- "Output tokens: pending",
254
- job_query
255
- )
256
-
257
- # If job is completed
258
- if job["status"] == "completed":
259
- result = job["result"]
260
- processing_time = job["end_time"] - job["start_time"]
261
-
262
- if job.get("type") == "load_files":
263
- return (
264
- f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
265
- result[1],
266
- result[2],
267
- "",
268
- job_query
269
- )
270
- else: # query job
271
- return (
272
- f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
273
- result[1],
274
- result[2],
275
- result[3],
276
- job_query
277
- )
278
-
279
- # Fallback for unknown status
280
- return f"Job status: {job['status']}", "", "", "", job_query
281
-
282
- # Function to clean up old jobs
283
- def cleanup_old_jobs():
284
- current_time = time.time()
285
- to_delete = []
286
-
287
- for job_id, job in jobs.items():
288
- # Keep completed jobs for 1 hour, processing jobs for 2 hours
289
- if job["status"] == "completed" and (current_time - job.get("end_time", 0)) > 3600:
290
- to_delete.append(job_id)
291
- elif job["status"] == "processing" and (current_time - job.get("start_time", 0)) > 7200:
292
- to_delete.append(job_id)
293
-
294
- for job_id in to_delete:
295
- del jobs[job_id]
296
-
297
- debug_print(f"Cleaned up {len(to_delete)} old jobs. {len(jobs)} jobs remaining.")
298
- return f"Cleaned up {len(to_delete)} old jobs", "", ""
299
-
300
- # Improve the truncate_prompt function to be more aggressive with limiting context
301
- def truncate_prompt(prompt: str, max_tokens: int = 4096) -> str:
302
- """Truncate prompt to fit within token limit, preserving the most recent/relevant parts."""
303
- if not prompt:
304
- return ""
305
-
306
- if global_tokenizer:
307
- try:
308
- tokens = global_tokenizer.encode(prompt)
309
- if len(tokens) > max_tokens:
310
- # For prompts, we often want to keep the beginning instructions and the end context
311
- # So we'll keep the first 20% and the last 80% of the max tokens
312
- beginning_tokens = int(max_tokens * 0.2)
313
- ending_tokens = max_tokens - beginning_tokens
314
-
315
- new_tokens = tokens[:beginning_tokens] + tokens[-(ending_tokens):]
316
- return global_tokenizer.decode(new_tokens)
317
- except Exception as e:
318
- debug_print(f"Truncation error: {str(e)}")
319
-
320
- # Fallback to word-based truncation
321
- words = prompt.split()
322
- if len(words) > max_tokens:
323
- beginning_words = int(max_tokens * 0.2)
324
- ending_words = max_tokens - beginning_words
325
-
326
- return " ".join(words[:beginning_words] + words[-(ending_words):])
327
-
328
- return prompt
329
-
330
-
331
-
332
-
333
- default_prompt = """\
334
- {conversation_history}
335
- Use the following context to provide a detailed technical answer to the user's question.
336
- Do not include an introduction like "Based on the provided documents, ...". Just answer the question.
337
- If you don't know the answer, please respond with "I don't know".
338
-
339
- Context:
340
- {context}
341
-
342
- User's question:
343
- {question}
344
- """
345
-
346
- def load_txt_from_url(url: str) -> Document:
347
- response = requests.get(url)
348
- if response.status_code == 200:
349
- text = response.text.strip()
350
- if not text:
351
- raise ValueError(f"TXT file at {url} is empty.")
352
- return Document(page_content=text, metadata={"source": url})
353
- else:
354
- raise Exception(f"Failed to load {url} with status {response.status_code}")
355
-
356
- class ElevatedRagChain:
357
- def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
358
- bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
359
- debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
360
- self.embed_func = HuggingFaceEmbeddings(
361
- model_name="sentence-transformers/all-MiniLM-L6-v2",
362
- model_kwargs={"device": "cpu"}
363
- )
364
- self.bm25_weight = bm25_weight
365
- self.faiss_weight = 1.0 - bm25_weight
366
- self.top_k = 5
367
- self.llm_choice = llm_choice
368
- self.temperature = temperature
369
- self.top_p = top_p
370
- self.prompt_template = prompt_template
371
- self.context = ""
372
- self.conversation_history: List[Dict[str, str]] = []
373
- self.raw_data = None
374
- self.split_data = None
375
- self.elevated_rag_chain = None
376
-
377
- # Instance method to capture context and conversation history
378
- def capture_context(self, result):
379
- self.context = "\n".join([str(doc) for doc in result["context"]])
380
- result["context"] = self.context
381
- history_text = (
382
- "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in self.conversation_history])
383
- if self.conversation_history else ""
384
- )
385
- result["conversation_history"] = history_text
386
- return result
387
-
388
- # Instance method to extract question from input data
389
- def extract_question(self, input_data):
390
- return input_data["question"]
391
-
392
- # Improve error handling in the ElevatedRagChain class
393
- def create_llm_pipeline(self):
394
- from langchain.llms.base import LLM # Import LLM here so it's always defined
395
- normalized = self.llm_choice.lower()
396
- try:
397
- if "remote" in normalized:
398
- debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
399
- from huggingface_hub import InferenceClient
400
- repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
401
- hf_api_token = os.environ.get("HF_API_TOKEN")
402
- if not hf_api_token:
403
- raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
404
-
405
- client = InferenceClient(token=hf_api_token, timeout=120)
406
-
407
- # We no longer use wait_for_model because it's unsupported
408
- def remote_generate(prompt: str) -> str:
409
- max_retries = 3
410
- backoff = 2 # start with 2 seconds
411
- for attempt in range(max_retries):
412
- try:
413
- debug_print(f"Remote generation attempt {attempt+1}")
414
- response = client.text_generation(
415
- prompt,
416
- model=repo_id,
417
- temperature=self.temperature,
418
- top_p=self.top_p,
419
- max_new_tokens=512 # Reduced token count for speed
420
- )
421
- return response
422
- except Exception as e:
423
- debug_print(f"Attempt {attempt+1} failed with error: {e}")
424
- if attempt == max_retries - 1:
425
- raise
426
- time.sleep(backoff)
427
- backoff *= 2 # exponential backoff
428
- return "Failed to generate response after multiple attempts."
429
-
430
- class RemoteLLM(LLM):
431
- @property
432
- def _llm_type(self) -> str:
433
- return "remote_llm"
434
-
435
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
436
- return remote_generate(prompt)
437
-
438
- @property
439
- def _identifying_params(self) -> dict:
440
- return {"model": repo_id}
441
-
442
- debug_print("Remote Meta-Llama-3 pipeline created successfully.")
443
- return RemoteLLM()
444
-
445
- elif "mistral-api" in normalized:
446
- debug_print("Creating Mistral API pipeline...")
447
- mistral_api_key = os.environ.get("MISTRAL_API_KEY")
448
- if not mistral_api_key:
449
- raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
450
- try:
451
- from mistralai import Mistral
452
- debug_print("Mistral library imported successfully")
453
- except ImportError:
454
- debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
455
- normalized = "llama"
456
- if normalized != "llama":
457
- # from pydantic import PrivateAttr
458
- # from langchain.llms.base import LLM
459
- # from typing import Any, Optional, List
460
- # import typing
461
-
462
- class MistralLLM(LLM):
463
- temperature: float = 0.7
464
- top_p: float = 0.95
465
- _client: Any = PrivateAttr(default=None)
466
-
467
- def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
468
- try:
469
- super().__init__(**kwargs)
470
- # Bypass Pydantic's __setattr__ to assign to _client
471
- object.__setattr__(self, '_client', Mistral(api_key=api_key))
472
- self.temperature = temperature
473
- self.top_p = top_p
474
- except Exception as e:
475
- debug_print(f"Init Mistral failed with error: {e}")
476
-
477
- @property
478
- def _llm_type(self) -> str:
479
- return "mistral_llm"
480
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
481
- try:
482
- debug_print("Calling Mistral API...")
483
- response = self._client.chat.complete(
484
- model="mistral-small-latest",
485
- messages=[{"role": "user", "content": prompt}],
486
- temperature=self.temperature,
487
- top_p=self.top_p
488
- )
489
- return response.choices[0].message.content
490
- except Exception as e:
491
- debug_print(f"Mistral API error: {str(e)}")
492
- return f"Error generating response: {str(e)}"
493
- @property
494
- def _identifying_params(self) -> dict:
495
- return {"model": "mistral-small-latest"}
496
- debug_print("Creating Mistral LLM instance")
497
- mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
498
- debug_print("Mistral API pipeline created successfully.")
499
- return mistral_llm
500
-
501
- else:
502
- # Default case - using a fallback model (or Llama)
503
- debug_print("Using local/fallback model pipeline")
504
- model_id = "facebook/opt-350m" # Use a smaller model as fallback
505
- pipe = pipeline(
506
- "text-generation",
507
- model=model_id,
508
- device=-1, # CPU
509
- max_length=1024
510
- )
511
-
512
- class LocalLLM(LLM):
513
- @property
514
- def _llm_type(self) -> str:
515
- return "local_llm"
516
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
517
- # For this fallback, truncate prompt if it exceeds limits
518
- reserved_gen = 128
519
- max_total = 1024
520
- max_prompt_tokens = max_total - reserved_gen
521
- truncated_prompt = truncate_prompt(prompt, max_tokens=max_prompt_tokens)
522
- generated = pipe(truncated_prompt, max_new_tokens=reserved_gen)[0]["generated_text"]
523
- return generated
524
- @property
525
- def _identifying_params(self) -> dict:
526
- return {"model": model_id, "max_length": 1024}
527
-
528
- debug_print("Local fallback pipeline created.")
529
- return LocalLLM()
530
-
531
- except Exception as e:
532
- debug_print(f"Error creating LLM pipeline: {str(e)}")
533
- # Return a dummy LLM that explains the error
534
- class ErrorLLM(LLM):
535
- @property
536
- def _llm_type(self) -> str:
537
- return "error_llm"
538
-
539
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
540
- return f"Error initializing LLM: \n\nPlease check your environment variables and try again."
541
-
542
- @property
543
- def _identifying_params(self) -> dict:
544
- return {"model": "error"}
545
-
546
- return ErrorLLM()
547
-
548
-
549
- def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
550
- debug_print(f"Updating chain with new model: {new_model_choice}")
551
- self.llm_choice = new_model_choice
552
- self.temperature = temperature
553
- self.top_p = top_p
554
- self.prompt_template = prompt_template
555
- self.bm25_weight = bm25_weight
556
- self.faiss_weight = 1.0 - bm25_weight
557
- self.llm = self.create_llm_pipeline()
558
- def format_response(response: str) -> str:
559
- input_tokens = count_tokens(self.context + self.prompt_template)
560
- output_tokens = count_tokens(response)
561
- formatted = f"### Response\n\n{response}\n\n---\n"
562
- formatted += f"- **Input tokens:** {input_tokens}\n"
563
- formatted += f"- **Output tokens:** {output_tokens}\n"
564
- formatted += f"- **Generated using:** {self.llm_choice}\n"
565
- formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
566
- return formatted
567
- base_runnable = RunnableParallel({
568
- "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
569
- "question": RunnableLambda(self.extract_question)
570
- }) | self.capture_context
571
- self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
572
- debug_print("Chain updated successfully with new LLM pipeline.")
573
-
574
- def add_pdfs_to_vectore_store(self, file_links: List[str]) -> None:
575
- debug_print(f"Processing files using {self.llm_choice}")
576
- self.raw_data = []
577
- for link in file_links:
578
- if link.lower().endswith(".pdf"):
579
- debug_print(f"Loading PDF: {link}")
580
- loaded_docs = OnlinePDFLoader(link).load()
581
- if loaded_docs:
582
- self.raw_data.append(loaded_docs[0])
583
- else:
584
- debug_print(f"No content found in PDF: {link}")
585
- elif link.lower().endswith(".txt") or link.lower().endswith(".utf-8"):
586
- debug_print(f"Loading TXT: {link}")
587
- try:
588
- self.raw_data.append(load_txt_from_url(link))
589
- except Exception as e:
590
- debug_print(f"Error loading TXT file {link}: {e}")
591
- else:
592
- debug_print(f"File type not supported for URL: {link}")
593
- if not self.raw_data:
594
- raise ValueError("No files were successfully loaded. Please check the URLs and file formats.")
595
- debug_print("Files loaded successfully.")
596
- debug_print("Starting text splitting...")
597
- self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
598
- self.split_data = self.text_splitter.split_documents(self.raw_data)
599
- if not self.split_data:
600
- raise ValueError("Text splitting resulted in no chunks. Check the file contents.")
601
- debug_print(f"Text splitting completed. Number of chunks: {len(self.split_data)}")
602
- debug_print("Creating BM25 retriever...")
603
- self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
604
- self.bm25_retriever.k = self.top_k
605
- debug_print("BM25 retriever created.")
606
- debug_print("Embedding chunks and creating FAISS vector store...")
607
- self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
608
- self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
609
- debug_print("FAISS vector store created successfully.")
610
- self.ensemble_retriever = EnsembleRetriever(
611
- retrievers=[self.bm25_retriever, self.faiss_retriever],
612
- weights=[self.bm25_weight, self.faiss_weight]
613
- )
614
-
615
- base_runnable = RunnableParallel({
616
- "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
617
- "question": RunnableLambda(self.extract_question)
618
- }) | self.capture_context
619
-
620
- # Ensure the prompt template is set
621
- self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
622
- if self.rag_prompt is None:
623
- raise ValueError("Prompt template could not be created from the given template.")
624
- prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
625
-
626
- self.str_output_parser = StrOutputParser()
627
- debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
628
- self.llm = self.create_llm_pipeline()
629
- if self.llm is None:
630
- raise ValueError("LLM pipeline creation failed.")
631
-
632
- def format_response(response: str) -> str:
633
- input_tokens = count_tokens(self.context + self.prompt_template)
634
- output_tokens = count_tokens(response)
635
- formatted = f"### Response\n\n{response}\n\n---\n"
636
- formatted += f"- **Input tokens:** {input_tokens}\n"
637
- formatted += f"- **Output tokens:** {output_tokens}\n"
638
- formatted += f"- **Generated using:** {self.llm_choice}\n"
639
- formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
640
- return formatted
641
-
642
- self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
643
- debug_print("Elevated RAG chain successfully built and ready to use.")
644
-
645
-
646
-
647
- def get_current_context(self) -> str:
648
- base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
649
- history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
650
- recent = self.conversation_history[-3:]
651
- if recent:
652
- for i, conv in enumerate(recent, 1):
653
- history_summary += f"**Conversation {i}:**\n- Query: {conv['query']}\n- Response: {conv['response']}\n"
654
- else:
655
- history_summary += "No conversation history."
656
- return base_context + history_summary
657
-
658
- # ----------------------------
659
- # Gradio Interface Functions
660
- # ----------------------------
661
- global rag_chain
662
- rag_chain = ElevatedRagChain()
663
-
664
- def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
665
- debug_print("Inside load_pdfs function.")
666
- if not file_links:
667
- debug_print("Please enter non-empty URLs")
668
- return "Please enter non-empty URLs", "Word count: N/A", "Model used: N/A", "Context: N/A"
669
- try:
670
- links = [link.strip() for link in file_links.split("\n") if link.strip()]
671
- global rag_chain
672
- if rag_chain.raw_data:
673
- rag_chain.update_llm_pipeline(model_choice, temperature, top_p, prompt_template, bm25_weight)
674
- context_display = rag_chain.get_current_context()
675
- response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
676
- return (
677
- response_msg,
678
- f"Word count: {word_count(rag_chain.context)}",
679
- f"Model used: {rag_chain.llm_choice}",
680
- f"Context:\n{context_display}"
681
- )
682
- else:
683
- rag_chain = ElevatedRagChain(
684
- llm_choice=model_choice,
685
- prompt_template=prompt_template,
686
- bm25_weight=bm25_weight,
687
- temperature=temperature,
688
- top_p=top_p
689
- )
690
- rag_chain.add_pdfs_to_vectore_store(links)
691
- context_display = rag_chain.get_current_context()
692
- response_msg = f"Files loaded successfully. Using model: {model_choice}"
693
- return (
694
- response_msg,
695
- f"Word count: {word_count(rag_chain.context)}",
696
- f"Model used: {rag_chain.llm_choice}",
697
- f"Context:\n{context_display}"
698
- )
699
- except Exception as e:
700
- error_msg = traceback.format_exc()
701
- debug_print("Could not load files. Error: " + error_msg)
702
- return (
703
- "Error loading files: " + str(e),
704
- f"Word count: {word_count('')}",
705
- f"Model used: {rag_chain.llm_choice}",
706
- "Context: N/A"
707
- )
708
-
709
- def update_model(new_model: str):
710
- global rag_chain
711
- if rag_chain and rag_chain.raw_data:
712
- rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p,
713
- rag_chain.prompt_template, rag_chain.bm25_weight)
714
- debug_print(f"Model updated to {rag_chain.llm_choice}")
715
- return f"Model updated to: {rag_chain.llm_choice}"
716
- else:
717
- return "No files loaded; please load files first."
718
-
719
-
720
- # Update submit_query_updated to better handle context limitation
721
- def submit_query_updated(query):
722
- debug_print(f"Processing query: {query}")
723
- if not query:
724
- debug_print("Empty query received")
725
- return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
726
-
727
- if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
728
- debug_print("RAG chain not initialized")
729
- return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
730
-
731
- try:
732
- # Determine max context size based on model
733
- model_name = rag_chain.llm_choice.lower()
734
- max_context_tokens = 32000 if "mistral" in model_name else 4096
735
-
736
- # Reserve 20% of tokens for the question and response generation
737
- reserved_tokens = int(max_context_tokens * 0.2)
738
- max_context_tokens -= reserved_tokens
739
-
740
- # Collect conversation history (last 2 only to save tokens)
741
- if rag_chain.conversation_history:
742
- recent_history = rag_chain.conversation_history[-2:]
743
- history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response'][:300]}..."
744
- for conv in recent_history])
745
- else:
746
- history_text = ""
747
-
748
- # Get history token count
749
- history_tokens = count_tokens(history_text)
750
-
751
- # Adjust context tokens based on history size
752
- context_tokens = max_context_tokens - history_tokens
753
-
754
- # Ensure we have some minimum context
755
- context_tokens = max(context_tokens, 1000)
756
-
757
- # Truncate context if needed
758
- context = truncate_prompt(rag_chain.context, max_tokens=context_tokens)
759
-
760
- debug_print(f"Using model: {model_name}, context tokens: {count_tokens(context)}, history tokens: {history_tokens}")
761
-
762
- prompt_variables = {
763
- "conversation_history": history_text,
764
- "context": context,
765
- "question": query
766
- }
767
-
768
- debug_print("Invoking RAG chain")
769
- response = rag_chain.elevated_rag_chain.invoke({"question": query})
770
-
771
- # Store only a reasonable amount of the response in history
772
- trimmed_response = response[:1000] + ("..." if len(response) > 1000 else "")
773
- rag_chain.conversation_history.append({"query": query, "response": trimmed_response})
774
-
775
- input_token_count = count_tokens(query)
776
- output_token_count = count_tokens(response)
777
-
778
- debug_print(f"Query processed successfully. Output tokens: {output_token_count}")
779
-
780
- return (
781
- response,
782
- rag_chain.get_current_context(),
783
- f"Input tokens: {input_token_count}",
784
- f"Output tokens: {output_token_count}"
785
- )
786
- except Exception as e:
787
- error_msg = traceback.format_exc()
788
- debug_print(f"LLM error: {error_msg}")
789
- return (
790
- f"Query error: {str(e)}\n\nTry using a smaller document or simplifying your query.",
791
- "",
792
- "Input tokens: 0",
793
- "Output tokens: 0"
794
- )
795
-
796
- def reset_app_updated():
797
- global rag_chain
798
- rag_chain = ElevatedRagChain()
799
- debug_print("App reset successfully.")
800
- return (
801
- "App reset successfully. You can now load new files",
802
- "",
803
- "Model used: Not selected"
804
- )
805
-
806
- # ----------------------------
807
- # Gradio Interface Setup
808
- # ----------------------------
809
- custom_css = """
810
- textarea {
811
- overflow-y: scroll !important;
812
- max-height: 200px;
813
- }
814
- """
815
-
816
- # Update the Gradio interface to include job status checking
817
- with gr.Blocks(css=custom_css, js="""
818
- document.addEventListener('DOMContentLoaded', function() {
819
- // Add event listener for job list clicks
820
- const jobListInterval = setInterval(() => {
821
- const jobLinks = document.querySelectorAll('.job-list-container a');
822
- if (jobLinks.length > 0) {
823
- jobLinks.forEach(link => {
824
- link.addEventListener('click', function(e) {
825
- e.preventDefault();
826
- const jobId = this.textContent.split(' ')[0];
827
- // Find the job ID input textbox and set its value
828
- const jobIdInput = document.querySelector('.job-id-input input');
829
- if (jobIdInput) {
830
- jobIdInput.value = jobId;
831
- // Trigger the input event to update Gradio's state
832
- jobIdInput.dispatchEvent(new Event('input', { bubbles: true }));
833
- }
834
- });
835
- });
836
- clearInterval(jobListInterval);
837
- }
838
- }, 500);
839
- });
840
- """) as app:
841
- gr.Markdown('''# PhiRAG - Async Version
842
- **PhiRAG** Query Your Data with Advanced RAG Techniques
843
-
844
- **Model Selection & Parameters:** Choose from the following options:
845
- - 🇺🇸 Remote Meta-Llama-3 - has context windows of 8000 tokens
846
- - 🇪🇺 Mistral-API - has context windows of 32000 tokens
847
-
848
- **🔥 Randomness (Temperature):** Adjusts output predictability.
849
- - Example: 0.2 makes the output very deterministic (less creative), while 0.8 introduces more variety and spontaneity.
850
-
851
- **🎯 Word Variety (Top‑p):** Limits word choices to a set probability percentage.
852
- - Example: 0.5 restricts output to the most likely 50% of token choices for a focused answer; 0.95 allows almost all possibilities for more diverse responses.
853
-
854
- **⚖️ BM25 Weight:** Adjust Lexical vs Semantics.
855
- - Example: A value of 0.8 puts more emphasis on exact keyword (lexical) matching, while 0.3 shifts emphasis toward semantic similarity.
856
-
857
- **✏️ Prompt Template:** Edit as desired.
858
-
859
- **🔗 File URLs:** Enter one URL per line (.pdf or .txt).\
860
- - Example: Provide one URL per line, such as
861
- https://www.gutenberg.org/ebooks/8438.txt.utf-8
862
-
863
- **🔍 Query:** Enter your query below.
864
-
865
- **⚠️ IMPORTANT: This app now uses asynchronous processing to avoid timeout issues**
866
- - When you load files or submit a query, you'll receive a Job ID
867
- - Use the "Check Job Status" tab to monitor and retrieve your results
868
- ''')
869
-
870
- with gr.Tabs() as tabs:
871
- with gr.TabItem("Setup & Load Files"):
872
- with gr.Row():
873
- with gr.Column():
874
- model_dropdown = gr.Dropdown(
875
- choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
876
- value="🇺🇸 Remote Meta-Llama-3",
877
- label="Select Model"
878
- )
879
- temperature_slider = gr.Slider(
880
- minimum=0.1, maximum=1.0, value=0.5, step=0.1,
881
- label="Randomness (Temperature)"
882
- )
883
- top_p_slider = gr.Slider(
884
- minimum=0.1, maximum=0.99, value=0.95, step=0.05,
885
- label="Word Variety (Top-p)"
886
- )
887
- with gr.Column():
888
- pdf_input = gr.Textbox(
889
- label="Enter your file URLs (one per line)",
890
- placeholder="Enter one URL per line (.pdf or .txt)",
891
- lines=4
892
- )
893
- prompt_input = gr.Textbox(
894
- label="Custom Prompt Template",
895
- placeholder="Enter your custom prompt template here",
896
- lines=8,
897
- value=default_prompt
898
- )
899
- with gr.Column():
900
- bm25_weight_slider = gr.Slider(
901
- minimum=0.0, maximum=1.0, value=0.6, step=0.1,
902
- label="Lexical vs Semantics (BM25 Weight)"
903
- )
904
- load_button = gr.Button("Load Files (Async)")
905
- load_status = gr.Markdown("Status: Waiting for files")
906
-
907
- with gr.Row():
908
- load_response = gr.Textbox(
909
- label="Load Response",
910
- placeholder="Response will appear here",
911
- lines=4
912
- )
913
- load_context = gr.Textbox(
914
- label="Context Info",
915
- placeholder="Context info will appear here",
916
- lines=4
917
- )
918
-
919
- with gr.Row():
920
- model_output = gr.Markdown("**Current Model**: Not selected")
921
-
922
- with gr.TabItem("Submit Query"):
923
- with gr.Row():
924
- # Add this line to define the query_model_dropdown
925
- query_model_dropdown = gr.Dropdown(
926
- choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
927
- value="🇺🇸 Remote Meta-Llama-3",
928
- label="Query Model"
929
- )
930
-
931
- query_input = gr.Textbox(
932
- label="Enter your query here",
933
- placeholder="Type your query",
934
- lines=4
935
- )
936
- submit_button = gr.Button("Submit Query (Async)")
937
-
938
- with gr.Row():
939
- query_response = gr.Textbox(
940
- label="Query Response",
941
- placeholder="Response will appear here (formatted as Markdown)",
942
- lines=6
943
- )
944
- query_context = gr.Textbox(
945
- label="Context Information",
946
- placeholder="Retrieved context and conversation history will appear here",
947
- lines=6
948
- )
949
-
950
- with gr.Row():
951
- input_tokens = gr.Markdown("Input tokens: 0")
952
- output_tokens = gr.Markdown("Output tokens: 0")
953
-
954
- with gr.TabItem("Check Job Status"):
955
- with gr.Row():
956
- with gr.Column(scale=1):
957
- job_list = gr.Markdown(
958
- value="No jobs yet",
959
- label="Job List (Click to select)"
960
- )
961
- refresh_button = gr.Button("Refresh Job List")
962
-
963
- with gr.Column(scale=2):
964
- job_id_input = gr.Textbox(
965
- label="Job ID",
966
- placeholder="Job ID will appear here when selected from the list",
967
- lines=1
968
- )
969
- job_query_display = gr.Textbox(
970
- label="Job Query",
971
- placeholder="The query associated with this job will appear here",
972
- lines=2,
973
- interactive=False
974
- )
975
- check_button = gr.Button("Check Status")
976
- cleanup_button = gr.Button("Cleanup Old Jobs")
977
-
978
- with gr.Row():
979
- status_response = gr.Textbox(
980
- label="Job Result",
981
- placeholder="Job result will appear here",
982
- lines=6
983
- )
984
- status_context = gr.Textbox(
985
- label="Context Information",
986
- placeholder="Context information will appear here",
987
- lines=6
988
- )
989
-
990
- with gr.Row():
991
- status_tokens1 = gr.Markdown("")
992
- status_tokens2 = gr.Markdown("")
993
-
994
- with gr.TabItem("App Management"):
995
- with gr.Row():
996
- reset_button = gr.Button("Reset App")
997
-
998
- with gr.Row():
999
- reset_response = gr.Textbox(
1000
- label="Reset Response",
1001
- placeholder="Reset confirmation will appear here",
1002
- lines=2
1003
- )
1004
- reset_context = gr.Textbox(
1005
- label="",
1006
- placeholder="",
1007
- lines=2,
1008
- visible=False
1009
- )
1010
-
1011
- with gr.Row():
1012
- reset_model = gr.Markdown("")
1013
-
1014
- # Connect the buttons to their respective functions
1015
- load_button.click(
1016
- load_pdfs_async,
1017
- inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
1018
- outputs=[load_response, load_context, model_output]
1019
- )
1020
-
1021
- # Also sync in the other direction
1022
- query_model_dropdown.change(
1023
- fn=sync_model_dropdown,
1024
- inputs=query_model_dropdown,
1025
- outputs=model_dropdown
1026
- )
1027
-
1028
- submit_button.click(
1029
- submit_query_async,
1030
- inputs=[query_input, query_model_dropdown],
1031
- outputs=[query_response, query_context, input_tokens, output_tokens]
1032
- )
1033
-
1034
- check_button.click(
1035
- check_job_status,
1036
- inputs=[job_id_input],
1037
- outputs=[status_response, status_context, status_tokens1, status_tokens2, job_query_display]
1038
- )
1039
-
1040
- refresh_button.click(
1041
- refresh_job_list,
1042
- inputs=[],
1043
- outputs=[job_list]
1044
- )
1045
-
1046
- # Connect the job list selection event (this is handled by JavaScript)
1047
- job_id_input.change(
1048
- job_selected,
1049
- inputs=[job_id_input],
1050
- outputs=[job_id_input, job_query_display]
1051
- )
1052
-
1053
- cleanup_button.click(
1054
- cleanup_old_jobs,
1055
- inputs=[],
1056
- outputs=[status_response, status_context, status_tokens1]
1057
- )
1058
-
1059
- reset_button.click(
1060
- reset_app_updated,
1061
- inputs=[],
1062
- outputs=[reset_response, reset_context, reset_model]
1063
- )
1064
-
1065
-
1066
- model_dropdown.change(
1067
- fn=sync_model_dropdown,
1068
- inputs=model_dropdown,
1069
- outputs=query_model_dropdown
1070
- )
1071
-
1072
- # Add an event to refresh the job list on page load
1073
- app.load(
1074
- fn=refresh_job_list,
1075
- inputs=None,
1076
- outputs=job_list
1077
- )
1078
-
1079
- if __name__ == "__main__":
1080
- debug_print("Launching Gradio interface.")
1081
- app.launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+ import datetime
4
+ import functools
5
+ import traceback
6
+ from typing import List, Optional, Any, Dict
7
+
8
+ import torch
9
+ import transformers
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
+ from langchain_community.llms import HuggingFacePipeline
12
+
13
+ # Other LangChain and community imports
14
+ from langchain_community.document_loaders import OnlinePDFLoader
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain_community.vectorstores import FAISS
17
+ from langchain.embeddings import HuggingFaceEmbeddings
18
+ from langchain_community.retrievers import BM25Retriever
19
+ from langchain.retrievers import EnsembleRetriever
20
+ from langchain.prompts import ChatPromptTemplate
21
+ from langchain.schema import StrOutputParser, Document
22
+ from langchain_core.runnables import RunnableParallel, RunnableLambda
23
+ from transformers.quantizers.auto import AutoQuantizationConfig
24
+ import gradio as gr
25
+ import requests
26
+ from pydantic import PrivateAttr
27
+ import pydantic
28
+
29
+ from langchain.llms.base import LLM
30
+ from typing import Any, Optional, List
31
+ import typing
32
+ import time
33
+
34
+ print("Pydantic Version: ")
35
+ print(pydantic.__version__)
36
+ # Add Mistral imports with fallback handling
37
+
38
+ try:
39
+ from mistralai import Mistral
40
+ MISTRAL_AVAILABLE = True
41
+ debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
42
+ debug_print("Loaded latest Mistral client library")
43
+ except ImportError:
44
+ MISTRAL_AVAILABLE = False
45
+ debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
46
+ debug_print("Mistral client library not found. Install with: pip install mistralai")
47
+
48
+ def debug_print(message: str):
49
+ print(f"[{datetime.datetime.now().isoformat()}] {message}", flush=True)
50
+
51
+ def word_count(text: str) -> int:
52
+ return len(text.split())
53
+
54
+ # Initialize a tokenizer for token counting (using gpt2 as a generic fallback)
55
+ def initialize_tokenizer():
56
+ try:
57
+ return AutoTokenizer.from_pretrained("gpt2")
58
+ except Exception as e:
59
+ debug_print("Failed to initialize tokenizer: " + str(e))
60
+ return None
61
+
62
+ global_tokenizer = initialize_tokenizer()
63
+
64
+ def count_tokens(text: str) -> int:
65
+ if global_tokenizer:
66
+ try:
67
+ return len(global_tokenizer.encode(text))
68
+ except Exception as e:
69
+ return len(text.split())
70
+ return len(text.split())
71
+
72
+
73
+ # Add these imports at the top of your file
74
+ import uuid
75
+ import threading
76
+ import queue
77
+ from typing import Dict, Any, Tuple, Optional
78
+ import time
79
+
80
+ # Global storage for jobs and results
81
+ jobs = {} # Stores job status and results
82
+ results_queue = queue.Queue() # Thread-safe queue for completed jobs
83
+ processing_lock = threading.Lock() # Prevent simultaneous processing of the same job
84
+
85
+ # Add a global variable to store the last job ID
86
+ last_job_id = None
87
+
88
+ # Add these missing async processing functions
89
+
90
+ def process_in_background(job_id, function, args):
91
+ """Process a function in the background and store results"""
92
+ try:
93
+ debug_print(f"Processing job {job_id} in background")
94
+ result = function(*args)
95
+ results_queue.put((job_id, result))
96
+ debug_print(f"Job {job_id} completed and added to results queue")
97
+ except Exception as e:
98
+ debug_print(f"Error in background job {job_id}: {str(e)}")
99
+ error_result = (f"Error processing job: {str(e)}", "", "", "")
100
+ results_queue.put((job_id, error_result))
101
+
102
+ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
103
+ """Asynchronous version of load_pdfs_updated to prevent timeouts"""
104
+ global last_job_id
105
+ if not file_links:
106
+ return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list()
107
+
108
+ job_id = str(uuid.uuid4())
109
+ debug_print(f"Starting async job {job_id} for file loading")
110
+
111
+ # Start background thread
112
+ threading.Thread(
113
+ target=process_in_background,
114
+ args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p])
115
+ ).start()
116
+
117
+ job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
118
+ jobs[job_id] = {
119
+ "status": "processing",
120
+ "type": "load_files",
121
+ "start_time": time.time(),
122
+ "query": job_query
123
+ }
124
+
125
+ last_job_id = job_id
126
+
127
+ return (
128
+ f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
129
+ f"Use 'Check Job Status' tab with this ID to get results.",
130
+ f"Job ID: {job_id}",
131
+ f"Model requested: {model_choice}",
132
+ job_id, # Return job_id to update the job_id_input component
133
+ job_query, # Return job_query to update the job_query_display component
134
+ get_job_list() # Return updated job list
135
+ )
136
+
137
+ def submit_query_async(query, model_choice=None):
138
+ """Asynchronous version of submit_query_updated to prevent timeouts"""
139
+ global last_job_id
140
+ if not query:
141
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
142
+
143
+ job_id = str(uuid.uuid4())
144
+ debug_print(f"Starting async job {job_id} for query: {query}")
145
+
146
+ # Update model if specified
147
+ if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
148
+ debug_print(f"Updating model to {model_choice} for this query")
149
+ rag_chain.update_llm_pipeline(model_choice, rag_chain.temperature, rag_chain.top_p,
150
+ rag_chain.prompt_template, rag_chain.bm25_weight)
151
+
152
+ # Start background thread
153
+ threading.Thread(
154
+ target=process_in_background,
155
+ args=(job_id, submit_query_updated, [query])
156
+ ).start()
157
+
158
+ jobs[job_id] = {
159
+ "status": "processing",
160
+ "type": "query",
161
+ "start_time": time.time(),
162
+ "query": query,
163
+ "model": rag_chain.llm_choice if hasattr(rag_chain, 'llm_choice') else "Unknown"
164
+ }
165
+
166
+ last_job_id = job_id
167
+
168
+ return (
169
+ f"Query submitted and processing in the background (Job ID: {job_id}).\n\n"
170
+ f"Use 'Check Job Status' tab with this ID to get results.",
171
+ f"Job ID: {job_id}",
172
+ f"Input tokens: {count_tokens(query)}",
173
+ "Output tokens: pending",
174
+ job_id, # Return job_id to update the job_id_input component
175
+ query, # Return query to update the job_query_display component
176
+ get_job_list() # Return updated job list
177
+ )
178
+
179
+ def update_ui_with_last_job_id():
180
+ # This function doesn't need to do anything anymore
181
+ # We'll update the UI directly in the functions that call this
182
+ pass
183
+
184
+ # Function to display all jobs as a clickable list
185
+ def get_job_list():
186
+ job_list_md = "### Submitted Jobs\n\n"
187
+
188
+ if not jobs:
189
+ return "No jobs found. Submit a query or load files to create jobs."
190
+
191
+ # Sort jobs by start time (newest first)
192
+ sorted_jobs = sorted(
193
+ [(job_id, job_info) for job_id, job_info in jobs.items()],
194
+ key=lambda x: x[1].get("start_time", 0),
195
+ reverse=True
196
+ )
197
+
198
+ for job_id, job_info in sorted_jobs:
199
+ status = job_info.get("status", "unknown")
200
+ job_type = job_info.get("type", "unknown")
201
+ query = job_info.get("query", "")
202
+ start_time = job_info.get("start_time", 0)
203
+ time_str = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
204
+
205
+ # Create a shortened query preview
206
+ query_preview = query[:30] + "..." if query and len(query) > 30 else query or "N/A"
207
+
208
+ # Create clickable links using Markdown
209
+ if job_type == "query":
210
+ job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - Query: {query_preview}\n"
211
+ else:
212
+ job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - File Load Job\n"
213
+
214
+ return job_list_md
215
+
216
+ # Function to handle job list clicks
217
+ def job_selected(job_id):
218
+ if job_id in jobs:
219
+ return job_id, jobs[job_id].get("query", "No query for this job")
220
+ return job_id, "Job not found"
221
+
222
+ # Function to refresh the job list
223
+ def refresh_job_list():
224
+ return get_job_list()
225
+
226
+ # Function to sync model dropdown boxes
227
+ def sync_model_dropdown(value):
228
+ return value
229
+
230
+ # Function to check job status
231
+ def check_job_status(job_id):
232
+ if not job_id:
233
+ return "Please enter a job ID", "", "", "", ""
234
+
235
+ # Process any completed jobs in the queue
236
+ try:
237
+ while not results_queue.empty():
238
+ completed_id, result = results_queue.get_nowait()
239
+ if completed_id in jobs:
240
+ jobs[completed_id]["status"] = "completed"
241
+ jobs[completed_id]["result"] = result
242
+ jobs[completed_id]["end_time"] = time.time()
243
+ debug_print(f"Job {completed_id} completed and stored in jobs dictionary")
244
+ except queue.Empty:
245
+ pass
246
+
247
+ # Check if the requested job exists
248
+ if job_id not in jobs:
249
+ return "Job not found. Please check the ID and try again.", "", "", "", ""
250
+
251
+ job = jobs[job_id]
252
+ job_query = job.get("query", "No query available for this job")
253
+
254
+ # If job is still processing
255
+ if job["status"] == "processing":
256
+ elapsed_time = time.time() - job["start_time"]
257
+ job_type = job.get("type", "unknown")
258
+
259
+ if job_type == "load_files":
260
+ return (
261
+ f"Files are still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
262
+ f"Try checking again in a few seconds.",
263
+ f"Job ID: {job_id}",
264
+ f"Status: Processing",
265
+ "",
266
+ job_query
267
+ )
268
+ else: # query job
269
+ return (
270
+ f"Query is still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
271
+ f"Try checking again in a few seconds.",
272
+ f"Job ID: {job_id}",
273
+ f"Input tokens: {count_tokens(job.get('query', ''))}",
274
+ "Output tokens: pending",
275
+ job_query
276
+ )
277
+
278
+ # If job is completed
279
+ if job["status"] == "completed":
280
+ result = job["result"]
281
+ processing_time = job["end_time"] - job["start_time"]
282
+
283
+ if job.get("type") == "load_files":
284
+ return (
285
+ f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
286
+ result[1],
287
+ result[2],
288
+ "",
289
+ job_query
290
+ )
291
+ else: # query job
292
+ return (
293
+ f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
294
+ result[1],
295
+ result[2],
296
+ result[3],
297
+ job_query
298
+ )
299
+
300
+ # Fallback for unknown status
301
+ return f"Job status: {job['status']}", "", "", "", job_query
302
+
303
+ # Function to clean up old jobs
304
+ def cleanup_old_jobs():
305
+ current_time = time.time()
306
+ to_delete = []
307
+
308
+ for job_id, job in jobs.items():
309
+ # Keep completed jobs for 1 hour, processing jobs for 2 hours
310
+ if job["status"] == "completed" and (current_time - job.get("end_time", 0)) > 3600:
311
+ to_delete.append(job_id)
312
+ elif job["status"] == "processing" and (current_time - job.get("start_time", 0)) > 7200:
313
+ to_delete.append(job_id)
314
+
315
+ for job_id in to_delete:
316
+ del jobs[job_id]
317
+
318
+ debug_print(f"Cleaned up {len(to_delete)} old jobs. {len(jobs)} jobs remaining.")
319
+ return f"Cleaned up {len(to_delete)} old jobs", "", ""
320
+
321
+ # Improve the truncate_prompt function to be more aggressive with limiting context
322
+ def truncate_prompt(prompt: str, max_tokens: int = 4096) -> str:
323
+ """Truncate prompt to fit within token limit, preserving the most recent/relevant parts."""
324
+ if not prompt:
325
+ return ""
326
+
327
+ if global_tokenizer:
328
+ try:
329
+ tokens = global_tokenizer.encode(prompt)
330
+ if len(tokens) > max_tokens:
331
+ # For prompts, we often want to keep the beginning instructions and the end context
332
+ # So we'll keep the first 20% and the last 80% of the max tokens
333
+ beginning_tokens = int(max_tokens * 0.2)
334
+ ending_tokens = max_tokens - beginning_tokens
335
+
336
+ new_tokens = tokens[:beginning_tokens] + tokens[-(ending_tokens):]
337
+ return global_tokenizer.decode(new_tokens)
338
+ except Exception as e:
339
+ debug_print(f"Truncation error: {str(e)}")
340
+
341
+ # Fallback to word-based truncation
342
+ words = prompt.split()
343
+ if len(words) > max_tokens:
344
+ beginning_words = int(max_tokens * 0.2)
345
+ ending_words = max_tokens - beginning_words
346
+
347
+ return " ".join(words[:beginning_words] + words[-(ending_words):])
348
+
349
+ return prompt
350
+
351
+
352
+
353
+
354
+ default_prompt = """\
355
+ {conversation_history}
356
+ Use the following context to provide a detailed technical answer to the user's question.
357
+ Do not include an introduction like "Based on the provided documents, ...". Just answer the question.
358
+ If you don't know the answer, please respond with "I don't know".
359
+
360
+ Context:
361
+ {context}
362
+
363
+ User's question:
364
+ {question}
365
+ """
366
+
367
+ def load_txt_from_url(url: str) -> Document:
368
+ response = requests.get(url)
369
+ if response.status_code == 200:
370
+ text = response.text.strip()
371
+ if not text:
372
+ raise ValueError(f"TXT file at {url} is empty.")
373
+ return Document(page_content=text, metadata={"source": url})
374
+ else:
375
+ raise Exception(f"Failed to load {url} with status {response.status_code}")
376
+
377
+ class ElevatedRagChain:
378
+ def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
379
+ bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
380
+ debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
381
+ self.embed_func = HuggingFaceEmbeddings(
382
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
383
+ model_kwargs={"device": "cpu"}
384
+ )
385
+ self.bm25_weight = bm25_weight
386
+ self.faiss_weight = 1.0 - bm25_weight
387
+ self.top_k = 5
388
+ self.llm_choice = llm_choice
389
+ self.temperature = temperature
390
+ self.top_p = top_p
391
+ self.prompt_template = prompt_template
392
+ self.context = ""
393
+ self.conversation_history: List[Dict[str, str]] = []
394
+ self.raw_data = None
395
+ self.split_data = None
396
+ self.elevated_rag_chain = None
397
+
398
+ # Instance method to capture context and conversation history
399
+ def capture_context(self, result):
400
+ self.context = "\n".join([str(doc) for doc in result["context"]])
401
+ result["context"] = self.context
402
+ history_text = (
403
+ "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in self.conversation_history])
404
+ if self.conversation_history else ""
405
+ )
406
+ result["conversation_history"] = history_text
407
+ return result
408
+
409
+ # Instance method to extract question from input data
410
+ def extract_question(self, input_data):
411
+ return input_data["question"]
412
+
413
+ # Improve error handling in the ElevatedRagChain class
414
+ def create_llm_pipeline(self):
415
+ from langchain.llms.base import LLM # Import LLM here so it's always defined
416
+ normalized = self.llm_choice.lower()
417
+ try:
418
+ if "remote" in normalized:
419
+ debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
420
+ from huggingface_hub import InferenceClient
421
+ repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
422
+ hf_api_token = os.environ.get("HF_API_TOKEN")
423
+ if not hf_api_token:
424
+ raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
425
+
426
+ client = InferenceClient(token=hf_api_token, timeout=120)
427
+
428
+ # We no longer use wait_for_model because it's unsupported
429
+ def remote_generate(prompt: str) -> str:
430
+ max_retries = 3
431
+ backoff = 2 # start with 2 seconds
432
+ for attempt in range(max_retries):
433
+ try:
434
+ debug_print(f"Remote generation attempt {attempt+1}")
435
+ response = client.text_generation(
436
+ prompt,
437
+ model=repo_id,
438
+ temperature=self.temperature,
439
+ top_p=self.top_p,
440
+ max_new_tokens=512 # Reduced token count for speed
441
+ )
442
+ return response
443
+ except Exception as e:
444
+ debug_print(f"Attempt {attempt+1} failed with error: {e}")
445
+ if attempt == max_retries - 1:
446
+ raise
447
+ time.sleep(backoff)
448
+ backoff *= 2 # exponential backoff
449
+ return "Failed to generate response after multiple attempts."
450
+
451
+ class RemoteLLM(LLM):
452
+ @property
453
+ def _llm_type(self) -> str:
454
+ return "remote_llm"
455
+
456
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
457
+ return remote_generate(prompt)
458
+
459
+ @property
460
+ def _identifying_params(self) -> dict:
461
+ return {"model": repo_id}
462
+
463
+ debug_print("Remote Meta-Llama-3 pipeline created successfully.")
464
+ return RemoteLLM()
465
+
466
+ elif "mistral-api" in normalized:
467
+ debug_print("Creating Mistral API pipeline...")
468
+ mistral_api_key = os.environ.get("MISTRAL_API_KEY")
469
+ if not mistral_api_key:
470
+ raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
471
+ try:
472
+ from mistralai import Mistral
473
+ debug_print("Mistral library imported successfully")
474
+ except ImportError:
475
+ debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
476
+ normalized = "llama"
477
+ if normalized != "llama":
478
+ # from pydantic import PrivateAttr
479
+ # from langchain.llms.base import LLM
480
+ # from typing import Any, Optional, List
481
+ # import typing
482
+
483
+ class MistralLLM(LLM):
484
+ temperature: float = 0.7
485
+ top_p: float = 0.95
486
+ _client: Any = PrivateAttr(default=None)
487
+
488
+ def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
489
+ try:
490
+ super().__init__(**kwargs)
491
+ # Bypass Pydantic's __setattr__ to assign to _client
492
+ object.__setattr__(self, '_client', Mistral(api_key=api_key))
493
+ self.temperature = temperature
494
+ self.top_p = top_p
495
+ except Exception as e:
496
+ debug_print(f"Init Mistral failed with error: {e}")
497
+
498
+ @property
499
+ def _llm_type(self) -> str:
500
+ return "mistral_llm"
501
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
502
+ try:
503
+ debug_print("Calling Mistral API...")
504
+ response = self._client.chat.complete(
505
+ model="mistral-small-latest",
506
+ messages=[{"role": "user", "content": prompt}],
507
+ temperature=self.temperature,
508
+ top_p=self.top_p
509
+ )
510
+ return response.choices[0].message.content
511
+ except Exception as e:
512
+ debug_print(f"Mistral API error: {str(e)}")
513
+ return f"Error generating response: {str(e)}"
514
+ @property
515
+ def _identifying_params(self) -> dict:
516
+ return {"model": "mistral-small-latest"}
517
+ debug_print("Creating Mistral LLM instance")
518
+ mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
519
+ debug_print("Mistral API pipeline created successfully.")
520
+ return mistral_llm
521
+
522
+ else:
523
+ # Default case - using a fallback model (or Llama)
524
+ debug_print("Using local/fallback model pipeline")
525
+ model_id = "facebook/opt-350m" # Use a smaller model as fallback
526
+ pipe = pipeline(
527
+ "text-generation",
528
+ model=model_id,
529
+ device=-1, # CPU
530
+ max_length=1024
531
+ )
532
+
533
+ class LocalLLM(LLM):
534
+ @property
535
+ def _llm_type(self) -> str:
536
+ return "local_llm"
537
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
538
+ # For this fallback, truncate prompt if it exceeds limits
539
+ reserved_gen = 128
540
+ max_total = 1024
541
+ max_prompt_tokens = max_total - reserved_gen
542
+ truncated_prompt = truncate_prompt(prompt, max_tokens=max_prompt_tokens)
543
+ generated = pipe(truncated_prompt, max_new_tokens=reserved_gen)[0]["generated_text"]
544
+ return generated
545
+ @property
546
+ def _identifying_params(self) -> dict:
547
+ return {"model": model_id, "max_length": 1024}
548
+
549
+ debug_print("Local fallback pipeline created.")
550
+ return LocalLLM()
551
+
552
+ except Exception as e:
553
+ debug_print(f"Error creating LLM pipeline: {str(e)}")
554
+ # Return a dummy LLM that explains the error
555
+ class ErrorLLM(LLM):
556
+ @property
557
+ def _llm_type(self) -> str:
558
+ return "error_llm"
559
+
560
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
561
+ return f"Error initializing LLM: \n\nPlease check your environment variables and try again."
562
+
563
+ @property
564
+ def _identifying_params(self) -> dict:
565
+ return {"model": "error"}
566
+
567
+ return ErrorLLM()
568
+
569
+
570
+ def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
571
+ debug_print(f"Updating chain with new model: {new_model_choice}")
572
+ self.llm_choice = new_model_choice
573
+ self.temperature = temperature
574
+ self.top_p = top_p
575
+ self.prompt_template = prompt_template
576
+ self.bm25_weight = bm25_weight
577
+ self.faiss_weight = 1.0 - bm25_weight
578
+ self.llm = self.create_llm_pipeline()
579
+ def format_response(response: str) -> str:
580
+ input_tokens = count_tokens(self.context + self.prompt_template)
581
+ output_tokens = count_tokens(response)
582
+ formatted = f"### Response\n\n{response}\n\n---\n"
583
+ formatted += f"- **Input tokens:** {input_tokens}\n"
584
+ formatted += f"- **Output tokens:** {output_tokens}\n"
585
+ formatted += f"- **Generated using:** {self.llm_choice}\n"
586
+ formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
587
+ return formatted
588
+ base_runnable = RunnableParallel({
589
+ "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
590
+ "question": RunnableLambda(self.extract_question)
591
+ }) | self.capture_context
592
+ self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
593
+ debug_print("Chain updated successfully with new LLM pipeline.")
594
+
595
+ def add_pdfs_to_vectore_store(self, file_links: List[str]) -> None:
596
+ debug_print(f"Processing files using {self.llm_choice}")
597
+ self.raw_data = []
598
+ for link in file_links:
599
+ if link.lower().endswith(".pdf"):
600
+ debug_print(f"Loading PDF: {link}")
601
+ loaded_docs = OnlinePDFLoader(link).load()
602
+ if loaded_docs:
603
+ self.raw_data.append(loaded_docs[0])
604
+ else:
605
+ debug_print(f"No content found in PDF: {link}")
606
+ elif link.lower().endswith(".txt") or link.lower().endswith(".utf-8"):
607
+ debug_print(f"Loading TXT: {link}")
608
+ try:
609
+ self.raw_data.append(load_txt_from_url(link))
610
+ except Exception as e:
611
+ debug_print(f"Error loading TXT file {link}: {e}")
612
+ else:
613
+ debug_print(f"File type not supported for URL: {link}")
614
+ if not self.raw_data:
615
+ raise ValueError("No files were successfully loaded. Please check the URLs and file formats.")
616
+ debug_print("Files loaded successfully.")
617
+ debug_print("Starting text splitting...")
618
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
619
+ self.split_data = self.text_splitter.split_documents(self.raw_data)
620
+ if not self.split_data:
621
+ raise ValueError("Text splitting resulted in no chunks. Check the file contents.")
622
+ debug_print(f"Text splitting completed. Number of chunks: {len(self.split_data)}")
623
+ debug_print("Creating BM25 retriever...")
624
+ self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
625
+ self.bm25_retriever.k = self.top_k
626
+ debug_print("BM25 retriever created.")
627
+ debug_print("Embedding chunks and creating FAISS vector store...")
628
+ self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
629
+ self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
630
+ debug_print("FAISS vector store created successfully.")
631
+ self.ensemble_retriever = EnsembleRetriever(
632
+ retrievers=[self.bm25_retriever, self.faiss_retriever],
633
+ weights=[self.bm25_weight, self.faiss_weight]
634
+ )
635
+
636
+ base_runnable = RunnableParallel({
637
+ "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
638
+ "question": RunnableLambda(self.extract_question)
639
+ }) | self.capture_context
640
+
641
+ # Ensure the prompt template is set
642
+ self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
643
+ if self.rag_prompt is None:
644
+ raise ValueError("Prompt template could not be created from the given template.")
645
+ prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
646
+
647
+ self.str_output_parser = StrOutputParser()
648
+ debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
649
+ self.llm = self.create_llm_pipeline()
650
+ if self.llm is None:
651
+ raise ValueError("LLM pipeline creation failed.")
652
+
653
+ def format_response(response: str) -> str:
654
+ input_tokens = count_tokens(self.context + self.prompt_template)
655
+ output_tokens = count_tokens(response)
656
+ formatted = f"### Response\n\n{response}\n\n---\n"
657
+ formatted += f"- **Input tokens:** {input_tokens}\n"
658
+ formatted += f"- **Output tokens:** {output_tokens}\n"
659
+ formatted += f"- **Generated using:** {self.llm_choice}\n"
660
+ formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
661
+ return formatted
662
+
663
+ self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
664
+ debug_print("Elevated RAG chain successfully built and ready to use.")
665
+
666
+
667
+
668
+ def get_current_context(self) -> str:
669
+ base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
670
+ history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
671
+ recent = self.conversation_history[-3:]
672
+ if recent:
673
+ for i, conv in enumerate(recent, 1):
674
+ history_summary += f"**Conversation {i}:**\n- Query: {conv['query']}\n- Response: {conv['response']}\n"
675
+ else:
676
+ history_summary += "No conversation history."
677
+ return base_context + history_summary
678
+
679
+ # ----------------------------
680
+ # Gradio Interface Functions
681
+ # ----------------------------
682
+ global rag_chain
683
+ rag_chain = ElevatedRagChain()
684
+
685
+ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
686
+ debug_print("Inside load_pdfs function.")
687
+ if not file_links:
688
+ debug_print("Please enter non-empty URLs")
689
+ return "Please enter non-empty URLs", "Word count: N/A", "Model used: N/A", "Context: N/A"
690
+ try:
691
+ links = [link.strip() for link in file_links.split("\n") if link.strip()]
692
+ global rag_chain
693
+ if rag_chain.raw_data:
694
+ rag_chain.update_llm_pipeline(model_choice, temperature, top_p, prompt_template, bm25_weight)
695
+ context_display = rag_chain.get_current_context()
696
+ response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
697
+ return (
698
+ response_msg,
699
+ f"Word count: {word_count(rag_chain.context)}",
700
+ f"Model used: {rag_chain.llm_choice}",
701
+ f"Context:\n{context_display}"
702
+ )
703
+ else:
704
+ rag_chain = ElevatedRagChain(
705
+ llm_choice=model_choice,
706
+ prompt_template=prompt_template,
707
+ bm25_weight=bm25_weight,
708
+ temperature=temperature,
709
+ top_p=top_p
710
+ )
711
+ rag_chain.add_pdfs_to_vectore_store(links)
712
+ context_display = rag_chain.get_current_context()
713
+ response_msg = f"Files loaded successfully. Using model: {model_choice}"
714
+ return (
715
+ response_msg,
716
+ f"Word count: {word_count(rag_chain.context)}",
717
+ f"Model used: {rag_chain.llm_choice}",
718
+ f"Context:\n{context_display}"
719
+ )
720
+ except Exception as e:
721
+ error_msg = traceback.format_exc()
722
+ debug_print("Could not load files. Error: " + error_msg)
723
+ return (
724
+ "Error loading files: " + str(e),
725
+ f"Word count: {word_count('')}",
726
+ f"Model used: {rag_chain.llm_choice}",
727
+ "Context: N/A"
728
+ )
729
+
730
+ def update_model(new_model: str):
731
+ global rag_chain
732
+ if rag_chain and rag_chain.raw_data:
733
+ rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p,
734
+ rag_chain.prompt_template, rag_chain.bm25_weight)
735
+ debug_print(f"Model updated to {rag_chain.llm_choice}")
736
+ return f"Model updated to: {rag_chain.llm_choice}"
737
+ else:
738
+ return "No files loaded; please load files first."
739
+
740
+
741
+ # Update submit_query_updated to better handle context limitation
742
+ def submit_query_updated(query):
743
+ debug_print(f"Processing query: {query}")
744
+ if not query:
745
+ debug_print("Empty query received")
746
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
747
+
748
+ if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
749
+ debug_print("RAG chain not initialized")
750
+ return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
751
+
752
+ try:
753
+ # Determine max context size based on model
754
+ model_name = rag_chain.llm_choice.lower()
755
+ max_context_tokens = 32000 if "mistral" in model_name else 4096
756
+
757
+ # Reserve 20% of tokens for the question and response generation
758
+ reserved_tokens = int(max_context_tokens * 0.2)
759
+ max_context_tokens -= reserved_tokens
760
+
761
+ # Collect conversation history (last 2 only to save tokens)
762
+ if rag_chain.conversation_history:
763
+ recent_history = rag_chain.conversation_history[-2:]
764
+ history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response'][:300]}..."
765
+ for conv in recent_history])
766
+ else:
767
+ history_text = ""
768
+
769
+ # Get history token count
770
+ history_tokens = count_tokens(history_text)
771
+
772
+ # Adjust context tokens based on history size
773
+ context_tokens = max_context_tokens - history_tokens
774
+
775
+ # Ensure we have some minimum context
776
+ context_tokens = max(context_tokens, 1000)
777
+
778
+ # Truncate context if needed
779
+ context = truncate_prompt(rag_chain.context, max_tokens=context_tokens)
780
+
781
+ debug_print(f"Using model: {model_name}, context tokens: {count_tokens(context)}, history tokens: {history_tokens}")
782
+
783
+ prompt_variables = {
784
+ "conversation_history": history_text,
785
+ "context": context,
786
+ "question": query
787
+ }
788
+
789
+ debug_print("Invoking RAG chain")
790
+ response = rag_chain.elevated_rag_chain.invoke({"question": query})
791
+
792
+ # Store only a reasonable amount of the response in history
793
+ trimmed_response = response[:1000] + ("..." if len(response) > 1000 else "")
794
+ rag_chain.conversation_history.append({"query": query, "response": trimmed_response})
795
+
796
+ input_token_count = count_tokens(query)
797
+ output_token_count = count_tokens(response)
798
+
799
+ debug_print(f"Query processed successfully. Output tokens: {output_token_count}")
800
+
801
+ return (
802
+ response,
803
+ rag_chain.get_current_context(),
804
+ f"Input tokens: {input_token_count}",
805
+ f"Output tokens: {output_token_count}"
806
+ )
807
+ except Exception as e:
808
+ error_msg = traceback.format_exc()
809
+ debug_print(f"LLM error: {error_msg}")
810
+ return (
811
+ f"Query error: {str(e)}\n\nTry using a smaller document or simplifying your query.",
812
+ "",
813
+ "Input tokens: 0",
814
+ "Output tokens: 0"
815
+ )
816
+
817
+ def reset_app_updated():
818
+ global rag_chain
819
+ rag_chain = ElevatedRagChain()
820
+ debug_print("App reset successfully.")
821
+ return (
822
+ "App reset successfully. You can now load new files",
823
+ "",
824
+ "Model used: Not selected"
825
+ )
826
+
827
+ # ----------------------------
828
+ # Gradio Interface Setup
829
+ # ----------------------------
830
+ custom_css = """
831
+ textarea {
832
+ overflow-y: scroll !important;
833
+ max-height: 200px;
834
+ }
835
+ """
836
+
837
+ # Update the Gradio interface to include job status checking
838
+ with gr.Blocks(css=custom_css, js="""
839
+ document.addEventListener('DOMContentLoaded', function() {
840
+ // Add event listener for job list clicks
841
+ const jobListInterval = setInterval(() => {
842
+ const jobLinks = document.querySelectorAll('.job-list-container a');
843
+ if (jobLinks.length > 0) {
844
+ jobLinks.forEach(link => {
845
+ link.addEventListener('click', function(e) {
846
+ e.preventDefault();
847
+ const jobId = this.textContent.split(' ')[0];
848
+ // Find the job ID input textbox and set its value
849
+ const jobIdInput = document.querySelector('.job-id-input input');
850
+ if (jobIdInput) {
851
+ jobIdInput.value = jobId;
852
+ // Trigger the input event to update Gradio's state
853
+ jobIdInput.dispatchEvent(new Event('input', { bubbles: true }));
854
+ }
855
+ });
856
+ });
857
+ clearInterval(jobListInterval);
858
+ }
859
+ }, 500);
860
+ });
861
+ """) as app:
862
+ gr.Markdown('''# PhiRAG - Async Version
863
+ **PhiRAG** Query Your Data with Advanced RAG Techniques
864
+
865
+ **Model Selection & Parameters:** Choose from the following options:
866
+ - 🇺🇸 Remote Meta-Llama-3 - has context windows of 8000 tokens
867
+ - 🇪🇺 Mistral-API - has context windows of 32000 tokens
868
+
869
+ **🔥 Randomness (Temperature):** Adjusts output predictability.
870
+ - Example: 0.2 makes the output very deterministic (less creative), while 0.8 introduces more variety and spontaneity.
871
+
872
+ **🎯 Word Variety (Top‑p):** Limits word choices to a set probability percentage.
873
+ - Example: 0.5 restricts output to the most likely 50% of token choices for a focused answer; 0.95 allows almost all possibilities for more diverse responses.
874
+
875
+ **⚖️ BM25 Weight:** Adjust Lexical vs Semantics.
876
+ - Example: A value of 0.8 puts more emphasis on exact keyword (lexical) matching, while 0.3 shifts emphasis toward semantic similarity.
877
+
878
+ **✏️ Prompt Template:** Edit as desired.
879
+
880
+ **🔗 File URLs:** Enter one URL per line (.pdf or .txt).\
881
+ - Example: Provide one URL per line, such as
882
+ https://www.gutenberg.org/ebooks/8438.txt.utf-8
883
+
884
+ **🔍 Query:** Enter your query below.
885
+
886
+ **⚠️ IMPORTANT: This app now uses asynchronous processing to avoid timeout issues**
887
+ - When you load files or submit a query, you'll receive a Job ID
888
+ - Use the "Check Job Status" tab to monitor and retrieve your results
889
+ ''')
890
+
891
+ with gr.Tabs() as tabs:
892
+ with gr.TabItem("Setup & Load Files"):
893
+ with gr.Row():
894
+ with gr.Column():
895
+ model_dropdown = gr.Dropdown(
896
+ choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
897
+ value="🇺🇸 Remote Meta-Llama-3",
898
+ label="Select Model"
899
+ )
900
+ temperature_slider = gr.Slider(
901
+ minimum=0.1, maximum=1.0, value=0.5, step=0.1,
902
+ label="Randomness (Temperature)"
903
+ )
904
+ top_p_slider = gr.Slider(
905
+ minimum=0.1, maximum=0.99, value=0.95, step=0.05,
906
+ label="Word Variety (Top-p)"
907
+ )
908
+ with gr.Column():
909
+ pdf_input = gr.Textbox(
910
+ label="Enter your file URLs (one per line)",
911
+ placeholder="Enter one URL per line (.pdf or .txt)",
912
+ lines=4
913
+ )
914
+ prompt_input = gr.Textbox(
915
+ label="Custom Prompt Template",
916
+ placeholder="Enter your custom prompt template here",
917
+ lines=8,
918
+ value=default_prompt
919
+ )
920
+ with gr.Column():
921
+ bm25_weight_slider = gr.Slider(
922
+ minimum=0.0, maximum=1.0, value=0.6, step=0.1,
923
+ label="Lexical vs Semantics (BM25 Weight)"
924
+ )
925
+ load_button = gr.Button("Load Files (Async)")
926
+ load_status = gr.Markdown("Status: Waiting for files")
927
+
928
+ with gr.Row():
929
+ load_response = gr.Textbox(
930
+ label="Load Response",
931
+ placeholder="Response will appear here",
932
+ lines=4
933
+ )
934
+ load_context = gr.Textbox(
935
+ label="Context Info",
936
+ placeholder="Context info will appear here",
937
+ lines=4
938
+ )
939
+
940
+ with gr.Row():
941
+ model_output = gr.Markdown("**Current Model**: Not selected")
942
+
943
+ with gr.TabItem("Submit Query"):
944
+ with gr.Row():
945
+ # Add this line to define the query_model_dropdown
946
+ query_model_dropdown = gr.Dropdown(
947
+ choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
948
+ value="🇺🇸 Remote Meta-Llama-3",
949
+ label="Query Model"
950
+ )
951
+
952
+ query_input = gr.Textbox(
953
+ label="Enter your query here",
954
+ placeholder="Type your query",
955
+ lines=4
956
+ )
957
+ submit_button = gr.Button("Submit Query (Async)")
958
+
959
+ with gr.Row():
960
+ query_response = gr.Textbox(
961
+ label="Query Response",
962
+ placeholder="Response will appear here (formatted as Markdown)",
963
+ lines=6
964
+ )
965
+ query_context = gr.Textbox(
966
+ label="Context Information",
967
+ placeholder="Retrieved context and conversation history will appear here",
968
+ lines=6
969
+ )
970
+
971
+ with gr.Row():
972
+ input_tokens = gr.Markdown("Input tokens: 0")
973
+ output_tokens = gr.Markdown("Output tokens: 0")
974
+
975
+ with gr.TabItem("Check Job Status"):
976
+ with gr.Row():
977
+ with gr.Column(scale=1):
978
+ job_list = gr.Markdown(
979
+ value="No jobs yet",
980
+ label="Job List (Click to select)"
981
+ )
982
+ refresh_button = gr.Button("Refresh Job List")
983
+
984
+ with gr.Column(scale=2):
985
+ job_id_input = gr.Textbox(
986
+ label="Job ID",
987
+ placeholder="Job ID will appear here when selected from the list",
988
+ lines=1
989
+ )
990
+ job_query_display = gr.Textbox(
991
+ label="Job Query",
992
+ placeholder="The query associated with this job will appear here",
993
+ lines=2,
994
+ interactive=False
995
+ )
996
+ check_button = gr.Button("Check Status")
997
+ cleanup_button = gr.Button("Cleanup Old Jobs")
998
+
999
+ with gr.Row():
1000
+ status_response = gr.Textbox(
1001
+ label="Job Result",
1002
+ placeholder="Job result will appear here",
1003
+ lines=6
1004
+ )
1005
+ status_context = gr.Textbox(
1006
+ label="Context Information",
1007
+ placeholder="Context information will appear here",
1008
+ lines=6
1009
+ )
1010
+
1011
+ with gr.Row():
1012
+ status_tokens1 = gr.Markdown("")
1013
+ status_tokens2 = gr.Markdown("")
1014
+
1015
+ with gr.TabItem("App Management"):
1016
+ with gr.Row():
1017
+ reset_button = gr.Button("Reset App")
1018
+
1019
+ with gr.Row():
1020
+ reset_response = gr.Textbox(
1021
+ label="Reset Response",
1022
+ placeholder="Reset confirmation will appear here",
1023
+ lines=2
1024
+ )
1025
+ reset_context = gr.Textbox(
1026
+ label="",
1027
+ placeholder="",
1028
+ lines=2,
1029
+ visible=False
1030
+ )
1031
+
1032
+ with gr.Row():
1033
+ reset_model = gr.Markdown("")
1034
+
1035
+ # Connect the buttons to their respective functions
1036
+ load_button.click(
1037
+ load_pdfs_async,
1038
+ inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
1039
+ outputs=[load_response, load_context, model_output, job_id_input, job_query_display, job_list]
1040
+ )
1041
+
1042
+ # Also sync in the other direction
1043
+ query_model_dropdown.change(
1044
+ fn=sync_model_dropdown,
1045
+ inputs=query_model_dropdown,
1046
+ outputs=model_dropdown
1047
+ )
1048
+
1049
+ submit_button.click(
1050
+ submit_query_async,
1051
+ inputs=[query_input, query_model_dropdown],
1052
+ outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
1053
+ )
1054
+
1055
+ check_button.click(
1056
+ check_job_status,
1057
+ inputs=[job_id_input],
1058
+ outputs=[status_response, status_context, status_tokens1, status_tokens2, job_query_display]
1059
+ )
1060
+
1061
+ refresh_button.click(
1062
+ refresh_job_list,
1063
+ inputs=[],
1064
+ outputs=[job_list]
1065
+ )
1066
+
1067
+ # Connect the job list selection event (this is handled by JavaScript)
1068
+ job_id_input.change(
1069
+ job_selected,
1070
+ inputs=[job_id_input],
1071
+ outputs=[job_id_input, job_query_display]
1072
+ )
1073
+
1074
+ cleanup_button.click(
1075
+ cleanup_old_jobs,
1076
+ inputs=[],
1077
+ outputs=[status_response, status_context, status_tokens1]
1078
+ )
1079
+
1080
+ reset_button.click(
1081
+ reset_app_updated,
1082
+ inputs=[],
1083
+ outputs=[reset_response, reset_context, reset_model]
1084
+ )
1085
+
1086
+
1087
+ model_dropdown.change(
1088
+ fn=sync_model_dropdown,
1089
+ inputs=model_dropdown,
1090
+ outputs=query_model_dropdown
1091
+ )
1092
+
1093
+ # Add an event to refresh the job list on page load
1094
+ app.load(
1095
+ fn=refresh_job_list,
1096
+ inputs=None,
1097
+ outputs=job_list
1098
+ )
1099
+
1100
+ if __name__ == "__main__":
1101
+ debug_print("Launching Gradio interface.")
1102
+ app.launch(share=False)