from src.utils.logger import logger from motor.motor_asyncio import AsyncIOMotorClient from pydantic import BaseModel from typing import Type, Dict, List, Optional from bson import ObjectId from motor.motor_asyncio import AsyncIOMotorCollection from datetime import datetime, timezone, timedelta from src.utils.logger import get_date_time import os client: AsyncIOMotorClient = AsyncIOMotorClient(os.getenv("MONGO_CONNECTION_STR")) # database = client["custom_gpt"] database = client["prompt_editor"] class MongoCRUD: def __init__( self, collection: AsyncIOMotorCollection, model: Type[BaseModel], ttl_seconds: Optional[int] = None, ): self.collection = collection self.model = model self.ttl_seconds = ttl_seconds self._index_created = False async def _ensure_ttl_index(self): """Ensure TTL index exists""" if self.ttl_seconds is not None and not self._index_created: await self.collection.create_index("expire_at", expireAfterSeconds=0) self._index_created = True def _order_fields(self, doc: Dict) -> Dict: """Order fields in the document to ensure created_at and updated_at are at the end.""" ordered_doc = { k: doc[k] for k in doc if k not in ["created_at", "updated_at", "expire_at"] } if "id" in doc: ordered_doc["_id"] = ObjectId(doc["id"]) if "created_at" in doc: ordered_doc["created_at"] = doc["created_at"] if "updated_at" in doc: ordered_doc["updated_at"] = doc["updated_at"] if "expire_at" in doc: ordered_doc["expire_at"] = doc["expire_at"] return ordered_doc async def create(self, data: Dict) -> str: """Create a new document in the collection asynchronously, optionally using a user-specified ID.""" await self._ensure_ttl_index() now = get_date_time().replace(tzinfo=None) data["created_at"] = now data["updated_at"] = now if self.ttl_seconds is not None: data["expire_at"] = now + timedelta(seconds=self.ttl_seconds) document = self.model(**data).model_dump(exclude_unset=True) ordered_document = self._order_fields(document) result = await self.collection.insert_one(ordered_document) return str(result.inserted_id) async def read(self, query: Dict, sort: List[tuple] = None) -> List[Dict]: """Read documents from the collection based on a query asynchronously.""" cursor = self.collection.find(query) # Apply sorting if provided if sort: cursor = cursor.sort(sort) docs = [] async for doc in cursor: docs.append( { "_id": str(doc["_id"]), **self._order_fields(self.model(**doc).model_dump(exclude={"id"})), } ) return docs async def read_one(self, query: Dict) -> Optional[Dict]: """Read a single document from the collection based on a query asynchronously.""" doc = await self.collection.find_one(query) if doc: doc["_id"] = str(doc["_id"]) return { "_id": doc["_id"], **self._order_fields(self.model(**doc).model_dump(exclude={"id"})), } return None async def update(self, query: Dict, data: Dict, upsert: bool = False) -> int: await self._ensure_ttl_index() if any(key.startswith("$") for key in data.keys()): update_data = data else: data["updated_at"] = get_date_time().replace(tzinfo=None) if self.ttl_seconds is not None: data["expire_at"] = data["updated_at"] + timedelta( seconds=self.ttl_seconds ) update_data = { "$set": self._order_fields( self.model(**data).model_dump(exclude_unset=True) ) } result = await self.collection.update_many(query, update_data, upsert=upsert) return result.modified_count async def delete(self, query: Dict) -> int: """Delete documents from the collection based on a query asynchronously.""" result = await self.collection.delete_many(query) return result.deleted_count async def delete_one(self, query: Dict) -> int: """Delete a single document from the collection based on a query asynchronously.""" result = await self.collection.delete_one(query) return result.deleted_count async def find_by_id(self, id: str) -> Optional[Dict]: """Find a document by its ID asynchronously.""" return await self.read_one({"_id": ObjectId(id)}) async def find_all(self) -> List[Dict]: """Find all documents in the collection asynchronously.""" return await self.read({}) async def find_many( self, filter: Dict, skip: int = 0, limit: int = 0, sort: List[tuple] = None ) -> List[Dict]: """ Find documents based on filter with pagination support. Args: filter: MongoDB query filter skip: Number of documents to skip limit: Maximum number of documents to return (0 means no limit) sort: Optional sorting parameters [(field_name, direction)] where direction is 1 for ascending, -1 for descending Returns: List of documents matching the filter """ cursor = self.collection.find(filter) # Apply pagination if skip > 0: cursor = cursor.skip(skip) if limit > 0: cursor = cursor.limit(limit) # Apply sorting if provided if sort: cursor = cursor.sort(sort) docs = [] async for doc in cursor: # Convert _id to string and prepare document doc_id = str(doc["_id"]) doc_copy = {**doc} doc_copy["_id"] = doc_id # Process through model validation try: validated_doc = self.model(**doc_copy).model_dump(exclude={"id"}) docs.append({"_id": doc_id, **self._order_fields(validated_doc)}) except Exception as e: logger.error(f"Error validating document {doc_id}: {str(e)}") # Include document even if validation fails, but with original data docs.append( {"_id": doc_id, **{k: v for k, v in doc.items() if k != "_id"}} ) return docs from src.apis.models.user_models import User from src.apis.models.prompt_models import Prompt UserCRUD = MongoCRUD(database["users"], User) PromptCRUD = MongoCRUD(database["prompt_templates"], Prompt)