Spaces:
Runtime error
Runtime error
| from huggingface_hub import login | |
| from fastapi import FastAPI, Depends, HTTPException | |
| import logging | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModel | |
| from services.qdrant_searcher import QdrantSearcher | |
| from services.openai_service import generate_rag_response | |
| from utils.auth import token_required | |
| from dotenv import load_dotenv | |
| import os | |
| import torch | |
| from utils.auth_x import x_api_key_auth | |
| import time | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Initialize FastAPI application | |
| app = FastAPI() | |
| # Set the cache directory for Hugging Face | |
| os.environ["HF_HOME"] = "/tmp/huggingface_cache" | |
| # Ensure the cache directory exists | |
| hf_home_dir = os.environ["HF_HOME"] | |
| if not os.path.exists(hf_home_dir): | |
| os.makedirs(hf_home_dir) | |
| collection_name = os.getenv("QDRANT_COLLECTION_NAME") | |
| logging.info(f"Collection name: {collection_name}") | |
| # Setup logging using Python's standard logging library | |
| logging.basicConfig(level=logging.INFO) | |
| # Load Hugging Face token from environment variable | |
| huggingface_token = os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| if huggingface_token: | |
| try: | |
| login(token=huggingface_token, add_to_git_credential=True) | |
| logging.info("Successfully logged into Hugging Face Hub.") | |
| except Exception as e: | |
| logging.error(f"Failed to log into Hugging Face Hub: {e}") | |
| raise HTTPException( | |
| status_code=500, detail="Failed to log into Hugging Face Hub." | |
| ) | |
| else: | |
| raise ValueError( | |
| "Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable." | |
| ) | |
| # Initialize the Qdrant searcher | |
| qdrant_url = os.getenv("QDRANT_URL") | |
| access_token = os.getenv("QDRANT_ACCESS_TOKEN") | |
| if not qdrant_url or not access_token: | |
| raise ValueError( | |
| "Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables." | |
| ) | |
| # Load the model and tokenizer with trust_remote_code=True | |
| try: | |
| cache_folder = os.path.join(hf_home_dir, "transformers_cache") | |
| # Load the tokenizer and model with trust_remote_code=True | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True | |
| ) | |
| model = AutoModel.from_pretrained( | |
| "nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True | |
| ) | |
| logging.info("Successfully loaded the model and tokenizer with transformers.") | |
| # Initialize the Qdrant searcher after the model is successfully loaded | |
| global searcher # Ensure searcher is accessible globally if needed | |
| searcher = QdrantSearcher(qdrant_url=qdrant_url, access_token=access_token) | |
| except Exception as e: | |
| logging.error(f"Failed to load the model or initialize searcher: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Failed to load the custom model or initialize searcher.", | |
| ) | |
| # Function to embed text using the model | |
| def embed_text(text): | |
| inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") | |
| outputs = model(**inputs) | |
| embeddings = outputs.last_hidden_state.mean(dim=1) # Example: mean pooling | |
| return embeddings.detach().numpy() | |
| # Define the request body models | |
| class SearchDocumentsRequest(BaseModel): | |
| query: str | |
| limit: int = 3 | |
| file_id: str = None | |
| class GenerateRAGRequest(BaseModel): | |
| search_query: str | |
| file_id: str = None | |
| class XApiKeyRequest(BaseModel): | |
| organization_id: str | |
| user_id: str | |
| search_query: str | |
| file_id: str = None | |
| async def root(): | |
| return { | |
| "message": "Welcome to the Search and RAG API!, go to relevant address for API request" | |
| } | |
| # Define the search documents endpoint | |
| async def search_documents( | |
| body: SearchDocumentsRequest, credentials: tuple = Depends(token_required) | |
| ): | |
| customer_id, user_id = credentials | |
| start_time = time.time() | |
| if not customer_id or not user_id: | |
| logging.error("Failed to extract customer_id or user_id from the JWT token.") | |
| raise HTTPException( | |
| status_code=401, detail="Invalid token: missing customer_id or user_id" | |
| ) | |
| logging.info("Received request to search documents") | |
| try: | |
| logging.info("Starting document search") | |
| # Encode the query using the custom embedding function | |
| query_embedding = embed_text(body.query) | |
| print(body.query) | |
| # collection_name = "embed" # Use the collection name where the embeddings are stored | |
| logging.info("Performing search using the precomputed embeddings") | |
| if body.file_id: | |
| hits, error = searcher.search_documents( | |
| collection_name, | |
| query_embedding, | |
| user_id, | |
| body.limit, | |
| file_id=body.file_id, | |
| ) | |
| else: | |
| # Perform search using the precomputed embeddings | |
| hits, error = searcher.search_documents( | |
| collection_name, query_embedding, user_id, body.limit | |
| ) | |
| if error: | |
| logging.error(f"Search documents error: {error}") | |
| raise HTTPException(status_code=500, detail=error) | |
| end_time = time.time() | |
| time_taken = end_time - start_time | |
| return hits, time_taken | |
| except Exception as e: | |
| logging.error(f"Unexpected error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Define the generate RAG response endpoint | |
| async def generate_rag_response_api( | |
| body: GenerateRAGRequest, credentials: tuple = Depends(token_required) | |
| ): | |
| customer_id, user_id = credentials | |
| start_time = time.time() | |
| if not customer_id or not user_id: | |
| logging.error("Failed to extract customer_id or user_id from the JWT token.") | |
| raise HTTPException( | |
| status_code=401, detail="Invalid token: missing customer_id or user_id" | |
| ) | |
| logging.info("Received request to generate RAG response") | |
| try: | |
| search_time = time.time() | |
| logging.info("Starting document search") | |
| # Encode the query using the custom embedding function | |
| query_embedding = embed_text(body.search_query) | |
| print(body.search_query) | |
| # collection_name = "embed" # Use the collection name where the embeddings are stored | |
| # Perform search using the precomputed embeddings | |
| if body.file_id: | |
| hits, error = searcher.search_documents( | |
| collection_name, query_embedding, user_id, file_id=body.file_id | |
| ) | |
| else: | |
| hits, error = searcher.search_documents( | |
| collection_name, query_embedding, user_id | |
| ) | |
| if error: | |
| logging.error(f"Search documents error: {error}") | |
| raise HTTPException(status_code=500, detail=error) | |
| logging.info("Generating RAG response") | |
| end_search_time = time.time() | |
| search_time_taken = end_search_time - search_time | |
| rag_start_time = time.time() | |
| # Generate the RAG response using the retrieved documents | |
| response, error = generate_rag_response(hits, body.search_query) | |
| rag_end_time = time.time() | |
| rag_time_taken = rag_end_time - rag_start_time | |
| end_time = time.time() | |
| total_time = end_time - start_time | |
| logging.info( | |
| f"Search time: {search_time_taken}, RAG time: {rag_time_taken}, Total time: {total_time}" | |
| ) | |
| if error: | |
| logging.error(f"Generate RAG response error: {error}") | |
| raise HTTPException(status_code=500, detail=error) | |
| return {"response": response} | |
| except Exception as e: | |
| logging.error(f"Unexpected error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def search_documents_x_api_key( | |
| body: XApiKeyRequest, authorized: bool = Depends(x_api_key_auth) | |
| ): | |
| if not authorized: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| start_time = time.time() | |
| organization_id = body.organization_id | |
| user_id = body.user_id | |
| file_id = body.file_id | |
| logging.info(f"search query {body.search_query}") | |
| logging.info(f"organization_id: {organization_id}, user_id: {user_id}") | |
| logging.info("Received request to search documents with x-api-key auth") | |
| try: | |
| logging.info("Starting document search") | |
| # Encode the query using the custom embedding function | |
| query_embedding = embed_text(body.search_query) | |
| # collection_name = "embed" # Use the collection name where the embeddings are stored | |
| # Perform search using the precomputed embeddings | |
| hits, error = searcher.search_documents( | |
| collection_name, query_embedding, user_id, limit=3, file_id=file_id | |
| ) | |
| if error: | |
| logging.error(f"Search documents error: {error}") | |
| raise HTTPException(status_code=500, detail=error) | |
| logging.info(f"Document search completed with {len(hits)} hits") | |
| end_time = time.time() | |
| logging.info(f"Time taken: {end_time - start_time}") | |
| return hits | |
| except Exception as e: | |
| logging.error(f"Unexpected error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_rag_response_x_api_key( | |
| body: XApiKeyRequest, authorized: bool = Depends(x_api_key_auth) | |
| ): | |
| # Assuming x_api_key_auth validates the key | |
| if not authorized: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| start_time = time.time() | |
| organization_id = body.organization_id | |
| user_id = body.user_id | |
| file_id = body.file_id | |
| logging.info(f"search query {body.search_query}") | |
| logging.info(f"organization_id: {organization_id}, user_id: {user_id}") | |
| logging.info("Received request to generate RAG response with x-api-key auth") | |
| try: | |
| logging.info("Starting document search") | |
| # Encode the query using the custom embedding function | |
| query_embedding = embed_text(body.search_query) | |
| # collection_name = "embed" # Use the collection name where the embeddings are stored | |
| # Perform search using the precomputed embeddings | |
| hits, error = searcher.search_documents( | |
| collection_name, query_embedding, user_id, file_id=file_id | |
| ) | |
| if error: | |
| logging.error(f"Search documents error: {error}") | |
| raise HTTPException(status_code=500, detail=error) | |
| logging.info("Generating RAG response") | |
| # Generate the RAG response using the retrieved documents | |
| response, error = generate_rag_response(hits, body.search_query) | |
| if error: | |
| logging.error(f"Generate RAG response error: {error}") | |
| raise HTTPException(status_code=500, detail=error) | |
| end_time = time.time() | |
| logging.info(f"Time taken: {end_time - start_time}") | |
| return {"response": response} | |
| except Exception as e: | |
| logging.error(f"Unexpected error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |