Uppalapati commited on
Commit
36bb471
·
verified ·
1 Parent(s): acb3a08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -83
app.py CHANGED
@@ -2,9 +2,9 @@ import os
2
  import torch
3
  import numpy as np
4
  from transformers import AutoTokenizer, AutoModel
5
- from flask import Flask, request, jsonify
6
  import logging
7
- import spaces # HuggingFace Spaces GPU decorator
8
 
9
  # Try to import flash attention (optional)
10
  try:
@@ -19,8 +19,6 @@ except ImportError:
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
- app = Flask(__name__)
23
-
24
  # Qwen3-Embedding-4B model for retrieval
25
  MODEL_NAME = "Qwen/Qwen3-Embedding-4B"
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -31,7 +29,7 @@ tokenizer = None
31
  model = None
32
 
33
  def initialize_model():
34
- """Initialize model (runs on CPU in main process)"""
35
  global tokenizer, model
36
 
37
  if tokenizer is None:
@@ -44,7 +42,6 @@ def initialize_model():
44
  if model is None:
45
  logger.info(f"Loading {MODEL_NAME} on {DEVICE}")
46
 
47
- # Configure model loading with optional flash attention
48
  model_kwargs = {
49
  "trust_remote_code": True,
50
  "torch_dtype": torch.float16 if DEVICE == "cuda" else torch.float32
@@ -59,12 +56,15 @@ def initialize_model():
59
  model.eval()
60
  logger.info("✅ Model loaded successfully")
61
 
62
- # CRITICAL: This must be a TOP-LEVEL function with @spaces.GPU decorator
63
  @spaces.GPU
64
- def encode_texts_gpu(texts, batch_size=16):
65
  """
66
  Encode texts to embeddings using Qwen3-Embedding-4B
67
- This function MUST be at module level for ZeroGPU detection
 
 
 
 
68
  """
69
  global tokenizer, model
70
 
@@ -72,8 +72,11 @@ def encode_texts_gpu(texts, batch_size=16):
72
  if model is None or tokenizer is None:
73
  initialize_model()
74
 
75
- if isinstance(texts, str):
76
- texts = [texts]
 
 
 
77
 
78
  embeddings = []
79
 
@@ -93,7 +96,6 @@ def encode_texts_gpu(texts, batch_size=16):
93
 
94
  with torch.no_grad():
95
  outputs = model(**inputs)
96
- # Use EOS token embedding for Qwen3
97
  eos_token_id = tokenizer.eos_token_id
98
  sequence_lengths = (inputs['input_ids'] == eos_token_id).long().argmax(-1) - 1
99
 
@@ -103,82 +105,73 @@ def encode_texts_gpu(texts, batch_size=16):
103
  batch_embeddings.append(embedding)
104
 
105
  batch_embeddings = np.array(batch_embeddings)
106
-
107
- # Normalize embeddings
108
  batch_embeddings = batch_embeddings / np.linalg.norm(batch_embeddings, axis=1, keepdims=True)
109
 
110
  embeddings.extend(batch_embeddings)
111
 
112
- return embeddings
113
-
114
- @app.route("/", methods=["GET"])
115
- def health_check():
116
- return jsonify({
117
- "status": "healthy",
118
  "model": MODEL_NAME,
119
- "device": DEVICE,
120
- "embedding_dim": EMBEDDING_DIM,
121
- "max_context": 32768
122
- })
123
-
124
- @app.route("/embed", methods=["POST"])
125
- def embed_texts():
126
- """Embed texts and return embeddings"""
127
- try:
128
- data = request.get_json()
129
-
130
- if not data or "texts" not in data:
131
- return jsonify({"error": "Missing 'texts' field"}), 400
132
-
133
- texts = data["texts"]
134
- if not isinstance(texts, list):
135
- texts = [texts]
136
-
137
- logger.info(f"Embedding {len(texts)} texts")
138
-
139
- # Call the GPU-decorated function
140
- embeddings = encode_texts_gpu(texts)
141
-
142
- return jsonify({
143
- "embeddings": [embedding.tolist() for embedding in embeddings],
144
- "model": MODEL_NAME,
145
- "dimension": len(embeddings[0]) if embeddings else 0,
146
- "count": len(embeddings)
147
- })
148
-
149
- except Exception as e:
150
- logger.error(f"Embedding error: {str(e)}")
151
- return jsonify({"error": str(e)}), 500
152
 
153
- @app.route("/embed_single", methods=["POST"])
154
- def embed_single():
155
- """Embed single text (convenience endpoint)"""
156
- try:
157
- data = request.get_json()
158
-
159
- if not data or "text" not in data:
160
- return jsonify({"error": "Missing 'text' field"}), 400
161
-
162
- text = data["text"]
163
- logger.info(f"Embedding single text: {text[:100]}...")
164
-
165
- # Call the GPU-decorated function
166
- embeddings = encode_texts_gpu([text])
167
-
168
- return jsonify({
169
- "embedding": embeddings[0].tolist(),
170
- "model": MODEL_NAME,
171
- "dimension": len(embeddings[0]),
172
- "text_length": len(text)
173
- })
174
-
175
- except Exception as e:
176
- logger.error(f"Single embedding error: {str(e)}")
177
- return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  if __name__ == "__main__":
180
- logger.info("🚀 Starting embedding service...")
181
- logger.info("⚡ Model will load on first GPU request (ZeroGPU lazy loading)")
182
-
183
- port = int(os.environ.get("PORT", 7860))
184
- app.run(host="0.0.0.0", port=port)
 
2
  import torch
3
  import numpy as np
4
  from transformers import AutoTokenizer, AutoModel
5
+ import gradio as gr
6
  import logging
7
+ import spaces
8
 
9
  # Try to import flash attention (optional)
10
  try:
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
 
 
22
  # Qwen3-Embedding-4B model for retrieval
23
  MODEL_NAME = "Qwen/Qwen3-Embedding-4B"
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
29
  model = None
30
 
31
  def initialize_model():
32
+ """Initialize model"""
33
  global tokenizer, model
34
 
35
  if tokenizer is None:
 
42
  if model is None:
43
  logger.info(f"Loading {MODEL_NAME} on {DEVICE}")
44
 
 
45
  model_kwargs = {
46
  "trust_remote_code": True,
47
  "torch_dtype": torch.float16 if DEVICE == "cuda" else torch.float32
 
56
  model.eval()
57
  logger.info("✅ Model loaded successfully")
58
 
 
59
  @spaces.GPU
60
+ def encode_texts_gpu(texts_str, batch_size=16):
61
  """
62
  Encode texts to embeddings using Qwen3-Embedding-4B
63
+ Args:
64
+ texts_str: Either a single text string or multiple texts separated by '|||'
65
+ batch_size: Batch size for encoding
66
+ Returns:
67
+ JSON string with embeddings
68
  """
69
  global tokenizer, model
70
 
 
72
  if model is None or tokenizer is None:
73
  initialize_model()
74
 
75
+ # Parse input - support both single text and multiple texts
76
+ if '|||' in texts_str:
77
+ texts = [t.strip() for t in texts_str.split('|||')]
78
+ else:
79
+ texts = [texts_str]
80
 
81
  embeddings = []
82
 
 
96
 
97
  with torch.no_grad():
98
  outputs = model(**inputs)
 
99
  eos_token_id = tokenizer.eos_token_id
100
  sequence_lengths = (inputs['input_ids'] == eos_token_id).long().argmax(-1) - 1
101
 
 
105
  batch_embeddings.append(embedding)
106
 
107
  batch_embeddings = np.array(batch_embeddings)
 
 
108
  batch_embeddings = batch_embeddings / np.linalg.norm(batch_embeddings, axis=1, keepdims=True)
109
 
110
  embeddings.extend(batch_embeddings)
111
 
112
+ # Format output
113
+ import json
114
+ result = {
115
+ "embeddings": [emb.tolist() for emb in embeddings],
 
 
116
  "model": MODEL_NAME,
117
+ "dimension": len(embeddings[0]) if embeddings else 0,
118
+ "count": len(embeddings)
119
+ }
120
+
121
+ return json.dumps(result, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ # Create Gradio interface
124
+ with gr.Blocks(title="Qwen3-Embedding-4B API") as demo:
125
+ gr.Markdown("""
126
+ # Qwen3-Embedding-4B Embedding Service
127
+
128
+ This service generates embeddings using Qwen3-Embedding-4B (2560 dimensions).
129
+
130
+ **Usage:**
131
+ - Single text: Enter your text directly
132
+ - Multiple texts: Separate texts with `|||` (e.g., `text1|||text2|||text3`)
133
+ """)
134
+
135
+ with gr.Row():
136
+ with gr.Column():
137
+ text_input = gr.Textbox(
138
+ label="Text Input",
139
+ placeholder="Enter text or multiple texts separated by '|||'",
140
+ lines=5
141
+ )
142
+ batch_size_input = gr.Slider(
143
+ minimum=1,
144
+ maximum=64,
145
+ value=16,
146
+ step=1,
147
+ label="Batch Size"
148
+ )
149
+ submit_btn = gr.Button("Generate Embeddings", variant="primary")
150
+
151
+ with gr.Column():
152
+ output = gr.JSON(label="Embeddings Output")
153
+
154
+ submit_btn.click(
155
+ fn=encode_texts_gpu,
156
+ inputs=[text_input, batch_size_input],
157
+ outputs=output
158
+ )
159
+
160
+ gr.Markdown("""
161
+ ### API Usage
162
+ You can also call this Space via API:
163
+ ```
164
+ from gradio_client import Client
165
+
166
+ client = Client("YOUR_USERNAME/YOUR_SPACE_NAME")
167
+ result = client.predict(
168
+ texts_str="Your text here",
169
+ batch_size=16,
170
+ api_name="/predict"
171
+ )
172
+ print(result)
173
+ ```
174
+ """)
175
 
176
  if __name__ == "__main__":
177
+ demo.launch(server_name="0.0.0.0", server_port=7860)