Spaces:
Sleeping
Sleeping
| """ | |
| MAKER Agent - Chat Interface | |
| ============================= | |
| Reliable AI Agent with Web Search & File Upload | |
| Based on: https://arxiv.org/abs/2511.09030 | |
| """ | |
| import gradio as gr | |
| import asyncio | |
| import json | |
| import re | |
| import base64 | |
| from collections import Counter | |
| from dataclasses import dataclass, field | |
| from typing import Any, Callable, Optional | |
| from pathlib import Path | |
| # ============================================================================ | |
| # MAKER Core (Embedded) | |
| # ============================================================================ | |
| class VotingConfig: | |
| k: int = 3 | |
| max_samples: int = 30 | |
| temperature_first: float = 0.0 | |
| temperature_rest: float = 0.1 | |
| parallel_samples: int = 3 | |
| class RedFlagConfig: | |
| max_response_chars: int = 3000 | |
| min_response_length: int = 5 | |
| banned_patterns: list = field(default_factory=lambda: []) | |
| class LLMClient: | |
| """Universal LLM client.""" | |
| def __init__(self, provider: str, api_key: str, model: str = None): | |
| self.provider = provider.lower() | |
| self.api_key = api_key | |
| self.model = model | |
| self._client = None | |
| self._setup_client() | |
| def _setup_client(self): | |
| if self.provider == "openai": | |
| from openai import AsyncOpenAI | |
| self._client = AsyncOpenAI(api_key=self.api_key) | |
| self.model = self.model or "gpt-4o-mini" | |
| elif self.provider == "anthropic": | |
| from anthropic import AsyncAnthropic | |
| self._client = AsyncAnthropic(api_key=self.api_key) | |
| self.model = self.model or "claude-sonnet-4-20250514" | |
| elif self.provider == "groq": | |
| from openai import AsyncOpenAI | |
| self._client = AsyncOpenAI(api_key=self.api_key, base_url="https://api.groq.com/openai/v1") | |
| self.model = self.model or "llama-3.3-70b-versatile" | |
| elif self.provider == "together": | |
| from openai import AsyncOpenAI | |
| self._client = AsyncOpenAI(api_key=self.api_key, base_url="https://api.together.xyz/v1") | |
| self.model = self.model or "meta-llama/Llama-3.3-70B-Instruct-Turbo" | |
| elif self.provider == "openrouter": | |
| from openai import AsyncOpenAI | |
| self._client = AsyncOpenAI(api_key=self.api_key, base_url="https://openrouter.ai/api/v1") | |
| self.model = self.model or "openai/gpt-4o-mini" | |
| async def generate(self, prompt: str, temperature: float = 0.0, max_tokens: int = 2000) -> str: | |
| if self.provider == "anthropic": | |
| r = await self._client.messages.create( | |
| model=self.model, max_tokens=max_tokens, | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| return r.content[0].text | |
| else: | |
| r = await self._client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=temperature, max_tokens=max_tokens | |
| ) | |
| return r.choices[0].message.content | |
| class WebSearch: | |
| """Web search using DuckDuckGo (free).""" | |
| async def search(query: str, num_results: int = 5) -> list: | |
| try: | |
| from duckduckgo_search import DDGS | |
| results = [] | |
| with DDGS() as ddgs: | |
| for r in ddgs.text(query, max_results=num_results): | |
| results.append({ | |
| "title": r.get("title", ""), | |
| "url": r.get("href", ""), | |
| "snippet": r.get("body", "") | |
| }) | |
| return results | |
| except Exception as e: | |
| return [{"title": "Error", "url": "", "snippet": str(e)}] | |
| class FileHandler: | |
| """Handle file uploads.""" | |
| async def load_file(file_path: str) -> dict: | |
| path = Path(file_path) | |
| ext = path.suffix.lower() | |
| try: | |
| if ext in {'.txt', '.md', '.json', '.py', '.js', '.html', '.css', '.csv'}: | |
| content = path.read_text(encoding='utf-8', errors='replace') | |
| return {"type": "text", "name": path.name, "content": content[:50000]} | |
| elif ext == '.pdf': | |
| try: | |
| import pymupdf | |
| doc = pymupdf.open(str(path)) | |
| text = "\n\n".join([page.get_text() for page in doc]) | |
| doc.close() | |
| return {"type": "pdf", "name": path.name, "content": text[:50000]} | |
| except ImportError: | |
| return {"type": "error", "name": path.name, "content": "PDF requires: pip install pymupdf"} | |
| elif ext == '.docx': | |
| try: | |
| from docx import Document | |
| doc = Document(str(path)) | |
| text = "\n\n".join([p.text for p in doc.paragraphs]) | |
| return {"type": "docx", "name": path.name, "content": text[:50000]} | |
| except ImportError: | |
| return {"type": "error", "name": path.name, "content": "DOCX requires: pip install python-docx"} | |
| elif ext in {'.png', '.jpg', '.jpeg', '.gif', '.webp'}: | |
| content = path.read_bytes() | |
| b64 = base64.b64encode(content).decode('utf-8') | |
| return {"type": "image", "name": path.name, "base64": b64} | |
| else: | |
| content = path.read_text(encoding='utf-8', errors='replace') | |
| return {"type": "text", "name": path.name, "content": content[:50000]} | |
| except Exception as e: | |
| return {"type": "error", "name": path.name, "content": str(e)} | |
| class MAKERAgent: | |
| """MAKER Framework Agent.""" | |
| def __init__(self, llm: LLMClient, voting: VotingConfig = None, red_flags: RedFlagConfig = None): | |
| self.llm = llm | |
| self.voting = voting or VotingConfig() | |
| self.red_flags = red_flags or RedFlagConfig() | |
| self.stats = {"samples": 0, "red_flags": 0, "tool_calls": 0} | |
| def _check_red_flags(self, response: str) -> bool: | |
| if len(response) > self.red_flags.max_response_chars: | |
| return True | |
| if len(response) < self.red_flags.min_response_length: | |
| return True | |
| for pattern in self.red_flags.banned_patterns: | |
| if re.search(pattern, response, re.IGNORECASE): | |
| return True | |
| return False | |
| def _normalize_response(self, response: str) -> str: | |
| """Normalize response for voting comparison.""" | |
| return response.strip().lower() | |
| async def execute(self, prompt: str, use_search: bool = False, | |
| file_context: str = None, progress_callback: Callable = None) -> dict: | |
| # Build the full prompt | |
| full_prompt = "You are a helpful assistant. Respond naturally and conversationally.\n\n" | |
| if file_context: | |
| full_prompt += f"The user has provided the following files for context:\n{file_context}\n\n" | |
| full_prompt += f"User: {prompt}\n\nAssistant:" | |
| # Handle web search if enabled | |
| search_results = None | |
| if use_search: | |
| if progress_callback: | |
| progress_callback(0.1, "Searching the web...") | |
| search_results = await WebSearch.search(prompt) | |
| self.stats["tool_calls"] += 1 | |
| if search_results and search_results[0].get("title") != "Error": | |
| search_text = "\n".join([f"- {r['title']}: {r['snippet']}" for r in search_results[:5]]) | |
| full_prompt = f"You are a helpful assistant with access to web search results.\n\n" | |
| if file_context: | |
| full_prompt += f"Files provided:\n{file_context}\n\n" | |
| full_prompt += f"Web search results for '{prompt}':\n{search_text}\n\n" | |
| full_prompt += f"User question: {prompt}\n\nProvide a helpful response based on the search results. Assistant:" | |
| if progress_callback: | |
| progress_callback(0.2, "Getting response...") | |
| # Voting loop | |
| votes: Counter = Counter() | |
| responses_map = {} | |
| samples, flagged = 0, 0 | |
| # First sample at temperature 0 | |
| response = await self.llm.generate(full_prompt, temperature=0.0) | |
| samples += 1 | |
| self.stats["samples"] += 1 | |
| if not self._check_red_flags(response): | |
| key = self._normalize_response(response) | |
| votes[key] += 1 | |
| responses_map[key] = response | |
| else: | |
| flagged += 1 | |
| self.stats["red_flags"] += 1 | |
| # Continue voting until we have a winner | |
| round_num = 1 | |
| while samples < self.voting.max_samples: | |
| if votes: | |
| top = votes.most_common(2) | |
| top_count = top[0][1] | |
| second_count = top[1][1] if len(top) > 1 else 0 | |
| if top_count - second_count >= self.voting.k: | |
| break | |
| round_num += 1 | |
| if progress_callback: | |
| progress_callback(0.2 + 0.7 * (samples / self.voting.max_samples), f"Voting round {round_num}...") | |
| for _ in range(self.voting.parallel_samples): | |
| if samples >= self.voting.max_samples: | |
| break | |
| response = await self.llm.generate(full_prompt, temperature=self.voting.temperature_rest) | |
| samples += 1 | |
| self.stats["samples"] += 1 | |
| if not self._check_red_flags(response): | |
| key = self._normalize_response(response) | |
| votes[key] += 1 | |
| if key not in responses_map: | |
| responses_map[key] = response | |
| else: | |
| flagged += 1 | |
| self.stats["red_flags"] += 1 | |
| if progress_callback: | |
| progress_callback(1.0, "Done!") | |
| if votes: | |
| top_key, top_count = votes.most_common(1)[0] | |
| return { | |
| "success": True, | |
| "response": responses_map[top_key], | |
| "votes": top_count, | |
| "total_samples": samples, | |
| "red_flagged": flagged, | |
| "search_results": search_results | |
| } | |
| return { | |
| "success": False, | |
| "response": "I couldn't generate a reliable response. Please try again.", | |
| "votes": 0, | |
| "total_samples": samples, | |
| "red_flagged": flagged, | |
| "search_results": search_results | |
| } | |
| # ============================================================================ | |
| # Global State | |
| # ============================================================================ | |
| current_agent = None | |
| loaded_files = {} | |
| # ============================================================================ | |
| # Functions | |
| # ============================================================================ | |
| def setup_agent(provider, api_key, model, k_votes): | |
| global current_agent | |
| if not api_key: | |
| return "β Please enter your API key", gr.update(interactive=False) | |
| try: | |
| llm = LLMClient(provider, api_key, model if model else None) | |
| current_agent = MAKERAgent(llm, VotingConfig(k=k_votes)) | |
| return f"β Connected to {provider} ({llm.model})", gr.update(interactive=True) | |
| except Exception as e: | |
| return f"β Error: {e}", gr.update(interactive=False) | |
| def process_files(files): | |
| global loaded_files | |
| loaded_files = {} | |
| if not files: | |
| return "No files attached" | |
| names = [] | |
| for f in files: | |
| info = asyncio.run(FileHandler.load_file(f.name)) | |
| loaded_files[info['name']] = info | |
| names.append(info['name']) | |
| return f"π {', '.join(names)}" | |
| async def chat_async(message, history, use_search, files, progress=gr.Progress()): | |
| global current_agent, loaded_files | |
| if not current_agent: | |
| return history + [[message, "β οΈ Please set up your API key first in the Settings tab."]] | |
| # Process any new files | |
| if files: | |
| for f in files: | |
| info = await FileHandler.load_file(f.name) | |
| loaded_files[info['name']] = info | |
| # Build file context | |
| file_context = None | |
| if loaded_files: | |
| parts = [] | |
| for name, info in loaded_files.items(): | |
| if info["type"] != "image" and info["type"] != "error": | |
| parts.append(f"=== {name} ===\n{info.get('content', '')[:10000]}") | |
| if parts: | |
| file_context = "\n\n".join(parts) | |
| def update_progress(pct, msg): | |
| progress(pct, desc=msg) | |
| try: | |
| result = await current_agent.execute( | |
| message, | |
| use_search=use_search, | |
| file_context=file_context, | |
| progress_callback=update_progress | |
| ) | |
| response = result["response"] | |
| # Add subtle stats footer | |
| stats = f"\n\n---\n*{result['votes']} votes, {result['total_samples']} samples*" | |
| return history + [[message, response + stats]] | |
| except Exception as e: | |
| return history + [[message, f"β Error: {str(e)}"]] | |
| def chat(message, history, use_search, files): | |
| return asyncio.run(chat_async(message, history, use_search, files)) | |
| def clear_chat(): | |
| global loaded_files | |
| loaded_files = {} | |
| return [], None, "No files attached" | |
| # ============================================================================ | |
| # UI | |
| # ============================================================================ | |
| with gr.Blocks(title="MAKER Agent") as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px 0 10px 0;"> | |
| <h1 style="font-size: 2rem; margin: 0;">π§ MAKER Agent</h1> | |
| <p style="color: #666; margin: 5px 0;">Reliable AI with Voting β’ <a href="https://arxiv.org/abs/2511.09030" target="_blank">Paper</a></p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # Chat Tab | |
| with gr.Tab("π¬ Chat"): | |
| chatbot = gr.Chatbot( | |
| height=450, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=12): | |
| msg = gr.Textbox( | |
| placeholder="Ask anything...", | |
| show_label=False, | |
| lines=2, | |
| ) | |
| with gr.Column(scale=1, min_width=80): | |
| send_btn = gr.Button("Send", variant="primary", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| file_upload = gr.File( | |
| label="", | |
| file_count="multiple", | |
| file_types=[".pdf", ".docx", ".txt", ".md", ".json", ".csv"], | |
| show_label=False, | |
| ) | |
| with gr.Column(scale=2): | |
| file_status = gr.Markdown("No files attached") | |
| with gr.Column(scale=2): | |
| use_search = gr.Checkbox( | |
| label="π Web Search", | |
| value=False, | |
| info="Search DuckDuckGo" | |
| ) | |
| with gr.Column(scale=1): | |
| clear_btn = gr.Button("ποΈ Clear") | |
| # Event handlers | |
| file_upload.change(process_files, file_upload, file_status) | |
| msg.submit(chat, [msg, chatbot, use_search, file_upload], chatbot).then( | |
| lambda: "", None, msg | |
| ) | |
| send_btn.click(chat, [msg, chatbot, use_search, file_upload], chatbot).then( | |
| lambda: "", None, msg | |
| ) | |
| clear_btn.click(clear_chat, None, [chatbot, file_upload, file_status]) | |
| # Settings Tab | |
| with gr.Tab("βοΈ Settings"): | |
| gr.Markdown("### Connect to an LLM Provider") | |
| with gr.Row(): | |
| with gr.Column(): | |
| provider = gr.Dropdown( | |
| ["groq", "openai", "anthropic", "together", "openrouter"], | |
| value="groq", | |
| label="Provider", | |
| info="Groq is free & fast!" | |
| ) | |
| api_key = gr.Textbox( | |
| label="API Key", | |
| type="password", | |
| placeholder="Paste your API key here..." | |
| ) | |
| model = gr.Textbox( | |
| label="Model (optional)", | |
| placeholder="Leave blank for default" | |
| ) | |
| with gr.Column(): | |
| k_votes = gr.Slider( | |
| 1, 7, value=3, step=1, | |
| label="Reliability (K votes)", | |
| info="Higher = more reliable, slower" | |
| ) | |
| gr.Markdown(""" | |
| ### Get API Keys | |
| **Groq** (recommended - free & fast): | |
| [console.groq.com](https://console.groq.com) | |
| **OpenAI**: [platform.openai.com/api-keys](https://platform.openai.com/api-keys) | |
| **Anthropic**: [console.anthropic.com](https://console.anthropic.com) | |
| """) | |
| connect_btn = gr.Button("π Connect", variant="primary") | |
| status = gr.Markdown("π Enter your API key and click Connect") | |
| connect_btn.click( | |
| setup_agent, | |
| [provider, api_key, model, k_votes], | |
| [status, send_btn] | |
| ) | |
| # About Tab | |
| with gr.Tab("βΉοΈ About"): | |
| gr.Markdown(""" | |
| ## How MAKER Works | |
| This agent uses the **MAKER Framework** to achieve reliable AI responses: | |
| 1. **Multiple Samples** - Generates several responses for each question | |
| 2. **Voting** - Responses "vote" and the winner needs K votes ahead | |
| 3. **Red-Flagging** - Suspicious outputs are automatically discarded | |
| ### Why This Matters | |
| Instead of hoping the AI gets it right, MAKER uses statistics to ensure reliability. The paper achieved **1 million steps with zero errors** using this approach. | |
| ### Features | |
| - π **Web Search** - Free DuckDuckGo search (no API key needed) | |
| - π **File Upload** - PDF, DOCX, TXT, MD, JSON, CSV | |
| - β‘ **Multiple Providers** - Groq, OpenAI, Anthropic, and more | |
| ### Links | |
| - π [Research Paper](https://arxiv.org/abs/2511.09030) | |
| - π₯ [Video Explanation](https://youtube.com/watch?v=TJ-vWGCosdQ) | |
| """) | |
| # Footer | |
| gr.HTML(""" | |
| <div style="text-align: center; color: #888; padding: 15px; font-size: 0.85rem;"> | |
| MAKER Framework β’ <a href="https://arxiv.org/abs/2511.09030" style="color: #888;">arxiv.org/abs/2511.09030</a> | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |