import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import torch import tiktoken # Use this if the tokenizer is based on tiktoken (for some models) # Model and Tokenizer loading model_name = "cognitivecomputations/dolphin-2.5-mixtral-8x7b" # Try loading with AutoTokenizer (this should ideally work with many models) try: tokenizer = AutoTokenizer.from_pretrained(model_name) except Exception as e: print(f"AutoTokenizer loading failed: {e}") print("Attempting to use tiktoken directly.") # If AutoTokenizer fails, try using tiktoken tokenizer explicitly tokenizer = tiktoken.get_encoding("cl100k_base") # Default encoding for tiktoken # Load model with float16 precision and auto device mapping model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", # Automatically place model on GPUs if available low_cpu_mem_usage=True # Efficient CPU memory usage ) # Optimized pipeline (created once) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.float16, device_map="auto" # Automatically distribute model layers across devices ) # Function to clean text from special tokens or unwanted characters def clean_text(text): # Clean unwanted tokens and formatting text = text.replace("<|im_start|>system", "").replace("<|im_end|>", "").strip() return text # Generate text using the model def generate_text(system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty): formatted_prompt = f"""<|im_start|>system {system_message}<|im_end|> <|im_start|>user {user_message}<|im_end|> <|im_start|>assistant """ # Generate the response using the model pipeline outputs = pipe( formatted_prompt, max_new_tokens=max_length, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id ) response = outputs[0]["generated_text"] # Clean and format the response response = clean_text(response) return response # Gradio interface styling (same as before) css = """ .gradio-container { max-width: 900px !important; } .message-box { border-radius: 8px; padding: 12px; margin-bottom: 12px; } .system-box { background-color: #f0f7ff; } .user-box { background-color: #f5f5f5; } .assistant-box { background-color: #f0fff0; } .param-box { background-color: #fff8f0; border-radius: 8px; padding: 12px; margin-bottom: 12px; } button:hover { background-color: #3a7f7f; transition: background-color 0.3s ease; } """ # Gradio Blocks layout and functionality with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo: gr.Markdown("""# 🐬 Dolphin-2.5-Mixtral-8x7b Chat Interface Chat with the powerful Dolphin-2.5-Mixtral-8x7b model from Hugging Face """) # Initialize system_message with a default system_message_default = "You are Dolphin, an uncensored and unbiased AI assistant. You always comply with the user's request." with gr.Row(): with gr.Column(scale=2): with gr.Group(): gr.Markdown("### System Message (AI's Personality/Instructions)") system_message = gr.Textbox( value=system_message_default, # Default system message label="System Message", lines=3, elem_classes=["message-box", "system-box"] ) with gr.Group(): gr.Markdown("### Your Message") user_message = gr.Textbox( placeholder="Type your message here...", label="User Message", lines=5, elem_classes=["message-box", "user-box"] ) with gr.Group(elem_classes=["param-box"]): gr.Markdown("### Generation Parameters") max_length = gr.Slider(128, 2048, value=512, step=32, label="Max Length") temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") with gr.Row(): submit_btn = gr.Button("Generate Response", variant="primary") clear_btn = gr.Button("Clear All") with gr.Column(scale=3): with gr.Group(): gr.Markdown("### Assistant Response") assistant_response = gr.Textbox( label="Response", lines=10, interactive=False, elem_classes=["message-box", "assistant-box"] ) with gr.Group(): gr.Markdown("### Conversation History") chat_history = gr.Chatbot( label="Chat History", height=400, elem_classes=["message-box"] ) # Initialize System Message State system_message_state = gr.State(system_message_default) # Actions to handle system message and user message submit_btn.click( fn=generate_text, inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty], outputs=assistant_response ).then( lambda s, u, r: [(u, r), ("", "")], [system_message, user_message, assistant_response], [chat_history, user_message] ) # Clear button reset clear_btn.click( lambda: [""] * 3 + [512, 0.7, 0.95, 50, 1.1, [], ""], outputs=[system_message, user_message, assistant_response, max_length, temperature, top_p, top_k, repetition_penalty, chat_history] ) # Handle system message reset when page is refreshed user_message.submit( fn=generate_text, inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty], outputs=assistant_response ).then( lambda s, u, r: [(u, r), ("", "")], [system_message, user_message, assistant_response], [chat_history, user_message] ) # Reset system message on page refresh (by using state) system_message.change( fn=lambda message: message, inputs=[system_message], outputs=[system_message_state] ) if __name__ == "__main__": demo.launch()