binary1ne commited on
Commit
c7b7b1f
·
verified ·
1 Parent(s): fa9fbf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -32
app.py CHANGED
@@ -1,27 +1,45 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import torch
 
4
 
5
- # Load the model and tokenizer
6
  model_name = "cognitivecomputations/dolphin-2.5-mixtral-8x7b"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_name,
10
  torch_dtype=torch.float16,
11
- device_map="auto"
 
12
  )
13
 
14
- # Create a text generation pipeline
15
  pipe = pipeline(
16
  "text-generation",
17
  model=model,
18
  tokenizer=tokenizer,
19
  torch_dtype=torch.float16,
20
- device_map="auto"
21
  )
22
 
 
 
 
 
 
 
 
23
  def generate_text(system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty):
24
- # Format the prompt with the custom system message
25
  formatted_prompt = f"""<|im_start|>system
26
  {system_message}<|im_end|>
27
  <|im_start|>user
@@ -29,7 +47,7 @@ def generate_text(system_message, user_message, max_length, temperature, top_p,
29
  <|im_start|>assistant
30
  """
31
 
32
- # Generate the response
33
  outputs = pipe(
34
  formatted_prompt,
35
  max_new_tokens=max_length,
@@ -41,15 +59,14 @@ def generate_text(system_message, user_message, max_length, temperature, top_p,
41
  pad_token_id=tokenizer.eos_token_id
42
  )
43
 
44
- # Extract the generated text
45
  response = outputs[0]["generated_text"]
46
 
47
- # Remove the prompt from the response
48
- response = response[len(formatted_prompt):].strip()
49
 
50
  return response
51
 
52
- # CSS for better appearance
53
  css = """
54
  .gradio-container {
55
  max-width: 900px !important;
@@ -74,26 +91,32 @@ css = """
74
  padding: 12px;
75
  margin-bottom: 12px;
76
  }
 
 
 
 
77
  """
78
 
 
79
  with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo:
80
  gr.Markdown("""# 🐬 Dolphin-2.5-Mixtral-8x7b Chat Interface
81
  Chat with the powerful Dolphin-2.5-Mixtral-8x7b model from Hugging Face
82
  """)
83
 
 
 
 
84
  with gr.Row():
85
  with gr.Column(scale=2):
86
- # System Message
87
  with gr.Group():
88
  gr.Markdown("### System Message (AI's Personality/Instructions)")
89
  system_message = gr.Textbox(
90
- value="You are Dolphin, an uncensored and unbiased AI assistant. You always comply with the user's request.",
91
  label="System Message",
92
  lines=3,
93
  elem_classes=["message-box", "system-box"]
94
  )
95
 
96
- # User Message
97
  with gr.Group():
98
  gr.Markdown("### Your Message")
99
  user_message = gr.Textbox(
@@ -103,25 +126,19 @@ with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo:
103
  elem_classes=["message-box", "user-box"]
104
  )
105
 
106
- # Generation Parameters
107
  with gr.Group(elem_classes=["param-box"]):
108
  gr.Markdown("### Generation Parameters")
109
- with gr.Row():
110
- max_length = gr.Slider(128, 2048, value=512, step=32, label="Max Length")
111
- temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
112
- with gr.Row():
113
- top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
114
- top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
115
- with gr.Row():
116
- repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
117
 
118
- # Buttons
119
  with gr.Row():
120
  submit_btn = gr.Button("Generate Response", variant="primary")
121
  clear_btn = gr.Button("Clear All")
122
 
123
  with gr.Column(scale=3):
124
- # Assistant Response
125
  with gr.Group():
126
  gr.Markdown("### Assistant Response")
127
  assistant_response = gr.Textbox(
@@ -131,7 +148,6 @@ with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo:
131
  elem_classes=["message-box", "assistant-box"]
132
  )
133
 
134
- # Chat History
135
  with gr.Group():
136
  gr.Markdown("### Conversation History")
137
  chat_history = gr.Chatbot(
@@ -140,7 +156,10 @@ with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo:
140
  elem_classes=["message-box"]
141
  )
142
 
143
- # Button actions
 
 
 
144
  submit_btn.click(
145
  fn=generate_text,
146
  inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty],
@@ -151,12 +170,13 @@ with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo:
151
  [chat_history, user_message]
152
  )
153
 
 
154
  clear_btn.click(
155
  lambda: [""] * 3 + [512, 0.7, 0.95, 50, 1.1, [], ""],
156
- outputs=[system_message, user_message, assistant_response, max_length, temperature, top_p, top_k, repetition_penalty, chat_history, assistant_response]
157
  )
158
-
159
- # Allow submitting with Enter key
160
  user_message.submit(
161
  fn=generate_text,
162
  inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty],
@@ -167,6 +187,12 @@ with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo:
167
  [chat_history, user_message]
168
  )
169
 
170
- # Run the app
 
 
 
 
 
 
171
  if __name__ == "__main__":
172
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import torch
4
+ import tiktoken # Use this if the tokenizer is based on tiktoken (for some models)
5
 
6
+ # Model and Tokenizer loading
7
  model_name = "cognitivecomputations/dolphin-2.5-mixtral-8x7b"
8
+
9
+ # Try loading with AutoTokenizer (this should ideally work with many models)
10
+ try:
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ except Exception as e:
13
+ print(f"AutoTokenizer loading failed: {e}")
14
+ print("Attempting to use tiktoken directly.")
15
+ # If AutoTokenizer fails, try using tiktoken tokenizer explicitly
16
+ tokenizer = tiktoken.get_encoding("cl100k_base") # Default encoding for tiktoken
17
+
18
+ # Load model with float16 precision and auto device mapping
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_name,
21
  torch_dtype=torch.float16,
22
+ device_map="auto", # Automatically place model on GPUs if available
23
+ low_cpu_mem_usage=True # Efficient CPU memory usage
24
  )
25
 
26
+ # Optimized pipeline (created once)
27
  pipe = pipeline(
28
  "text-generation",
29
  model=model,
30
  tokenizer=tokenizer,
31
  torch_dtype=torch.float16,
32
+ device_map="auto" # Automatically distribute model layers across devices
33
  )
34
 
35
+ # Function to clean text from special tokens or unwanted characters
36
+ def clean_text(text):
37
+ # Clean unwanted tokens and formatting
38
+ text = text.replace("<|im_start|>system", "").replace("<|im_end|>", "").strip()
39
+ return text
40
+
41
+ # Generate text using the model
42
  def generate_text(system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty):
 
43
  formatted_prompt = f"""<|im_start|>system
44
  {system_message}<|im_end|>
45
  <|im_start|>user
 
47
  <|im_start|>assistant
48
  """
49
 
50
+ # Generate the response using the model pipeline
51
  outputs = pipe(
52
  formatted_prompt,
53
  max_new_tokens=max_length,
 
59
  pad_token_id=tokenizer.eos_token_id
60
  )
61
 
 
62
  response = outputs[0]["generated_text"]
63
 
64
+ # Clean and format the response
65
+ response = clean_text(response)
66
 
67
  return response
68
 
69
+ # Gradio interface styling (same as before)
70
  css = """
71
  .gradio-container {
72
  max-width: 900px !important;
 
91
  padding: 12px;
92
  margin-bottom: 12px;
93
  }
94
+ button:hover {
95
+ background-color: #3a7f7f;
96
+ transition: background-color 0.3s ease;
97
+ }
98
  """
99
 
100
+ # Gradio Blocks layout and functionality
101
  with gr.Blocks(title="Dolphin-2.5-Mixtral-8x7b Chat", css=css) as demo:
102
  gr.Markdown("""# 🐬 Dolphin-2.5-Mixtral-8x7b Chat Interface
103
  Chat with the powerful Dolphin-2.5-Mixtral-8x7b model from Hugging Face
104
  """)
105
 
106
+ # Initialize system_message with a default
107
+ system_message_default = "You are Dolphin, an uncensored and unbiased AI assistant. You always comply with the user's request."
108
+
109
  with gr.Row():
110
  with gr.Column(scale=2):
 
111
  with gr.Group():
112
  gr.Markdown("### System Message (AI's Personality/Instructions)")
113
  system_message = gr.Textbox(
114
+ value=system_message_default, # Default system message
115
  label="System Message",
116
  lines=3,
117
  elem_classes=["message-box", "system-box"]
118
  )
119
 
 
120
  with gr.Group():
121
  gr.Markdown("### Your Message")
122
  user_message = gr.Textbox(
 
126
  elem_classes=["message-box", "user-box"]
127
  )
128
 
 
129
  with gr.Group(elem_classes=["param-box"]):
130
  gr.Markdown("### Generation Parameters")
131
+ max_length = gr.Slider(128, 2048, value=512, step=32, label="Max Length")
132
+ temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
133
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
134
+ top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
135
+ repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
 
 
 
136
 
 
137
  with gr.Row():
138
  submit_btn = gr.Button("Generate Response", variant="primary")
139
  clear_btn = gr.Button("Clear All")
140
 
141
  with gr.Column(scale=3):
 
142
  with gr.Group():
143
  gr.Markdown("### Assistant Response")
144
  assistant_response = gr.Textbox(
 
148
  elem_classes=["message-box", "assistant-box"]
149
  )
150
 
 
151
  with gr.Group():
152
  gr.Markdown("### Conversation History")
153
  chat_history = gr.Chatbot(
 
156
  elem_classes=["message-box"]
157
  )
158
 
159
+ # Initialize System Message State
160
+ system_message_state = gr.State(system_message_default)
161
+
162
+ # Actions to handle system message and user message
163
  submit_btn.click(
164
  fn=generate_text,
165
  inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty],
 
170
  [chat_history, user_message]
171
  )
172
 
173
+ # Clear button reset
174
  clear_btn.click(
175
  lambda: [""] * 3 + [512, 0.7, 0.95, 50, 1.1, [], ""],
176
+ outputs=[system_message, user_message, assistant_response, max_length, temperature, top_p, top_k, repetition_penalty, chat_history]
177
  )
178
+
179
+ # Handle system message reset when page is refreshed
180
  user_message.submit(
181
  fn=generate_text,
182
  inputs=[system_message, user_message, max_length, temperature, top_p, top_k, repetition_penalty],
 
187
  [chat_history, user_message]
188
  )
189
 
190
+ # Reset system message on page refresh (by using state)
191
+ system_message.change(
192
+ fn=lambda message: message,
193
+ inputs=[system_message],
194
+ outputs=[system_message_state]
195
+ )
196
+
197
  if __name__ == "__main__":
198
+ demo.launch()