GrandMasterPomidor commited on
Commit
54a6bd2
·
verified ·
1 Parent(s): e37062a

Update inference endpoint handler (20250927-165107)

Browse files
Files changed (3) hide show
  1. README.md +25 -0
  2. handler.py +418 -0
  3. requirements.txt +10 -0
README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Qwen Omni Hugging Face Inference Endpoint Handler
2
+
3
+ This directory contains a reusable custom handler for deploying Qwen 3 Omni models
4
+ (via the Hugging Face Inference Endpoints service). The handler mirrors the
5
+ multi-modal interaction blueprint from the official Qwen audio/visual dialogue
6
+ cookbook and supports text, image, and audio turns in a single payload.
7
+
8
+ ## Files
9
+
10
+ * `handler.py` – entry-point loaded by the Inference Endpoint runtime.
11
+ * `requirements.txt` – Python dependencies installed before the handler is imported.
12
+
13
+ ## Usage
14
+
15
+ 1. Upload the contents of this directory (`handler.py`, `requirements.txt`) to a
16
+ Hugging Face model repository that you control (defaults to
17
+ `GrandMasterPomidor/qwen-omni-endpoint-handler` via the provided Makefile).
18
+ 2. Provision a custom Inference Endpoint that references that repository and the
19
+ Qwen Omni model weights you wish to serve. Set environment variables such as
20
+ `MODEL_ID` to point at your chosen checkpoint (e.g. `Qwen/Qwen2.5-Omni-Mini`).
21
+ 3. Send JSON payloads to the endpoint as documented in the header docstring of
22
+ `handler.py`.
23
+
24
+ Refer to the accompanying `Makefile` for convenience targets to package and
25
+ push these assets.
handler.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom Hugging Face Inference Endpoints handler for Qwen Omni models.
2
+
3
+ This handler is designed for multi-modal dialogue with the Qwen3 Omni models,
4
+ following the audio/visual dialogue cookbook in the Qwen repository. It loads
5
+ an Omni chat model, accepts mixed text, image, and audio content, and returns
6
+ an assistant reply that can be fed into subsequent turns.
7
+
8
+ Expected request payload structure (JSON):
9
+ {
10
+ "inputs": {
11
+ "messages": [
12
+ {
13
+ "role": "user",
14
+ "content": [
15
+ {"type": "text", "text": "Describe the picture"},
16
+ {"type": "image", "image_url": "https://.../photo.jpg"},
17
+ {"type": "audio", "audio_url": "https://.../clip.wav"}
18
+ ]
19
+ }
20
+ ]
21
+ },
22
+ "parameters": {
23
+ "max_new_tokens": 256,
24
+ "temperature": 0.7,
25
+ "top_p": 0.9
26
+ }
27
+ }
28
+
29
+ Supported content variants:
30
+ * Text: provide "text" or "value".
31
+ * Image: provide one of "image" (base64 string with optional data URI),
32
+ "image_url" (HTTP(S) URL), or "image_path" (path within the repository).
33
+ * Audio: provide either
34
+ - "audio"/"array" with float samples plus "sampling_rate" (Hz), or
35
+ - base64 data under "audio"/"audio_b64", or
36
+ - remote/local path via "audio_url"/"audio_path".
37
+
38
+ Environment variables:
39
+ * MODEL_ID (defaults to Qwen/Qwen3-Omni-30B-A3B-Instruct) – Hugging Face model repo.
40
+ * DEVICE (defaults to cuda if available else cpu) – Inference device override.
41
+ * DEVICE_MAP (defaults to auto when GPU available) – Passed to from_pretrained.
42
+ * TORCH_DTYPE (defaults to bfloat16 on GPU, float32 on CPU) – torch dtype name.
43
+ * MAX_NEW_TOKENS, TEMPERATURE, TOP_P, TOP_K, DO_SAMPLE – override defaults.
44
+
45
+ Returned payload:
46
+ {
47
+ "generated_text": "...assistant reply...",
48
+ "messages": [...messages augmented with assistant turn...],
49
+ "generation_kwargs": {...actual generation settings used...}
50
+ }
51
+ """
52
+
53
+ from __future__ import annotations
54
+
55
+ import base64
56
+ import io
57
+ import json
58
+ import os
59
+ from dataclasses import dataclass
60
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
61
+
62
+ import numpy as np
63
+ import torch
64
+ from PIL import Image
65
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
66
+
67
+ try:
68
+ import requests
69
+ except ImportError: # pragma: no cover - requests is available on endpoints but guard just in case
70
+ requests = None # type: ignore
71
+
72
+
73
+ @dataclass
74
+ class AudioPayload:
75
+ """Container for audio samples consumed by the Omni processor."""
76
+
77
+ array: np.ndarray
78
+ sampling_rate: int
79
+
80
+ def as_processor_input(self) -> Dict[str, Any]:
81
+ return {
82
+ "array": self.array.astype(np.float32),
83
+ "sampling_rate": int(self.sampling_rate),
84
+ }
85
+
86
+
87
+ class EndpointHandler:
88
+ """Hugging Face custom handler compatible with multi-modal Qwen Omni models."""
89
+
90
+ def __init__(self, path: str = "") -> None:
91
+ model_id = os.getenv(
92
+ "MODEL_ID") or path or "Qwen/Qwen3-Omni-30B-A3B-Instruct"
93
+ device_hint = os.getenv("DEVICE")
94
+ self.device = device_hint or (
95
+ "cuda" if torch.cuda.is_available() else "cpu")
96
+ dtype_name = os.getenv(
97
+ "TORCH_DTYPE",
98
+ "bfloat16" if self.device.startswith("cuda") else "float32",
99
+ )
100
+ torch_dtype = getattr(torch, dtype_name, None)
101
+ if torch_dtype is None:
102
+ raise ValueError(f"Unsupported TORCH_DTYPE value: {dtype_name}")
103
+
104
+ model_kwargs: Dict[str, Any] = {
105
+ "trust_remote_code": True,
106
+ "torch_dtype": torch_dtype,
107
+ }
108
+
109
+ device_map_env = os.getenv("DEVICE_MAP")
110
+ if device_map_env:
111
+ model_kwargs["device_map"] = device_map_env
112
+ elif self.device != "cpu":
113
+ model_kwargs["device_map"] = "auto"
114
+
115
+ self.model = AutoModelForCausalLM.from_pretrained(
116
+ model_id, **model_kwargs)
117
+ if model_kwargs.get("device_map") is None:
118
+ self.model.to(self.device)
119
+
120
+ self.processor = AutoProcessor.from_pretrained(
121
+ model_id, trust_remote_code=True)
122
+
123
+ try:
124
+ generation_config = GenerationConfig.from_pretrained(model_id)
125
+ except Exception: # pragma: no cover - not all repos ship a config
126
+ generation_config = self.model.generation_config
127
+ self.base_generation_kwargs = self._extract_generation_kwargs(
128
+ generation_config)
129
+
130
+ # ---------------------------------------------------------------------
131
+ # Public API
132
+ # ---------------------------------------------------------------------
133
+ def __call__(self, data: Dict[str, Any], *args: Any, **kwargs: Any) -> Dict[str, Any]:
134
+ if not data:
135
+ raise ValueError("Empty payload received by handler")
136
+
137
+ payload = data.get("inputs") if isinstance(data, dict) else data
138
+ parameters = data.get("parameters", {}) if isinstance(
139
+ data, dict) else {}
140
+
141
+ messages = self._normalize_messages(payload)
142
+ processed_messages, images, audios = self._prepare_messages(messages)
143
+
144
+ chat_template = self.processor.apply_chat_template(
145
+ processed_messages,
146
+ add_generation_prompt=True,
147
+ tokenize=False,
148
+ )
149
+
150
+ model_inputs = self.processor(
151
+ text=chat_template,
152
+ images=[img for img in images] if images else None,
153
+ audios=[aud.as_processor_input()
154
+ for aud in audios] if audios else None,
155
+ return_tensors="pt",
156
+ )
157
+
158
+ if hasattr(model_inputs, "to"):
159
+ model_inputs = model_inputs.to(self.model.device if hasattr(
160
+ self.model, "device") else self.device)
161
+ else:
162
+ model_inputs = {
163
+ k: v.to(self.model.device if hasattr(
164
+ self.model, "device") else self.device)
165
+ for k, v in model_inputs.items()
166
+ }
167
+
168
+ generation_kwargs = {**self.base_generation_kwargs, **parameters}
169
+ generation_kwargs.setdefault("return_dict_in_generate", True)
170
+ generation_kwargs.setdefault("output_scores", False)
171
+
172
+ with torch.inference_mode():
173
+ outputs = self.model.generate(**model_inputs, **generation_kwargs)
174
+
175
+ sequences = outputs.sequences if hasattr(
176
+ outputs, "sequences") else outputs
177
+ input_length = model_inputs["input_ids"].shape[-1]
178
+ generated_ids = sequences[:, input_length:]
179
+ generated_text = self.processor.batch_decode(
180
+ generated_ids,
181
+ skip_special_tokens=True,
182
+ clean_up_tokenization_spaces=True,
183
+ )[0].strip()
184
+
185
+ augmented_messages = list(messages) + [
186
+ {
187
+ "role": "assistant",
188
+ "content": [
189
+ {
190
+ "type": "text",
191
+ "text": generated_text,
192
+ }
193
+ ],
194
+ }
195
+ ]
196
+
197
+ return {
198
+ "generated_text": generated_text,
199
+ "messages": augmented_messages,
200
+ "generation_kwargs": generation_kwargs,
201
+ }
202
+
203
+ # ------------------------------------------------------------------
204
+ # Helpers
205
+ # ------------------------------------------------------------------
206
+ @staticmethod
207
+ def _extract_generation_kwargs(config: GenerationConfig) -> Dict[str, Any]:
208
+ defaults = {
209
+ "max_new_tokens": getattr(config, "max_new_tokens", 512),
210
+ "temperature": getattr(config, "temperature", 0.7),
211
+ "top_p": getattr(config, "top_p", 0.9),
212
+ "top_k": getattr(config, "top_k", None),
213
+ "do_sample": getattr(config, "do_sample", True),
214
+ }
215
+
216
+ env_overrides = {
217
+ "max_new_tokens": os.getenv("MAX_NEW_TOKENS"),
218
+ "temperature": os.getenv("TEMPERATURE"),
219
+ "top_p": os.getenv("TOP_P"),
220
+ "top_k": os.getenv("TOP_K"),
221
+ "do_sample": os.getenv("DO_SAMPLE"),
222
+ }
223
+
224
+ for key, value in env_overrides.items():
225
+ if value is None:
226
+ continue
227
+ if key == "do_sample":
228
+ defaults[key] = value.lower() == "true"
229
+ elif key == "max_new_tokens" or key == "top_k":
230
+ defaults[key] = int(value)
231
+ else:
232
+ defaults[key] = float(value)
233
+ return {k: v for k, v in defaults.items() if v is not None}
234
+
235
+ @staticmethod
236
+ def _normalize_messages(payload: Any) -> List[Dict[str, Any]]:
237
+ if isinstance(payload, str):
238
+ return [
239
+ {
240
+ "role": "user",
241
+ "content": [{"type": "text", "text": payload}],
242
+ }
243
+ ]
244
+ if isinstance(payload, dict) and "messages" in payload:
245
+ return payload["messages"]
246
+ if isinstance(payload, dict):
247
+ text_value = payload.get("prompt") or payload.get("text")
248
+ if text_value:
249
+ return [
250
+ {
251
+ "role": payload.get("role", "user"),
252
+ "content": [{"type": "text", "text": text_value}],
253
+ }
254
+ ]
255
+ raise ValueError(
256
+ "Unsupported input format. Provide `inputs.messages` or a raw text prompt.")
257
+
258
+ def _prepare_messages(
259
+ self, messages: Iterable[Dict[str, Any]]
260
+ ) -> Tuple[List[Dict[str, Any]], List[Image.Image], List[AudioPayload]]:
261
+ processed_messages: List[Dict[str, Any]] = []
262
+ images: List[Image.Image] = []
263
+ audios: List[AudioPayload] = []
264
+
265
+ for message in messages:
266
+ role = message.get("role", "user")
267
+ raw_content = message.get("content")
268
+ if raw_content is None:
269
+ raise ValueError(f"Message without content: {message}")
270
+
271
+ if isinstance(raw_content, str):
272
+ raw_content = [{"type": "text", "text": raw_content}]
273
+
274
+ new_parts: List[Dict[str, Any]] = []
275
+ for part in raw_content:
276
+ part_type = part.get("type", "text")
277
+
278
+ if part_type == "text":
279
+ text = part.get("text") or part.get("value")
280
+ if text is None:
281
+ raise ValueError(f"Missing text value in part: {part}")
282
+ new_parts.append({"type": "text", "text": text})
283
+
284
+ elif part_type == "image":
285
+ image = self._load_image(part)
286
+ images.append(image)
287
+ new_parts.append({"type": "image", "image": image})
288
+
289
+ elif part_type == "audio":
290
+ audio_payload = self._load_audio(part)
291
+ audios.append(audio_payload)
292
+ new_parts.append(
293
+ {"type": "audio", "audio": audio_payload.as_processor_input()})
294
+
295
+ else:
296
+ raise ValueError(f"Unsupported content type: {part_type}")
297
+
298
+ processed_messages.append({"role": role, "content": new_parts})
299
+
300
+ return processed_messages, images, audios
301
+
302
+ # ------------------------------------------------------------------
303
+ # Loaders
304
+ # ------------------------------------------------------------------
305
+ def _load_image(self, part: Dict[str, Any]) -> Image.Image:
306
+ if "image" in part and isinstance(part["image"], Image.Image):
307
+ return part["image"]
308
+ if "image" in part and isinstance(part["image"], str):
309
+ return self._decode_image_string(part["image"])
310
+ if "image_b64" in part:
311
+ return self._decode_image_string(part["image_b64"])
312
+ if "image_path" in part:
313
+ return Image.open(part["image_path"]).convert("RGB")
314
+ if "image_url" in part:
315
+ data = self._fetch_remote(part["image_url"])
316
+ return Image.open(io.BytesIO(data)).convert("RGB")
317
+ raise ValueError(f"Cannot resolve image content from part: {part}")
318
+
319
+ def _load_audio(self, part: Dict[str, Any]) -> AudioPayload:
320
+ if "audio" in part and isinstance(part["audio"], dict) and "array" in part["audio"]:
321
+ array = np.asarray(part["audio"]["array"], dtype=np.float32)
322
+ sampling_rate = int(part["audio"].get(
323
+ "sampling_rate", part.get("sampling_rate", 16000)))
324
+ return AudioPayload(array=array, sampling_rate=sampling_rate)
325
+
326
+ if "array" in part:
327
+ array = np.asarray(part["array"], dtype=np.float32)
328
+ sampling_rate = int(part.get("sampling_rate", 16000))
329
+ return AudioPayload(array=array, sampling_rate=sampling_rate)
330
+
331
+ audio_bytes: Optional[bytes] = None
332
+ if "audio" in part and isinstance(part["audio"], str):
333
+ audio_bytes = self._maybe_read_bytes(part["audio"])
334
+ elif "audio_b64" in part:
335
+ audio_bytes = base64.b64decode(part["audio_b64"])
336
+ elif "audio_path" in part:
337
+ with open(part["audio_path"], "rb") as handle:
338
+ audio_bytes = handle.read()
339
+ elif "audio_url" in part:
340
+ audio_bytes = self._fetch_remote(part["audio_url"])
341
+
342
+ if audio_bytes is None:
343
+ raise ValueError(f"Cannot resolve audio content from part: {part}")
344
+
345
+ array, sampling_rate = self._decode_audio(audio_bytes)
346
+ return AudioPayload(array=array, sampling_rate=sampling_rate)
347
+
348
+ @staticmethod
349
+ def _decode_image_string(raw: str) -> Image.Image:
350
+ if raw.startswith("data:"):
351
+ raw = raw.split(",", 1)[1]
352
+ image_bytes = base64.b64decode(raw)
353
+ return Image.open(io.BytesIO(image_bytes)).convert("RGB")
354
+
355
+ @staticmethod
356
+ def _maybe_read_bytes(value: str) -> bytes:
357
+ if os.path.exists(value):
358
+ with open(value, "rb") as handle:
359
+ return handle.read()
360
+ try:
361
+ if value.startswith("data:"):
362
+ value = value.split(",", 1)[1]
363
+ return base64.b64decode(value)
364
+ except Exception as exc:
365
+ raise ValueError(
366
+ "Provide either a file path or base64-encoded audio for 'audio'.") from exc
367
+
368
+ @staticmethod
369
+ def _decode_audio(raw_bytes: bytes) -> Tuple[np.ndarray, int]:
370
+ # Try python-soundfile first, fall back to torchaudio if available.
371
+ try:
372
+ import soundfile as sf
373
+
374
+ array, sampling_rate = sf.read(io.BytesIO(raw_bytes))
375
+ if array.ndim > 1:
376
+ array = np.mean(array, axis=1)
377
+ return array.astype(np.float32), int(sampling_rate)
378
+ except Exception:
379
+ pass
380
+
381
+ try:
382
+ import torchaudio
383
+
384
+ waveform, sampling_rate = torchaudio.load(io.BytesIO(raw_bytes))
385
+ array = waveform.mean(dim=0).numpy()
386
+ return array.astype(np.float32), int(sampling_rate)
387
+ except Exception as exc:
388
+ raise RuntimeError(
389
+ "Unable to decode audio bytes. Install 'soundfile' or 'torchaudio' in requirements."
390
+ ) from exc
391
+
392
+ @staticmethod
393
+ def _fetch_remote(url: str) -> bytes:
394
+ if requests is None:
395
+ raise RuntimeError(
396
+ "requests is required to download remote resources")
397
+ response = requests.get(url, timeout=10)
398
+ response.raise_for_status()
399
+ return response.content
400
+
401
+
402
+ if __name__ == "__main__": # pragma: no cover - simple smoke test entry point
403
+ handler = EndpointHandler()
404
+ demo_payload = {
405
+ "inputs": {
406
+ "messages": [
407
+ {
408
+ "role": "user",
409
+ "content": [
410
+ {"type": "text", "text": "Describe the image"},
411
+ ],
412
+ }
413
+ ]
414
+ },
415
+ "parameters": {"max_new_tokens": 64},
416
+ }
417
+ response = handler(demo_payload)
418
+ print(json.dumps(response, indent=2))
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dependencies for the Qwen Omni custom inference handler
2
+ transformers>=4.43.0
3
+ accelerate>=0.33.0
4
+ torch>=2.2.0
5
+ sentencepiece
6
+ numpy>=1.24
7
+ pillow>=10.0
8
+ requests>=2.31
9
+ soundfile>=0.12
10
+ torchaudio>=2.2