.gitignore DELETED
@@ -1 +0,0 @@
1
- .DS_Store
 
 
README.md CHANGED
@@ -5,530 +5,379 @@ license: cc-by-nc-4.0
5
  tags:
6
  - multimodal
7
  - multilingual
 
8
  - vlm
9
- - vision-language
10
- - qwen3
11
- - siglip2
12
  language:
13
  - en
14
- - zh
15
- - ar
16
- - pt
17
- - ru
18
- - tr
19
- - de
20
- - es
21
- - fr
22
- - it
23
- - ja
24
- - ko
25
- - vi
26
- - th
27
- - id
28
- - hi
29
- - bn
30
- - nl
31
- - pl
32
- - sv
33
- - fi
34
- - da
35
- - "no"
36
- - cs
37
- - el
38
- - he
39
- - uk
40
- - ro
41
- - hu
42
  - multilingual
43
- base_model:
44
- - Qwen/Qwen3-1.7B-Base
45
- - google/siglip2-so400m-patch14-384
46
  inference: false
47
  ---
 
48
 
49
  <p align="center">
50
- <img src="https://huggingface.co/datasets/jinaai/documentation-images/resolve/main/logo.webp" alt="Jina AI: Your Search Foundation, Supercharged!" width="150px">
51
  </p>
52
 
53
- # jina-vlm: Small Multilingual Vision Language Model
 
 
54
 
55
- [Blog](https://jina.ai/news/jina-vlm-small-multilingual-vision-language-model/) | API | AWS | Azure | GCP | [Arxiv](https://arxiv.org/abs/2512.04032)
56
 
57
- `jina-vlm` is a 2.4B parameter vision-language model that achieves state-of-the-art multilingual visual question answering among open 2B-scale VLMs. The model couples a SigLIP2 vision encoder with a Qwen3 language backbone through an attention-pooling connector that enables token-efficient processing of arbitrary-resolution images. Training data comprises approximately 5M multimodal samples and 12B text tokens across 29 languages, with roughly half in English and the remainder spanning high- and moderate-resource languages.
58
 
59
- ![jina-vlm architecture](./assets/jvlm_architecture.png)
60
 
61
- Built on [Qwen3-1.7B-Base](https://huggingface.co/Qwen/Qwen3-1.7B-Base) with [SigLIP2-So400M](https://huggingface.co/google/siglip2-so400m-patch14-384), it processes images via overlapping tiling with attention-based token pooling that reduces visual tokens by 4x while preserving spatial information. The model achieves the highest average score (72.3) across eight VQA benchmarks while leading on multilingual multimodal understanding (MMMB: 78.8, Multilingual MMBench: 74.3).
62
 
63
- | Model | Params | VQA Avg | MMMB | MM-Bench | RealWorld QA |
64
- |-------|--------|---------|------|----------|--------------|
65
- | **jina-vlm** | 2.4B | **72.3** | **78.8** | **74.3** | **68.2** |
66
- | Qwen2-VL-2B | 2.2B | 66.4 | 71.3 | 69.4 | 62.9 |
67
- | Qwen3-VL-2B | 2.2B | 71.6 | 75.0 | 72.3 | 63.9 |
68
- | InternVL3-2B | 2.2B | 69.2 | 73.6 | 71.9 | 64.3 |
69
- | InternVL3.5-2B | 2.2B | 71.6 | 74.6 | 70.9 | 62.0 |
70
 
 
71
 
72
- ## Via Jina API
73
 
74
- We provide an OpenAI-compatible API at `https://api-beta-vlm.jina.ai`. All requests require a Jina API key in the Authorization header, get your API key at [jina.ai](https://jina.ai).
75
 
76
 
77
- ### Image from URL
78
 
79
- | Format | Example |
80
- |--------|---------|
81
- | HTTP/HTTPS URL | `https://example.com/image.jpg` |
82
- | Base64 data URI | `...` |
83
 
84
- ```bash
85
- curl https://api-beta-vlm.jina.ai/v1/chat/completions \
86
- -H "Content-Type: application/json" \
87
- -H "Authorization: Bearer $JINA_API_KEY" \
88
- -d '{
89
- "model": "jina-vlm",
90
- "messages": [{
91
- "role": "user",
92
- "content": [
93
- {"type": "text", "text": "Describe this image"},
94
- {"type": "image_url", "image_url": {"url": "https://example.com/photo.jpg"}}
95
- ]
96
- }]
97
- }'
98
- ```
99
 
100
 
101
- ### Local image (base64)
102
 
103
- ```bash
104
- curl https://api-beta-vlm.jina.ai/v1/chat/completions \
105
- -H "Content-Type: application/json" \
106
- -H "Authorization: Bearer $JINA_API_KEY" \
107
- -d '{
108
- "model": "jina-vlm",
109
- "messages": [{
110
- "role": "user",
111
- "content": [
112
- {"type": "text", "text": "What is in this image?"},
113
- {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,'$(base64 -i image.jpg)'"}}
114
- ]
115
- }]
116
- }'
117
- ```
118
 
 
119
 
120
- ### Text-only query
 
 
 
 
 
 
 
 
 
121
 
122
- ```bash
123
- curl https://api-beta-vlm.jina.ai/v1/chat/completions \
124
- -H "Content-Type: application/json" \
125
- -H "Authorization: Bearer $JINA_API_KEY" \
126
- -d '{
127
- "model": "jina-vlm",
128
- "messages": [{"role": "user", "content": "What is the capital of France?"}]
129
- }'
130
- ```
131
 
132
- ### Streaming response
133
 
134
- Add `"stream": true` to receive tokens as they're generated:
135
 
136
- ```bash
137
- curl https://api-beta-vlm.jina.ai/v1/chat/completions \
138
- -H "Content-Type: application/json" \
139
- -H "Authorization: Bearer $JINA_API_KEY" \
140
- -d '{
141
- "model": "jina-vlm",
142
- "stream": true,
143
- "messages": [{"role": "user", "content": "Write a haiku about coding"}]
144
- }'
145
- ```
146
 
147
- When the service is cold starting, you'll receive:
148
 
149
- ```json
150
- {
151
- "error": {
152
- "message": "Model is loading, please retry in 30-60 seconds. Cold start takes ~30s after the service scales up.",
153
- "code": 503
154
- }
155
- }
156
- ```
157
 
158
- Simply retry your request after waiting.
159
 
 
160
 
161
- ## Local Installation
 
 
 
 
 
 
162
 
163
- ```bash
164
- uv sync
165
- ```
166
 
167
- For CUDA users with FlashAttention2 support:
168
- ```bash
169
- uv sync --extra flash-attn
170
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  ### Using the CLI
173
 
174
- You can directly chat with `jina-vlm` using the `infer.py` CLI:
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  ```bash
177
  # Single image
178
- python infer.py -i image.jpg -p "What's in this image?"
179
 
180
- # Streaming output
181
- python infer.py -i image.jpg -p "Describe this image" --stream
182
 
183
- # Multiple images
184
- python infer.py -i img1.jpg -i img2.jpg -p "Compare these images"
185
 
186
- # Text-only
187
- python infer.py -p "What is the capital of France?"
188
- ```
189
 
190
- **Options:**
191
- - `-m, --model`: Model path. Auto-detects local repo (if `config.json` exists) or falls back to `jinaai/jina-vlm` from HuggingFace.
192
- - `-i, --image`: Image path, URL, or glob pattern (can specify multiple times).
193
- - `-p, --prompt`: Text prompt (can specify multiple times).
194
- - `--max-crops`: Maximum crops (default: 12).
195
- - `--max-tokens`: Maximum output tokens (default: 1024).
196
- - `--max-pixels`: Max pixels per image, larger images are resized preserving aspect ratio.
197
- - `--stream`: Enable streaming output.
198
 
199
- **Example:**
 
 
200
 
201
- ```bash
202
- python infer.py -i assets/the_persistence_of_memory.jpg -p "Describe this picture"
203
- ```
204
 
205
- <table>
206
- <tr>
207
- <td width="40%"><b>Input</b></td>
208
- <td width="60%"><b>Output</b></td>
209
- </tr>
210
- <tr>
211
- <td><img src="./assets/the_persistence_of_memory.jpg" width="100%"></td>
212
- <td>
213
 
214
- ```
215
- * Conversation 1/1
216
- ├── 🖼️Images: ['the_persistence_of_memory.jpg']
217
- ├── 📜Prompt: Describe this picture
218
- └── 🧠Response: This image is a surreal painting
219
- by Salvador Dalí, titled "The Persistence of
220
- Memory." It features a dreamlike landscape with
221
- a variety of melting clocks and other objects.
222
- The central focus is a melting clock with a blue
223
- face and yellow hands, which is hanging from a
224
- branch...
225
-
226
- Token usage: 1753 tokens (4.3%)
227
- Generated in 8.68s | 20.04 tok/s
228
- ```
229
 
230
- </td>
231
- </tr>
232
- </table>
233
 
234
- ### Using Transformers
 
 
 
235
 
236
- ```python
237
- import torch
238
- from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
239
 
240
- processor = AutoProcessor.from_pretrained(
241
- 'jinaai/jina-vlm', use_fast=False, trust_remote_code=True
242
- )
243
- model = AutoModelForCausalLM.from_pretrained(
244
- 'jinaai/jina-vlm',
245
- device_map='auto',
246
- trust_remote_code=True
247
- )
248
 
249
- image = 'https://picsum.photos/800/600'
250
- conversation = [
251
- {
252
- 'role': 'user',
253
- 'content': [
254
- {'type': 'image', 'image': image},
255
- {'type': 'text', 'text': 'Describe this image'},
256
- ],
257
- }
258
- ]
259
 
260
- text = processor.apply_chat_template(conversation, add_generation_prompt=True)
261
- inputs = processor(text=[text], images=[image], padding='longest', return_tensors='pt')
262
- inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
263
 
264
- output = model.generate(
265
- **inputs,
266
- generation_config=GenerationConfig(max_new_tokens=512, do_sample=False),
267
- return_dict_in_generate=True,
268
- use_model_defaults=True,
269
- )
270
 
271
- response = processor.tokenizer.decode(
272
- output.sequences[0][inputs['input_ids'].shape[-1]:],
273
- skip_special_tokens=True
274
- )
275
- print(response)
 
 
 
 
 
276
  ```
277
 
278
- <details>
279
- <summary>Multi-image inference</summary>
280
 
281
  ```python
282
- images = ['https://picsum.photos/id/1/800/600', 'https://picsum.photos/id/2/800/600']
283
- conversation = [
284
- {
285
- 'role': 'user',
286
- 'content': [
287
- {'type': 'image', 'image': images[0]},
288
- {'type': 'image', 'image': images[1]},
289
- {'type': 'text', 'text': 'What is the difference between these images?'},
290
- ],
291
- }
292
- ]
293
- text = processor.apply_chat_template(conversation, add_generation_prompt=True)
294
- inputs = processor(text=[text], images=images, padding='longest', return_tensors='pt')
295
- inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
296
-
297
- output = model.generate(
298
- **inputs,
299
- generation_config=GenerationConfig(max_new_tokens=512, do_sample=False),
300
- return_dict_in_generate=True,
301
- use_model_defaults=True,
302
- )
303
- response = processor.tokenizer.decode(
304
- output.sequences[0][inputs['input_ids'].shape[-1]:],
305
- skip_special_tokens=True
306
  )
307
- print(response)
308
- ```
309
 
310
- </details>
 
 
 
 
 
 
311
 
312
- <details>
313
- <summary>Text-only inference</summary>
314
 
315
- ```python
316
- conversation = [
 
 
 
 
 
317
  {
318
- 'role': 'user',
319
- 'content': [
320
- {'type': 'text', 'text': 'Explain quantum computing in simple terms'},
 
 
 
 
321
  ],
322
  }
323
  ]
324
- text = processor.apply_chat_template(conversation, add_generation_prompt=True)
325
- inputs = processor(text=[text], padding='longest', return_tensors='pt')
326
- inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
327
-
328
- output = model.generate(
329
- **inputs,
330
- generation_config=GenerationConfig(max_new_tokens=512, do_sample=False),
331
- return_dict_in_generate=True,
332
- use_model_defaults=True,
333
- )
334
- response = processor.tokenizer.decode(
335
- output.sequences[0][inputs['input_ids'].shape[-1]:],
336
- skip_special_tokens=True
337
- )
338
- print(response)
339
- ```
340
-
341
- </details>
342
 
343
- <details>
344
- <summary>Batch inference</summary>
345
-
346
- ```python
347
- import torch
348
- from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
349
-
350
- processor = AutoProcessor.from_pretrained(
351
- 'jinaai/jina-vlm', use_fast=False, trust_remote_code=True
352
  )
353
- model = AutoModelForCausalLM.from_pretrained(
354
- 'jinaai/jina-vlm',
355
- device_map='auto',
356
- torch_dtype=torch.bfloat16,
357
- attn_implementation='flash_attention_2',
358
- trust_remote_code=True
 
359
  )
 
360
 
361
- images = [
362
- 'https://picsum.photos/id/22/800/600',
363
- 'https://picsum.photos/id/49/800/600'
364
- ]
365
- conversations = [
366
- [
367
- {
368
- 'role': 'user',
369
- 'content': [
370
- {'type': 'image', 'image': images[0]},
371
- {'type': 'text', 'text': 'What is the man doing in this image?'},
372
- ],
373
- }
374
- ],
375
- [
376
- {
377
- 'role': 'user',
378
- 'content': [
379
- {'type': 'image', 'image': images[1]},
380
- {'type': 'text', 'text': 'What country\'s flag is in this image?'},
381
- ],
382
- }
383
- ],
384
  ]
385
-
386
- texts = processor.apply_chat_template(conversations, add_generation_prompt=True)
387
- inputs = processor(text=texts, images=images, padding='longest', return_tensors='pt')
388
- inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
389
-
390
- output = model.generate(
391
- **inputs,
392
- generation_config=GenerationConfig(max_new_tokens=512, do_sample=False),
393
- return_dict_in_generate=True,
394
- use_model_defaults=True,
395
  )
396
-
397
- for idx in range(len(output.sequences)):
398
- gen_ids = output.sequences[idx][inputs['input_ids'].shape[-1]:]
399
- response = processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
400
- print(f"Response {idx+1}: {response}")
401
  ```
402
 
 
 
403
  </details>
404
 
405
  <details>
406
- <summary>Batch inference with mixed examples</summary>
407
-
408
- ```python
409
- import torch
410
- from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
411
-
412
- processor = AutoProcessor.from_pretrained(
413
- 'jinaai/jina-vlm', use_fast=False, trust_remote_code=True
414
- )
415
- model = AutoModelForCausalLM.from_pretrained(
416
- 'jinaai/jina-vlm',
417
- device_map='auto',
418
- torch_dtype=torch.bfloat16,
419
- attn_implementation='flash_attention_2',
420
- trust_remote_code=True
421
- )
422
 
423
- images = [
424
- ['https://picsum.photos/id/22/800/600'],
425
- ['https://picsum.photos/id/49/800/600'],
426
- ['https://picsum.photos/id/0/800/600', 'https://picsum.photos/id/2/800/600'],
427
- [],
428
- ]
429
- conversations = [
430
- [
431
- {
432
- 'role': 'user',
433
- 'content': [
434
- {'type': 'image', 'image': images[0][0]},
435
- {'type': 'text', 'text': 'What is the man doing in this image?'},
436
- ],
437
- }
438
- ],
439
- [
440
- {
441
- 'role': 'user',
442
- 'content': [
443
- {'type': 'image', 'image': images[1][0]},
444
- {'type': 'text', 'text': 'What country\'s flag is in this image?'},
445
- ],
446
- }
447
- ],
448
- [
449
- {
450
- 'role': 'user',
451
- 'content': [
452
- {'type': 'image', 'image': images[2][0]},
453
- {'type': 'image', 'image': images[2][1]},
454
- {'type': 'text', 'text': 'What is the difference between these two images?'},
455
- ],
456
- }
457
- ],
458
- [
459
- {
460
- 'role': 'user',
461
- 'content': [
462
- {'type': 'text', 'text': 'Describe the concept of polymorphism in Computer Science'},
463
- ],
464
- }
465
- ],
466
- ]
467
 
468
- texts = processor.apply_chat_template(conversations, add_generation_prompt=True)
469
- inputs = processor(text=texts, images=images, padding='longest', return_tensors='pt')
470
- inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
471
 
472
- output = model.generate(
473
- **inputs,
474
- generation_config=GenerationConfig(max_new_tokens=512, do_sample=False),
475
- return_dict_in_generate=True,
476
- use_model_defaults=True,
477
- )
478
 
479
- for idx in range(len(output.sequences)):
480
- gen_ids = output.sequences[idx][inputs['input_ids'].shape[-1]:]
481
- response = processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
482
- print(f"Response {idx+1}: {response}")
483
- ```
484
 
485
- </details>
486
 
487
- ## Evaluation
488
 
489
- ### Multilingual Understanding
490
 
491
- | Model | MMMB ar | MMMB cn | MMMB en | MMMB avg | MMBench avg | Overall |
492
- |-------|---------|---------|---------|----------|-------------|---------|
493
- | **jina-vlm** | **76.9** | **80.0** | **82.0** | **78.8** | **74.3** | **59.6** |
494
- | Qwen2-VL-2B | 68.3 | 74.2 | 78.3 | 71.3 | 69.4 | 53.8 |
495
- | Qwen3-VL-2B | 72.7 | 75.7 | 80.7 | 75.0 | 72.3 | 58.2 |
496
- | InternVL3-2B | 68.6 | 78.3 | 81.9 | 73.6 | 71.9 | 57.4 |
497
- | InternVL3.5-2B | 68.5 | 77.7 | 80.2 | 74.6 | 70.9 | 58.0 |
498
 
499
- ### General VQA Tasks
500
 
501
- | Model | AI2D | ChartQA | TextVQA | DocVQA | InfoVQA | OCRBench | SEED-2+ | CharXiv | Avg |
502
- |-------|------|---------|---------|--------|---------|----------|---------|---------|-----|
503
- | **jina-vlm** | **82.0** | **81.9** | **83.2** | 90.6 | 71.6 | 778 | 67.2 | **32.3**/63.5 | **72.3** |
504
- | Qwen2-VL-2B | 74.7 | 73.5 | 79.7 | 89.2 | 64.0 | 809 | 62.4 | 23.3/55.0 | 66.4 |
505
- | Qwen3-VL-2B | 76.9 | 77.2 | 79.5 | **92.3** | **71.9** | **858** | 67.3 | 28.8/62.3 | 71.6 |
506
- | InternVL3-2B | 78.6 | 80.2 | 77.0 | 87.4 | 67.1 | 835 | 64.6 | 28.3/54.7 | 69.2 |
507
- | InternVL3.5-2B | 78.8 | 80.7 | 76.5 | 88.5 | 69.3 | 836 | **68.0** | 31.6/**65.0** | 71.6 |
508
 
509
- ### Text-Only Performance
510
 
511
- | Model | MMLU | MMLU-Pro | GSM-8K | ARC-C | HellaSwag |
512
- |-------|------|----------|--------|-------|-----------|
513
- | **jina-vlm** | 56.1 | **30.3** | 71.3 | **77.3** | **59.4** |
514
- | Qwen3-1.7B | **62.6** | 46.4 | **75.3** | 73.4 | 59.0 |
515
 
516
  ## Citation
517
 
518
- If you find `jina-vlm` useful in your research, please cite our [technical report](https://arxiv.org/abs/2512.04032):
519
-
520
- ```bibtex
521
- @misc{koukounas2025jinavlm,
522
- title={Jina-VLM: Small Multilingual Vision Language Model},
523
- author={Andreas Koukounas and Georgios Mastrapas and Florian Hönicke and Sedigheh Eslami and Guillaume Roncari and Scott Martens and Han Xiao},
524
- year={2025},
525
- eprint={2512.04032},
526
- archivePrefix={arXiv},
527
- primaryClass={cs.CL},
528
- url={https://arxiv.org/abs/2512.04032},
529
- }
530
- ```
531
 
532
- ## License
533
-
534
- `jina-vlm` is licensed under CC BY-NC 4.0. For commercial usage inquiries, feel free to [contact us](https://jina.ai/contact-sales/).
 
 
5
  tags:
6
  - multimodal
7
  - multilingual
8
+ - vllm
9
  - vlm
10
+ - mllm
 
 
11
  language:
12
  - en
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  - multilingual
 
 
 
14
  inference: false
15
  ---
16
+ <br><br>
17
 
18
  <p align="center">
19
+ <img src="https://raw.githubusercontent.com/jina-ai/.github/refs/heads/main/profile/1.png">
20
  </p>
21
 
22
+ <p align="center">
23
+ <b>By <a href="https://jina.ai/"><b>Jina AI</b></a></b>
24
+ </p>
25
 
26
+ TODO: Update title when ready
27
 
28
+ # Jina VLM v1: Lightweight Vision Language Alignment
29
 
30
+ [GGUF]() | [Blog]() | [Technical Report]()
31
 
32
+ A small 🔍
33
 
34
+ Yet Mighty 🔥
 
 
 
 
 
 
35
 
36
+ Multimodal 👁️
37
 
38
+ and Multilingual 🌐
39
 
40
+ Vision-Language Model 🧠
41
 
42
 
43
+ ## Overview
44
 
45
+ TODO: Update overview when ready
 
 
 
46
 
47
+ We introduce `jina-vlm-v1`, a compact vision-language model with a focus on downstream embedding performance, computational efficiency, pure text performance and multilingual support. We explore the alignment of an encoder-only vision model with a decoder-only language model, with an emphasis on representation learning under a resource-constrained setting. Our approach employs a straightforward two-stage training strategy with fully unlocked model weights. Images are converted into fixed-size crops via overlapped cropping to enable high-resolution and any-resolution understanding. The crops are then split into patches and embedded to visual features by the vision encoder. The visual features are pooled, projected and injected into a small language model as visual tokens. We openly release `jina-vlm-v1` to facilitate further research in this domain.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
+ ## Model Info
51
 
52
+ <p align="center">
53
+ <img src="./assets/jvlm_architecture.png">
54
+ </p>
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ Summary of features:
57
 
58
+ | Feature | Jina VLM v1 |
59
+ |-------------------------|----------------------------------------------------------------------------|
60
+ | Type | VLM - Vision Language Model |
61
+ | Modalities | Texts, Images |
62
+ | Base Text Decoder | [Qwen3-1.7B-Base](https://huggingface.co/Qwen/Qwen3-1.7B-Base) |
63
+ | Base Vision Encoder | [SigLIP2 So400M](https://huggingface.co/google/siglip2-so400m-patch14-384) |
64
+ | Parameters | 2.4B |
65
+ | Max Sequence Length | 32768 |
66
+ | Single-Vector Dimension | 2048 |
67
+ | Attention Mechanisms | FlashAttention2, SDPA, Eager |
68
 
69
+ TODO: Add ArXiv link when ready
 
 
 
 
 
 
 
 
70
 
71
+ Check out our [technical report of jina-vlm-v1]() for more details on model architecture, training and evaluation.
72
 
 
73
 
74
+ ## Evaluation
 
 
 
 
 
 
 
 
 
75
 
76
+ ### General VQA Tasks
77
 
78
+ | Model Name | AI2D | ChartQA (test avg) | TextVQA (val) | DocVQA (val) | InfoVQA (val) | OCR Bench | SEED-2 Plus | CharXiv (RQ/DQ) | Overall |
79
+ |:--------------------------------------------------------------------|:--------:|:------------------:|:-------------:|:------------:|:-------------:|:---------:|:-----------:|:---------------:|:--------:|
80
+ | [`jina-vlm-v1`](https://huggingface.co/jinaai/jina-vlm-v1) | **82.0** | **81.9** | **83.2** | 90.6 | 71.6 | 778 | 67.2 | **32.3** / 63.5 | **72.3** |
81
+ | [`Qwen2-VL-2B`](https://huggingface.co/Qwen/Qwen2-VL-2B) | 74.7 | 73.5 | 79.7 | 89.2* | 64.0* | 809 | 62.4 | 23.3 / 55.0* | 66.4 |
82
+ | [`Qwen3-VL-2B`](https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct) | 76.9 | 77.2 | 79.5 | **92.3*** | **71.9*** | **858** | 67.3* | 28.8 / 62.3 | 71.6 |
83
+ | [`InternVL3-2B`](https://huggingface.co/OpenGVLab/InternVL3-2B) | 78.6 | 80.2 | 77.0 | 87.4* | 67.1* | 835 | 64.6 | 28.3 / 54.7 | 69.2 |
84
+ | [`InternVL3.5-2B`](https://huggingface.co/OpenGVLab/InternVL3_5-2B) | 78.8 | 80.7 | 76.5 | 88.5* | 69.3* | 836 | **68.0** | 31.6 / **65.0** | 71.6 |
 
85
 
86
+ Comparison of general visual question answering performance. Other model results are from their respective papers, except those marked with * which are computed using [VLMEvalKit](https://github.com/open-compass/VLMEvalKit). All scores represent accuracy (%) except OCRBench which uses a 0-1000 scale, normalized to 0-100 for Overall calculation.
87
 
88
+ ### Multimodal Comprehension and Real-World Understanding
89
 
90
+ | Model | MME (sum) | MMB v1.1 (EN) | MMStar | Overall (MM) | RealWorld QA | MME-RW (EN) | R-Bench (dis) | Overall (RW) |
91
+ |:--------------------------------------------------------------------|:---------:|:-------------:|:------:|:------------:|:------------:|:-----------:|:-------------:|:------------:|
92
+ | [`jina-vlm-v1`](https://huggingface.co/jinaai/jina-vlm-v1) | 1965.8 | 75.8 | 56.2 | 67.4 | **68.2** | 50.7 | 66.7 | 61.9 |
93
+ | [`Qwen2-VL-2B`](https://huggingface.co/Qwen/Qwen2-VL-2B) | 1872.0 | 72.2 | 48.0 | 62.4 | 62.9 | 38.7* | 63.2 | 55.0* |
94
+ | [`Qwen3-VL-2B`](https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct) | 2000.8* | 77.8 | 58.3 | 69.2 | 63.9 | **57.9*** | 67.3* | **63.0** |
95
+ | [`InternVL3-2B`](https://huggingface.co/OpenGVLab/InternVL3-2B) | **2221.2** | **78.6** | 60.7 | **72.9** | 64.3 | 53.8 | **67.5** | 61.9 |
96
+ | [`InternVL3.5-2B`](https://huggingface.co/OpenGVLab/InternVL3_5-2B) | 2123.3 | 76.6 | **62.7** | 71.7 | 62.0 | 49.7 | 62.4 | 58.0 |
97
 
98
+ Comparison of generic multimodal understanding and real-world understanding performance. Other model results are from their respective papers, except those marked with * which are computed using [VLMEvalKit](https://github.com/open-compass/VLMEvalKit). All scores represent accuracy (%) except MME which uses a 0-2800 scale, normalized to 0-100 for Overall calculation.
 
 
99
 
100
+ ### Multi-Image Reasoning and Hallucination
101
+
102
+ | Model | BLINK (val) | Muir Bench | MMT (val) | Overall (MI) | HallBench (avg) | POPE (avg) | Overall (Hall) |
103
+ |:--------------------------------------------------------------------|:-----------:|:----------:|:---------:|:------------:|:---------------:|:----------:|:--------------:|
104
+ | [`jina-vlm-v1`](https://huggingface.co/jinaai/jina-vlm-v1) | 50.1 | 34.7 | 57.2 | 47.3 | 39.1 | **90.3** | 64.7 |
105
+ | [`Qwen2-VL-2B`](https://huggingface.co/Qwen/Qwen2-VL-2B) | 44.4 | 25.5* | 55.1 | 41.7 | 41.7 | 87.9* | 64.8 |
106
+ | [`Qwen3-VL-2B`](https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct) | **53.8** | **47.4** | **60.0*** | **53.7** | 44.5 | 88.9* | 66.7 |
107
+ | [`InternVL3-2B`](https://huggingface.co/OpenGVLab/InternVL3-2B) | 50.3 | 38.8 | 59.5 | 49.5 | 42.5 | 89.6 | 66.1 |
108
+ | [`InternVL3.5-2B`](https://huggingface.co/OpenGVLab/InternVL3_5-2B) | 51.3 | 44.0 | 58.5 | 51.3 | **48.6** | 87.2 | **67.9** |
109
+
110
+ Comparison of multi-image and hallucination performance. Other model results are from their respective papers, except those marked with * which are computed using [VLMEvalKit](https://github.com/open-compass/VLMEvalKit). All scores represent accuracy (%).
111
+
112
+ ### Multimodal Reasoning and Mathematics
113
+
114
+ | Model | MMMU | MathVista | MathVision | MathVerse (Vision Only) | WeMath | LogicVista | Overall |
115
+ |:--------------------------------------------------------------------|:--------:|:-------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:--------:|
116
+ | [`jina-vlm-v1`](https://huggingface.co/jinaai/jina-vlm-v1) | 45.6 | 59.5 | 19.2 | 23.9 | 17.1 | 33.3 | 33.1 |
117
+ | [`Qwen2-VL-2B`](https://huggingface.co/Qwen/Qwen2-VL-2B) | 41.1 | 43.0 | 12.4 | 17.3* | 10.9* | 27.3* | 25.3 |
118
+ | [`Qwen3-VL-2B`](https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct) | 53.4 | 61.3 | 31.6 | 22.7* | 28.0* | 35.4* | 38.7 |
119
+ | [`InternVL3-2B`](https://huggingface.co/OpenGVLab/InternVL3-2B) | 48.6 | 57.0 | 21.7 | 25.3 | 22.4 | 36.9 | 35.3 |
120
+ | [`InternVL3.5-2B`](https://huggingface.co/OpenGVLab/InternVL3_5-2B) | **59.0** | **71.8** / 61.5† | **42.8** / 26.5† | **53.4** / 35.3† | **48.5** / 19.1† | **47.7** / 41.4† | **50.7** |
121
+
122
+ Comparison of multimodal reasoning and mathematical problem-solving performance. Other model results are from their respective papers, except those marked with * which are computed using [VLMEvalKit](https://github.com/open-compass/VLMEvalKit). † indicates scores for [`InternVL3.5-2B`](https://huggingface.co/OpenGVLab/InternVL3_5-2B) without thinking mode, evaluated using [VLMEvalKit](https://github.com/open-compass/VLMEvalKit). All scores represent accuracy (%).
123
+
124
+ ### Text-Only Performance
125
+
126
+ | Model | MMLU | MMLU-Pro | GSM-8K | ARC-C | HellaSwag |
127
+ |:-----------------------------------------------------------|:--------:|:--------:|:--------:|:--------:|:---------:|
128
+ | [`jina-vlm-v1`](https://huggingface.co/jinaai/jina-vlm-v1) | 56.1 | **30.3** | 69.6 | **76.0** | **59.4** |
129
+ | [`Qwen3-1.7B`](https://huggingface.co/Qwen/Qwen3-1.7B) | **62.6** | | **75.3** | | 59.0 |
130
+
131
+ Comparison of text-only benchmarks. Results are collected using our evaluation code. All scores represent accuracy (%).
132
+
133
+ ### Multimodal Multilingual Understanding
134
+
135
+ | Model Name | MMMB ar | MMMB cn | MMMB en | MMMB pt | MMMB ru | MMMB tr | MMMB avg | MMBench ar | MMBench cn | MMBench en | MMBench pt | MMBench ru | MMBench tr | MMBench avg | MTVQA | Overall |
136
+ |:--------------------------------------------------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|:-----------:|:--------:|:--------:|
137
+ | [`jina-vlm-v1`](https://huggingface.co/jinaai/jina-vlm-v1) | **76.9** | **80.0** | **82.0** | **79.2** | **79.2** | **75.5** | **78.8** | **70.0** | 75.9 | 78.8 | 74.7 | 75.3 | **71.1** | **74.3** | 25.6 | **59.6** |
138
+ | [`Qwen2-VL-2B`](https://huggingface.co/Qwen/Qwen2-VL-2B) | 68.3 | 74.2 | 78.3 | 72.6 | 72.8 | 61.8 | 71.3 | 66.7 | 67.0 | 71.1 | 72.1 | 69.9 | 69.3 | 69.4 | 20.6 | 53.8 |
139
+ | [`Qwen3-VL-2B`](https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct) | 72.7* | 75.7* | 80.7* | 75.0* | 75.9* | 68.5* | 75.0* | 66.2* | 75.7* | 77.8* | 71.4* | **75.9*** | 67.0* | 72.3* | 27.3* | 58.2 |
140
+ | [`InternVL3-2B`](https://huggingface.co/OpenGVLab/InternVL3-2B) | 68.6 | 78.3 | 81.9 | 75.4 | 74.6 | 62.9 | 73.6 | 66.4 | **77.8** | **81.3** | **75.9** | 70.7 | 59.5 | 71.9 | 26.7 | 57.4 |
141
+ | [`InternVL3.5-2B`](https://huggingface.co/OpenGVLab/InternVL3_5-2B) | 68.5 | 77.7 | 80.2 | 75.9 | 76.3 | 69.1 | 74.6 | 63.7 | 75.9 | 78.4 | 73.7 | 71.4 | 62.0 | 70.9 | **28.5** | 58.0 |
142
+
143
+ Comparison of multilingual multimodal understanding performance. Other model results are from their respective papers, except those marked with * which are computed using [VLMEvalKit](https://github.com/open-compass/VLMEvalKit). All scores represent accuracy (%).
144
+
145
+ ### Embedding Performance
146
+
147
+ | Task / Metric | [`Qwen3-VL-2B`](https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct) | [`Qwen2.5-VL-3B`](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) | [`InternVL3.5-2B`](https://huggingface.co/OpenGVLab/InternVL3_5-2B) | [`Qwen2-VL-2B`](https://huggingface.co/Qwen/Qwen2-VL-2B) | [`jina-vlm-v1`](https://huggingface.co/jinaai/jina-vlm-v1) |
148
+ |:-------------------------------------------|:-----------------------------------------------------------------:|:---------------------------------------------------------------------:|:-------------------------------------------------------------------:|:--------------------------------------------------------:|:----------------------------------------------------------:|
149
+ | Flickr30kT2I Retrieval (NDCG@10) | **86.9** | 83.8 | 84.6 | 85.8 | 86.0 |
150
+ | JinaVDR DocVQA Retrieval (NDCG@5) | **83.1** | 81.1 | 78.2 | 73.6 | 76.9 |
151
+ | JinaVDR InfoVQA Retrieval (NDCG@5) | 88.1 | 87.6 | 87.3 | **88.3** | 84.9 |
152
+ | Nano DBPedia Retrieval (NDCG@10) | 52.4 | 53.3 | 51.1 | 51.7 | **54.0** |
153
+ | Nano FEVER Retrieval (NDCG@10) | 78.3 | **83.2** | 72.8 | 75.1 | 76.3 |
154
+ | Nano FiQA2018 Retrieval (NDCG@10) | 40.4 | 45.0 | 40.3 | **45.7** | 35.8 |
155
+ | Nano HotpotQA Retrieval (NDCG@10) | 69.5 | **72.1** | 65.5 | 69.9 | 70.1 |
156
+ | Nano MS MARCO Retrieval (NDCG@10) | 48.0 | 48.7 | **49.5** | 47.5 | 45.7 |
157
+ | Nano NFCorpus Retrieval (NDCG@10) | 31.7 | **34.4** | 34.0 | 30.7 | 32.9 |
158
+ | Nano NQ Retrieval (NDCG@10) | 49.3 | **51.4** | 48.2 | 48.8 | 48.2 |
159
+ | Nano SCIDOCS Retrieval (NDCG@10) | **41.6** | 40.7 | 39.1 | 39.0 | 38.7 |
160
+ | Nano SciFact Retrieval (NDCG@10) | 73.0 | **78.0** | 73.2 | 70.6 | 77.2 |
161
+ | STS12 (Spearman) | 67.3 | 65.1 | 67.4 | 68.3 | **69.3** |
162
+ | SciFact (NDCG@10) | 69.7 | **71.2** | 68.2 | 66.0 | 68.5 |
163
+ | Vidore ArXivQA Retrieval (NDCG@5) | 74.4 | **80.2** | 74.8 | 75.8 | 74.4 |
164
+ | **Average** | 63.6 | **65.1** | 62.3 | 62.5 | 62.6 |
165
+
166
+ Pair training for single-vector embeddings. Higher is better. Averages are macro-averages across all tasks.
167
+
168
+
169
+ ## Usage
170
+
171
+ ### Requirements
172
+
173
+ The following Python packages are required:
174
+
175
+ - `torch>=2.9.0`
176
+ - `torchvision>=0.24.0`
177
+ - `transformers>=4.57.0`
178
+ - `pillow>=12.0.0`
179
+ - `einops>=0.8.1`
180
+
181
+ Optional but recommended packages:
182
+
183
+ - **flash-attention**: Installing [flash-attention](https://github.com/Dao-AILab/flash-attention) is recommended for improved inference speed and efficiency, but not mandatory.
184
 
185
  ### Using the CLI
186
 
187
+ You can directly chat with `jina-vlm-v1` using the `test_jvlm.py` CLI.
188
+
189
+ **Options:**
190
+ - `-m, --model`: Model path (default: `'.'`). Set this to `'jinaai/jina-vlm-v1'` if you are running this script outside this repo.
191
+ - `-i, --image`: Image path, URL, or glob pattern (can specify multiple times, default: `[]`).
192
+ - `-p, --prompt`: Text prompt (can specify multiple times, default: `'Describe the image for me in 100 words'` or `'Describe the images for me in 100 words'` if multiple images are provided).
193
+ - `--max-crops`: Maximum crops (default: `12`).
194
+ - `--max-tokens`: Maximum output tokens (default: `1024`).
195
+ - `--max-pixels`: Max pixels per image, larger images are resized and the aspect ratio is preserved (default: `None`).
196
+ - `--stream`: Enable streaming (default: `False`).
197
+ - `--image-labels`: Enable ordinal text labels after each image (default: `False` -> no image labels for multi-image).
198
+ - `--prompt-first`: Place prompt before images instead of after (default: `False` -> prompt after images).
199
+ - `--map`: Map mode - apply single prompt to multiple images OR multiple prompts to single image (default: `False` -> no mapping).
200
 
201
  ```bash
202
  # Single image
203
+ python test_jvlm.py -i photo.jpg -p "What's in this image?"
204
 
205
+ # Single image with streaming
206
+ python test_jvlm.py -i photo.jpg -p "What's in this image?" --stream
207
 
208
+ # Remote image URL
209
+ python test_jvlm.py -i https://example.com/image.jpg -p "Describe this image"
210
 
211
+ # Multiple images (local and remote)
212
+ python test_jvlm.py -i img1.jpg -i https://example.com/img2.jpg -i img3.jpg -p "Compare these images"
 
213
 
214
+ # Text only input
215
+ python test_jvlm.py -p "How many planets are in our solar system?"
 
 
 
 
 
 
216
 
217
+ # Glob pattern support (quote patterns to prevent shell expansion)
218
+ python test_jvlm.py -i "*.jpg" -p "Describe these images"
219
+ python test_jvlm.py -i "photos/*.png" -i "images/*.jpg" -p "What do you see in these images?"
220
 
221
+ # Custom max crops, max pixels and max output tokens
222
+ # Reducing max crops and max pixels speeds up inference and lowers mem consumption on large images
223
+ python test_jvlm.py -i photo.jpg -p "Describe this picture in detail" --max-crops 8 --max-pixels 500000 --max-tokens 2048
224
 
225
+ # Prompt position control
226
+ python test_jvlm.py -i photo.jpg -p "What's in this image?" --prompt-first
 
 
 
 
 
 
227
 
228
+ # Map mode: apply one prompt to multiple images
229
+ python test_jvlm.py --map -i "*.jpg" -p "What is this?"
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ # Map mode: apply multiple prompts to one image
232
+ python test_jvlm.py --map -i photo_of_a_dog.jpg -p "What breed?" -p "What color?" -p "Happy or sad?"
 
233
 
234
+ # Batch inference
235
+ # When an equal number of images and prompts (>1) is provided, we assume it is batched inference
236
+ # Generation will run in a batch if streaming is disabled, otherwise sequentially
237
+ python test_jvlm.py -i photo1.jpg -p "What is shown in this image?" -i photo2.jpg -p "Describe this image"
238
 
239
+ # Similarly for no images and multiple prompts
240
+ python test_jvlm.py -p "What is a neural network?" -p "Describe the concept of polymorphism in Computer Science"
241
+ ```
242
 
243
+ Example input:
244
+ ```bash
245
+ python test_jvlm.py -m jinaai/jina-vlm-v1 -i assets/the_persistence_of_memory.jpg -p "Describe this picture"
246
+ ```
247
+ <p align="center">
248
+ <img src="./assets/the_persistence_of_memory.jpg">
249
+ </p>
 
250
 
251
+ Example output:
252
+ ```
253
+ * Conversation 1/1
254
+ ├── 🖼️Images: ['assets/the_persistence_of_memory.jpg']
255
+ ├── 📜Prompt: Describe this picture
256
+ ├── 💬Chat: User: <|image|>Describe this picture Assistant:
257
+ └── 🧠Response: This image is a surrealistic painting by Salvador Dalí, titled "The Persistence of Memory." The painting is characterized by its dreamlike and distorted elements, which are hallmarks of Dalí's style. The central focus of the painting is a melting clock, which is a key symbol in the artwork. The clock is depicted in a state of fluidity, with its hands and numbers melting and flowing as if it is made of wax.
 
 
 
258
 
259
+ In the foreground, there is a wooden table with a branch extending from it. The branch holds a second clock, which is also melting and dripping. To the left of the table, there is a small, round, orange object that appears to be a pocket watch or a small container.
 
 
260
 
261
+ The background of the painting features a landscape with a calm sea and a rocky cliff. The sky is painted in shades of blue and yellow, suggesting either a sunrise or sunset. The overall color palette of the painting is muted, with earthy tones dominating the scene.
 
 
 
 
 
262
 
263
+ The painting is a prime example of Dalí's use of surrealism, which involves the depiction of bizarre and dreamlike scenes. The melting clocks and distorted forms are typical of Dalí's work, which often explores themes of time, memory, and the subconscious mind. The painting is a testament to Dalí's innovative and imaginative approach to art.
264
+ Token usage report:
265
+ Input Context Window Layout (max: 40960 tokens):
266
+ ├── Total: 1753 tokens (4.3%)
267
+ ├── Image 1 → 1744 tokens (4.3%)
268
+ └── Text: 9 tokens (0.0%)
269
+
270
+ Generated 1 responses in 33.078s
271
+ 0.03 res/s 8.16 tok/s
272
+ Done ✅
273
  ```
274
 
275
+ ### Using Transformers 🤗
 
276
 
277
  ```python
278
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
279
+ from qwen_vl_utils import process_vision_info
280
+
281
+ # default: Load the model on the available device(s)
282
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
283
+ "Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  )
 
 
285
 
286
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
287
+ # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
288
+ # "Qwen/Qwen2.5-VL-3B-Instruct",
289
+ # torch_dtype=torch.bfloat16,
290
+ # attn_implementation="flash_attention_2",
291
+ # device_map="auto",
292
+ # )
293
 
294
+ # default processer
295
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
296
 
297
+ # The default range for the number of visual tokens per image in the model is 4-16384.
298
+ # You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
299
+ # min_pixels = 256*28*28
300
+ # max_pixels = 1280*28*28
301
+ # processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
302
+
303
+ messages = [
304
  {
305
+ "role": "user",
306
+ "content": [
307
+ {
308
+ "type": "image",
309
+ "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
310
+ },
311
+ {"type": "text", "text": "Describe this image."},
312
  ],
313
  }
314
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
+ # Preparation for inference
317
+ text = processor.apply_chat_template(
318
+ messages, tokenize=False, add_generation_prompt=True
 
 
 
 
 
 
319
  )
320
+ image_inputs, video_inputs = process_vision_info(messages)
321
+ inputs = processor(
322
+ text=[text],
323
+ images=image_inputs,
324
+ videos=video_inputs,
325
+ padding=True,
326
+ return_tensors="pt",
327
  )
328
+ inputs = inputs.to("cuda")
329
 
330
+ # Inference: Generation of the output
331
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
332
+ generated_ids_trimmed = [
333
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  ]
335
+ output_text = processor.batch_decode(
336
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
 
 
 
 
 
 
 
 
337
  )
338
+ print(output_text)
 
 
 
 
339
  ```
340
 
341
+ <details>
342
+ <summary>Batch inference</summary>
343
  </details>
344
 
345
  <details>
346
+ <summary>Multi-image inference</summary>
347
+ </details>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
+ <details>
350
+ <summary>Text-only inference</summary>
351
+ </details>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
+ <details>
354
+ <summary>Mixed-batch inference</summary>
355
+ </details>
356
 
357
+ <details>
358
+ <summary>Feature extraction</summary>
359
+ </details>
 
 
 
360
 
361
+ ### Using vLLM
 
 
 
 
362
 
363
+ Coming soon!
364
 
 
365
 
366
+ ## License
367
 
368
+ The models is licensed under CC-BY-NC 4.0. For commercial usage inquiries, feel free to [contact us](https://jina.ai/contact-sales/).
 
 
 
 
 
 
369
 
 
370
 
371
+ ## Contact
 
 
 
 
 
 
372
 
373
+ Join our [Discord community](https://discord.jina.ai) and chat with other community members about ideas.
374
 
 
 
 
 
375
 
376
  ## Citation
377
 
378
+ TODO: Add citation when ready
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
+ If you find `jina-vlm-v1` useful in your research, please cite the following paper:
381
+ ```
382
+ TBD
383
+ ```
assets/jvlm_architecture.png CHANGED

Git LFS Details

  • SHA256: 1d33806662487fa930aae7ffd1335833156a73758436b38b4abc3aff62691e66
  • Pointer size: 131 Bytes
  • Size of remote file: 654 kB

Git LFS Details

  • SHA256: 8941f6788e95e12904ac301bff2f37089a1b2421e2c44c4cffa1743a62a3915e
  • Pointer size: 131 Bytes
  • Size of remote file: 248 kB
blocks_jvlm.py CHANGED
@@ -11,7 +11,6 @@ import torch
11
  import torch.backends.cuda
12
  import torch.nn as nn
13
  import torch.nn.functional as f
14
- from torch.nn.attention import SDPBackend, sdpa_kernel
15
  from transformers import PretrainedConfig
16
  from transformers.activations import ACT2FN
17
  from transformers.cache_utils import Cache
@@ -325,11 +324,10 @@ modeling_rope_utils.py
325
 
326
 
327
  def inv_freq_to_device(rope_forward):
328
- """Sometimes the inv_freq is calculated on the wrong device, or ends up in lower
329
- precision than float32.
330
-
331
- This wrapper ensures that inv_freq is always on the right device and in float32
332
- precision.
333
  """
334
 
335
  @wraps(rope_forward)
@@ -355,6 +353,7 @@ class RotaryEmbedding(nn.Module):
355
  theta: float,
356
  head_dim: int,
357
  hidden_size: int,
 
358
  partial_rotary_factor: float,
359
  device: Optional[torch.device] = None,
360
  scaling: Optional[Dict[str, Any]] = None,
@@ -367,6 +366,7 @@ class RotaryEmbedding(nn.Module):
367
  setattr(self.config, 'rope_theta', theta)
368
  setattr(self.config, 'partial_rotary_factor', partial_rotary_factor)
369
  setattr(self.config, 'head_dim', head_dim)
 
370
  setattr(self.config, 'hidden_size', hidden_size)
371
  setattr(self.config, 'rope_scaling', scaling or {})
372
 
@@ -377,7 +377,9 @@ class RotaryEmbedding(nn.Module):
377
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
378
  device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
379
  seqlen = config.max_position_embeddings or config.max_sequence_length
380
- invfreq, self.attention_scaling = self.rope_init_fn(self.config, device, seqlen)
 
 
381
  self.rope_init_device = device
382
  self.register_buffer('inv_freq', invfreq, persistent=False)
383
  self.original_inv_freq = self.inv_freq
@@ -615,9 +617,11 @@ def _create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
615
  def _ensure_finite(
616
  x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False
617
  ):
618
- """Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the
 
619
  dtype when ``check_neg_inf`` is ``True`` and replace ``float("inf")`` with the
620
- maximum value of the dtype when ``check_pos_inf`` is ``True``"""
 
621
  if check_neg_inf:
622
  x.masked_fill_(x == float('-inf'), torch.finfo(x.dtype).min)
623
  if check_pos_inf:
@@ -637,12 +641,14 @@ def resolve_causal_mask(
637
  # shape: (batch_size, 1, 1, seq_len)
638
  if len(attention_mask.shape) == 2:
639
  attention_mask = attention_mask[:, : past_length + seq_len]
640
- attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[
641
- :, None, None, :
642
- ]
643
  else:
644
  attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
645
- attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
 
 
646
 
647
  # Merge attention mask with causal mask (attention bias)
648
  # NOTE: We need to initialize the attn bias in order for attn to
@@ -654,7 +660,9 @@ def resolve_causal_mask(
654
  or past_key_values is not None
655
  ):
656
  if causal_mask is None:
657
- causal_mask = _create_causal_mask(past_length + seq_len, device)
 
 
658
  elif causal_mask.dtype in (torch.int8, torch.bool):
659
  causal_mask = causal_mask.to(dtype=torch.float)
660
  causal_mask.masked_fill_(
@@ -737,9 +745,7 @@ def rotate_half(x: torch.Tensor):
737
 
738
 
739
  def apply_rotary_positional_embeddings(
740
- x: torch.Tensor,
741
- cos: torch.Tensor,
742
- sin: torch.Tensor,
743
  ) -> torch.Tensor:
744
  return (x * cos + rotate_half(x) * sin).to(x.dtype)
745
 
@@ -884,6 +890,7 @@ class MHSDPA(nn.Module):
884
  attn_mask: Optional[torch.Tensor] = None,
885
  is_causal: Optional[bool] = None,
886
  ) -> Tuple[Callable, Optional[torch.Tensor], Optional[bool]]:
 
887
  if 'flash' in attn_implementation and self.fp32_attn:
888
  raise ValueError('Flash attention does not support fp32 attention')
889
  if self.sliding_window != -1 and 'flash' not in attn_implementation:
@@ -1064,7 +1071,9 @@ class FFN(nn.Module):
1064
  if self.gated_activation:
1065
  intermediate_size = 2 * self.intermediate_size
1066
 
1067
- self.up = nn.Linear(self.hidden_size, intermediate_size, bias=self.use_bias)
 
 
1068
  self.down = nn.Linear(
1069
  self.intermediate_size, self.output_size, bias=self.use_bias
1070
  )
@@ -1236,14 +1245,6 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1236
  assert config.attn_pooling_config is not None
1237
  if config.pooling_type == ImagePooling2DType.attention_2wide:
1238
  pooling_input_size *= 2
1239
-
1240
- # Flash Attention can cause Inf grads in the attention pooling layer
1241
- # because of very large batch sizes. Setting this to sdpa does not cost us
1242
- # much since sequence lengths in the case of attention pooling are very
1243
- # small
1244
- attn_implementation = attn_implementation or 'eager'
1245
- if attn_implementation.startswith('flash'):
1246
- attn_implementation = 'sdpa'
1247
  self.pooling = MHSDPA(
1248
  config.attn_pooling_config,
1249
  hidden_size=pooling_input_size,
@@ -1289,12 +1290,10 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1289
  image_features: torch.Tensor,
1290
  image_masks: Optional[torch.Tensor] = None,
1291
  attn_implementation: Optional[str] = None,
1292
- **kwargs: Unpack[FlashAttentionKwargs],
1293
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1294
  # image_features:
1295
  # (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
1296
  bs, ncrops = image_features.shape[:2]
1297
- ogtype = image_features.dtype
1298
 
1299
  if self.padding_embed_type is not None:
1300
  assert image_masks is not None
@@ -1323,7 +1322,6 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1323
  partial_pad, -1
1324
  )
1325
 
1326
- image_features = image_features.to(dtype=ogtype)
1327
  image_features = self.feature_dropout(image_features)
1328
  image_features = image_features.reshape((bs, ncrops) + self.n_patches + (-1,))
1329
  pad_h = self.n_patches[0] % self.pooling_h
@@ -1345,31 +1343,11 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1345
  dh=self.pooling_h,
1346
  dw=self.pooling_w,
1347
  )
1348
- image_features = image_features.contiguous()
1349
  if self.pooling_type == ImagePooling2DType.attention_meanq:
1350
  query = image_features.mean(-2, keepdim=True)
1351
- # Flash Attention can cause Inf grads in the attention pooling layer
1352
- # because of very large batch sizes. Setting this to sdpa does not cost
1353
- # us much since sequence lengths in the case of attention pooling are
1354
- # very small
1355
- attn_implementation = attn_implementation or 'eager'
1356
- if attn_implementation.startswith('flash'):
1357
- attn_implementation = 'sdpa'
1358
- if attn_implementation == 'sdpa':
1359
- with sdpa_kernel(backends=[SDPBackend.MATH]):
1360
- image_features, _ = self.pooling(
1361
- xq=query,
1362
- xk=image_features,
1363
- attn_implementation='sdpa',
1364
- **kwargs,
1365
- )
1366
- else:
1367
- image_features, _ = self.pooling(
1368
- xq=query,
1369
- xk=image_features,
1370
- attn_implementation=attn_implementation,
1371
- **kwargs,
1372
- )
1373
  elif self.pooling_type not in {
1374
  ImagePooling2DType.none,
1375
  ImagePooling2DType.stack,
 
11
  import torch.backends.cuda
12
  import torch.nn as nn
13
  import torch.nn.functional as f
 
14
  from transformers import PretrainedConfig
15
  from transformers.activations import ACT2FN
16
  from transformers.cache_utils import Cache
 
324
 
325
 
326
  def inv_freq_to_device(rope_forward):
327
+ """
328
+ Sometimes the inv_freq is calculated on the wrong device, or ends up in lower
329
+ precision than float32. This wrapper ensures that inv_freq is always on the right
330
+ device and in float32 precision.
 
331
  """
332
 
333
  @wraps(rope_forward)
 
353
  theta: float,
354
  head_dim: int,
355
  hidden_size: int,
356
+ n_heads: int,
357
  partial_rotary_factor: float,
358
  device: Optional[torch.device] = None,
359
  scaling: Optional[Dict[str, Any]] = None,
 
366
  setattr(self.config, 'rope_theta', theta)
367
  setattr(self.config, 'partial_rotary_factor', partial_rotary_factor)
368
  setattr(self.config, 'head_dim', head_dim)
369
+ setattr(self.config, 'num_attention_heads', n_heads)
370
  setattr(self.config, 'hidden_size', hidden_size)
371
  setattr(self.config, 'rope_scaling', scaling or {})
372
 
 
377
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
378
  device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
379
  seqlen = config.max_position_embeddings or config.max_sequence_length
380
+ invfreq, self.attention_scaling = self.rope_init_fn(
381
+ self.config, device, seqlen
382
+ )
383
  self.rope_init_device = device
384
  self.register_buffer('inv_freq', invfreq, persistent=False)
385
  self.original_inv_freq = self.inv_freq
 
617
  def _ensure_finite(
618
  x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False
619
  ):
620
+ """
621
+ Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the
622
  dtype when ``check_neg_inf`` is ``True`` and replace ``float("inf")`` with the
623
+ maximum value of the dtype when ``check_pos_inf`` is ``True``
624
+ """
625
  if check_neg_inf:
626
  x.masked_fill_(x == float('-inf'), torch.finfo(x.dtype).min)
627
  if check_pos_inf:
 
641
  # shape: (batch_size, 1, 1, seq_len)
642
  if len(attention_mask.shape) == 2:
643
  attention_mask = attention_mask[:, : past_length + seq_len]
644
+ attention_mask = attention_mask.to(dtype=torch.float).view(
645
+ batch_size, -1
646
+ )[:, None, None, :]
647
  else:
648
  attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
649
+ attention_mask = (1.0 - attention_mask) * torch.finfo(
650
+ attention_mask.dtype
651
+ ).min
652
 
653
  # Merge attention mask with causal mask (attention bias)
654
  # NOTE: We need to initialize the attn bias in order for attn to
 
660
  or past_key_values is not None
661
  ):
662
  if causal_mask is None:
663
+ causal_mask = _create_causal_mask(
664
+ past_length + seq_len, device
665
+ )
666
  elif causal_mask.dtype in (torch.int8, torch.bool):
667
  causal_mask = causal_mask.to(dtype=torch.float)
668
  causal_mask.masked_fill_(
 
745
 
746
 
747
  def apply_rotary_positional_embeddings(
748
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
 
 
749
  ) -> torch.Tensor:
750
  return (x * cos + rotate_half(x) * sin).to(x.dtype)
751
 
 
890
  attn_mask: Optional[torch.Tensor] = None,
891
  is_causal: Optional[bool] = None,
892
  ) -> Tuple[Callable, Optional[torch.Tensor], Optional[bool]]:
893
+
894
  if 'flash' in attn_implementation and self.fp32_attn:
895
  raise ValueError('Flash attention does not support fp32 attention')
896
  if self.sliding_window != -1 and 'flash' not in attn_implementation:
 
1071
  if self.gated_activation:
1072
  intermediate_size = 2 * self.intermediate_size
1073
 
1074
+ self.up = nn.Linear(
1075
+ self.hidden_size, intermediate_size, bias=self.use_bias
1076
+ )
1077
  self.down = nn.Linear(
1078
  self.intermediate_size, self.output_size, bias=self.use_bias
1079
  )
 
1245
  assert config.attn_pooling_config is not None
1246
  if config.pooling_type == ImagePooling2DType.attention_2wide:
1247
  pooling_input_size *= 2
 
 
 
 
 
 
 
 
1248
  self.pooling = MHSDPA(
1249
  config.attn_pooling_config,
1250
  hidden_size=pooling_input_size,
 
1290
  image_features: torch.Tensor,
1291
  image_masks: Optional[torch.Tensor] = None,
1292
  attn_implementation: Optional[str] = None,
 
1293
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1294
  # image_features:
1295
  # (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
1296
  bs, ncrops = image_features.shape[:2]
 
1297
 
1298
  if self.padding_embed_type is not None:
1299
  assert image_masks is not None
 
1322
  partial_pad, -1
1323
  )
1324
 
 
1325
  image_features = self.feature_dropout(image_features)
1326
  image_features = image_features.reshape((bs, ncrops) + self.n_patches + (-1,))
1327
  pad_h = self.n_patches[0] % self.pooling_h
 
1343
  dh=self.pooling_h,
1344
  dw=self.pooling_w,
1345
  )
 
1346
  if self.pooling_type == ImagePooling2DType.attention_meanq:
1347
  query = image_features.mean(-2, keepdim=True)
1348
+ image_features, _ = self.pooling(
1349
+ xq=query, xk=image_features, attn_implementation=attn_implementation
1350
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1351
  elif self.pooling_type not in {
1352
  ImagePooling2DType.none,
1353
  ImagePooling2DType.stack,
config.json CHANGED
@@ -4,7 +4,6 @@
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_jvlm.JinaVLMConfig",
7
- "AutoModel": "modeling_jvlm.JinaVLM",
8
  "AutoModelForCausalLM": "modeling_jvlm.JinaVLMForConditionalGeneration"
9
  },
10
  "bos_token_id": 151643,
@@ -215,4 +214,4 @@
215
  "spatial_merge_size": 2
216
  }
217
  }
218
- }
 
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_jvlm.JinaVLMConfig",
 
7
  "AutoModelForCausalLM": "modeling_jvlm.JinaVLMForConditionalGeneration"
8
  },
9
  "bos_token_id": 151643,
 
214
  "spatial_merge_size": 2
215
  }
216
  }
217
+ }
configuration_jvlm.py CHANGED
@@ -530,11 +530,6 @@ class JinaVLMTextConfig(PretrainedConfigWithDataclasses):
530
  self.rope_theta = rope_theta
531
  self.rope_scaling = rope_scaling
532
 
533
- # Needed for vLLM
534
- @property
535
- def num_attention_heads(self) -> int:
536
- return self.block_config.attn_config.n_heads
537
-
538
 
539
  class JinaVLMConfig(PretrainedConfig):
540
  """JinaVLM configuration.
@@ -550,8 +545,7 @@ class JinaVLMConfig(PretrainedConfig):
550
 
551
  model_type = 'jvlm'
552
  sub_configs = {
553
- 'vision_config': JinaVLMVisionConfig,
554
- 'text_config': JinaVLMTextConfig,
555
  }
556
 
557
  def __init__(
 
530
  self.rope_theta = rope_theta
531
  self.rope_scaling = rope_scaling
532
 
 
 
 
 
 
533
 
534
  class JinaVLMConfig(PretrainedConfig):
535
  """JinaVLM configuration.
 
545
 
546
  model_type = 'jvlm'
547
  sub_configs = {
548
+ 'vision_config': JinaVLMVisionConfig, 'text_config': JinaVLMTextConfig
 
549
  }
550
 
551
  def __init__(
image_processing_jvlm.py CHANGED
@@ -437,17 +437,6 @@ class JinaVLMImageProcessor(BaseImageProcessor):
437
 
438
  """ Base cropping via resizing """
439
 
440
- def base_get_n_image_patches(
441
- self,
442
- height: int,
443
- width: int,
444
- max_crops: int,
445
- ) -> int:
446
- raise NotImplementedError(
447
- 'Function `get_n_image_patches` is not implemented for cropping method '
448
- f'{CroppingMethod.RESIZE}'
449
- )
450
-
451
  def base_resize_cropping(self, image: np.ndarray):
452
  resized, mask = self.resize_image(image, list(self.base_input_size))
453
  resized = self.normalize_image(resized)
@@ -508,117 +497,6 @@ class JinaVLMImageProcessor(BaseImageProcessor):
508
 
509
  return candidate_tilings[ix]
510
 
511
- @staticmethod
512
- def _molmo_get_patches_from_tiling(
513
- num_tiles,
514
- pooling_size,
515
- crop_patches,
516
- crop_window_patches,
517
- left_margin,
518
- right_margin,
519
- ) -> np.int32:
520
- if num_tiles > 1:
521
- left_crop_window_patches = (
522
- (crop_window_patches + left_margin + pooling_size - 1)
523
- // pooling_size
524
- * pooling_size
525
- )
526
- middle_crop_window_patches = (
527
- (crop_window_patches + pooling_size - 1) // pooling_size * pooling_size
528
- )
529
- right_crop_window_patches = (
530
- (crop_window_patches + right_margin + pooling_size - 1)
531
- // pooling_size
532
- * pooling_size
533
- )
534
- return (
535
- left_crop_window_patches
536
- + (num_tiles - 2) * middle_crop_window_patches
537
- + right_crop_window_patches
538
- )
539
- else:
540
- single_crop_window_patches = (
541
- (crop_patches + pooling_size - 1) // pooling_size * pooling_size
542
- )
543
- return single_crop_window_patches
544
-
545
- def molmo_get_n_image_patches(
546
- self,
547
- height: int,
548
- width: int,
549
- max_crops: int,
550
- ) -> int:
551
- # Discard this many patches from the (left/top, right/bottom) of crops
552
- left_margin, right_margin = self.overlap_margins
553
- # Required for compatibility with image pooling
554
- assert left_margin % self.pooling_w == 0 and right_margin % self.pooling_w == 0
555
- assert left_margin % self.pooling_h == 0 and right_margin % self.pooling_h == 0
556
- # pixels removed per dim
557
- total_margin_pixels = self.patch_size * (right_margin + left_margin)
558
- # patches per crop dim
559
- crop_patches = self.base_input_size[0] // self.patch_size
560
-
561
- # usable patches
562
- crop_window_patches = crop_patches - (right_margin + left_margin)
563
- crop_window_size = crop_window_patches * self.patch_size
564
-
565
- # We assume hxw pooling, but can allow padding the right/bottom with extra
566
- # patches if the number of patches per side is not divisible by h/w
567
- assert (
568
- crop_patches + self.pooling_h - 1
569
- ) // self.pooling_h == self.token_length_h
570
- assert (
571
- crop_patches + self.pooling_w - 1
572
- ) // self.pooling_w == self.token_length_w
573
-
574
- # Decide how to tile the image, to account for the overlap margins we
575
- # compute the tiling as if we had an image without the margins and were
576
- # using a crop size without the margins
577
- tiling = self._molmo_select_tiling(
578
- height - total_margin_pixels,
579
- width - total_margin_pixels,
580
- crop_window_size,
581
- max_crops,
582
- )
583
-
584
- # Now build the output tokens
585
- h = self._molmo_get_patches_from_tiling(
586
- tiling[0],
587
- self.pooling_h,
588
- crop_patches,
589
- crop_window_patches,
590
- left_margin,
591
- right_margin,
592
- )
593
- w = self._molmo_get_patches_from_tiling(
594
- tiling[1],
595
- self.pooling_w,
596
- crop_patches,
597
- crop_window_patches,
598
- left_margin,
599
- right_margin,
600
- )
601
- # for each row of patches, add a patch token per patch
602
- n_tokens = w.item() // self.pooling_w
603
- if self.use_column_tokens:
604
- # after each row, one column token is added
605
- n_tokens += 1
606
- # replicate each row of patch tokens by number of rows, i.e.
607
- # proportional to image height
608
- n_tokens *= h.item() // self.pooling_h
609
- # add start and end image tokens
610
- n_tokens += 2
611
-
612
- # Global image goes first, so the order of patches in previous crops gets
613
- # increased
614
- n_thumbnail_tokens = self.token_length_w
615
- if self.use_column_tokens:
616
- n_thumbnail_tokens += 1
617
- n_thumbnail_tokens *= self.token_length_h
618
- n_thumbnail_tokens += 2
619
-
620
- return n_tokens + n_thumbnail_tokens
621
-
622
  def molmo_overlap_and_resize_cropping(self, image: np.ndarray):
623
  # Discard this many patches from the (left/top, right/bottom) of crops
624
  left_margin, right_margin = self.overlap_margins
@@ -747,23 +625,37 @@ class JinaVLMImageProcessor(BaseImageProcessor):
747
  # new order into sparse structure of `patch_ordering` to fix it
748
  patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  # Now build the output tokens
751
- h = self._molmo_get_patches_from_tiling(
752
- tiling[0],
753
- self.pooling_h,
754
- crop_patches,
755
- crop_window_patches,
756
- left_margin,
757
- right_margin,
758
- )
759
- w = self._molmo_get_patches_from_tiling(
760
- tiling[1],
761
- self.pooling_w,
762
- crop_patches,
763
- crop_window_patches,
764
- left_margin,
765
- right_margin,
766
- )
767
  # for each row of patches, add a patch token per patch
768
  per_row = np.full((w // self.pooling_w,), self.patch_token_id, dtype=np.int32)
769
  if self.use_column_tokens:
@@ -918,14 +810,6 @@ class JinaVLMImageProcessor(BaseImageProcessor):
918
 
919
  return slices, image_masks, patch_ordering_arr, best_grid
920
 
921
- def minicpm_get_n_image_patches(
922
- self, height: int, width: int, max_crops: int, with_thumbnail: bool = False
923
- ) -> int:
924
- raise NotImplementedError(
925
- 'Function `get_n_image_patches` is not implemented for cropping method '
926
- f'{CroppingMethod.ADAPTIVE_SLICING}'
927
- )
928
-
929
  def minicpm_adaptive_slicing(self, image: np.ndarray, with_thumbnail: bool = True):
930
  scale_resolution = self.base_input_size[0]
931
  refine_image, image_mask, best_grid = self._minicpm_refine_image_for_slicing(
@@ -1062,12 +946,23 @@ class JinaVLMImageProcessor(BaseImageProcessor):
1062
  self.start_token_id = start_token_id
1063
  self.end_token_id = end_token_id
1064
 
1065
- def _resolve_images_kwargs(
1066
- self, **kwargs: Unpack[JinaVLMImagesKwargs]
1067
- ) -> JinaVLMImagesKwargs:
1068
- max_crops = self.max_crops
 
 
 
 
 
 
 
 
 
 
1069
  if 'max_crops' in kwargs and kwargs['max_crops'] is not None:
1070
  max_crops = kwargs['max_crops']
 
1071
 
1072
  min_pixels = self.min_pixels
1073
  if 'min_pixels' in kwargs and kwargs['min_pixels'] is not None:
@@ -1089,93 +984,14 @@ class JinaVLMImageProcessor(BaseImageProcessor):
1089
  size = {'shortest_edge': min_pixels, 'longest_edge': max_pixels}
1090
  else:
1091
  size = {**self.size}
1092
- min_pixels = size['shortest_edge']
1093
- max_pixels = size['longest_edge']
1094
  do_resize = self.do_resize
1095
  if 'do_resize' in kwargs and kwargs['do_resize'] is not None:
1096
  do_resize = kwargs['do_resize']
 
1097
  do_convert_rgb = self.do_convert_rgb
1098
  if 'do_convert_rgb' in kwargs and kwargs['do_convert_rgb'] is not None:
1099
  do_convert_rgb = kwargs['do_convert_rgb']
1100
- input_data_format = None
1101
- if 'input_data_format' in kwargs:
1102
- input_data_format = kwargs['input_data_format']
1103
-
1104
- return JinaVLMImagesKwargs(
1105
- do_convert_rgb=do_convert_rgb,
1106
- do_resize=do_resize,
1107
- min_pixels=min_pixels,
1108
- max_pixels=max_pixels,
1109
- size=size,
1110
- max_crops=max_crops,
1111
- input_data_format=input_data_format,
1112
- )
1113
-
1114
- def get_n_image_patches(
1115
- self,
1116
- height: int,
1117
- width: int,
1118
- **kwargs: Unpack[JinaVLMImagesKwargs],
1119
- ) -> int:
1120
- """A utility that returns number of image patches for a given image size.
1121
-
1122
- Args:
1123
- height (`int`):
1124
- Height of the input image.
1125
- width (`int`):
1126
- Width of the input image.
1127
- **kwargs (`dict`, *optional*)
1128
- Any kwargs to override defaults of the image processor.
1129
- Returns:
1130
- `int`: Number of image patches
1131
- """
1132
- if self.cropping_method != CroppingMethod.OVERLAP_AND_RESIZE:
1133
- raise NotImplementedError(
1134
- 'Function is only implemented for cropping method '
1135
- f'{CroppingMethod.OVERLAP_AND_RESIZE}'
1136
- )
1137
- kwargs = self._resolve_images_kwargs(**kwargs)
1138
- do_resize = kwargs['do_resize']
1139
- size = kwargs['size']
1140
- max_crops = kwargs['max_crops']
1141
- if do_resize:
1142
- height, width = smart_resize(
1143
- height,
1144
- width,
1145
- factor=self.patch_size,
1146
- min_pixels=size['shortest_edge'],
1147
- max_pixels=size['longest_edge'],
1148
- )
1149
-
1150
- if self.cropping_method == CroppingMethod.RESIZE:
1151
- return self.base_get_n_image_patches(height, width, max_crops)
1152
- elif self.cropping_method == CroppingMethod.OVERLAP_AND_RESIZE:
1153
- return self.molmo_get_n_image_patches(height, width, max_crops)
1154
- elif self.cropping_method == CroppingMethod.ADAPTIVE_SLICING:
1155
- return self.minicpm_get_n_image_patches(height, width, max_crops)
1156
- return self.minicpm_get_n_image_patches(
1157
- height, width, max_crops, with_thumbnail=True
1158
- )
1159
-
1160
- def preprocess(
1161
- self,
1162
- images: ImageInput,
1163
- **kwargs: Unpack[JinaVLMImagesKwargs],
1164
- ) -> Dict[str, List[np.ndarray]]:
1165
- """Preprocess an image or batch of images."""
1166
- if images is None or len(images) == 0:
1167
- return {
1168
- 'image_crops': [],
1169
- 'image_tokens': [],
1170
- 'image_input_idx': [],
1171
- 'image_padding_mask': [],
1172
- }
1173
- kwargs = self._resolve_images_kwargs(**kwargs)
1174
- do_convert_rgb = kwargs['do_convert_rgb']
1175
- do_resize = kwargs['do_resize']
1176
- input_data_format = kwargs['input_data_format']
1177
- size = kwargs['size']
1178
- self.max_crops = kwargs['max_crops']
1179
 
1180
  # noinspection PyTypeChecker
1181
  images = self.fetch_images(images)
@@ -1185,11 +1001,16 @@ class JinaVLMImageProcessor(BaseImageProcessor):
1185
  'Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray '
1186
  'or torch.Tensor'
1187
  )
 
1188
  if do_convert_rgb:
1189
  images = [convert_to_rgb(image) for image in images]
1190
 
1191
  # All transformations expect numpy arrays
1192
  images = [to_numpy_array(image) for image in images]
 
 
 
 
1193
  if input_data_format is None:
1194
  # We assume that all images have the same channel dimension format.
1195
  input_data_format = infer_channel_dimension_format(images[0])
 
437
 
438
  """ Base cropping via resizing """
439
 
 
 
 
 
 
 
 
 
 
 
 
440
  def base_resize_cropping(self, image: np.ndarray):
441
  resized, mask = self.resize_image(image, list(self.base_input_size))
442
  resized = self.normalize_image(resized)
 
497
 
498
  return candidate_tilings[ix]
499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  def molmo_overlap_and_resize_cropping(self, image: np.ndarray):
501
  # Discard this many patches from the (left/top, right/bottom) of crops
502
  left_margin, right_margin = self.overlap_margins
 
625
  # new order into sparse structure of `patch_ordering` to fix it
626
  patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
627
 
628
+ def get_num_patches(num_tiles, pooling_size) -> int:
629
+ if num_tiles > 1:
630
+ left_crop_window_patches = (
631
+ (crop_window_patches + left_margin + pooling_size - 1)
632
+ // pooling_size
633
+ * pooling_size
634
+ )
635
+ middle_crop_window_patches = (
636
+ (crop_window_patches + pooling_size - 1)
637
+ // pooling_size
638
+ * pooling_size
639
+ )
640
+ right_crop_window_patches = (
641
+ (crop_window_patches + right_margin + pooling_size - 1)
642
+ // pooling_size
643
+ * pooling_size
644
+ )
645
+ return (
646
+ left_crop_window_patches
647
+ + (num_tiles - 2) * middle_crop_window_patches
648
+ + right_crop_window_patches
649
+ )
650
+ else:
651
+ single_crop_window_patches = (
652
+ (crop_patches + pooling_size - 1) // pooling_size * pooling_size
653
+ )
654
+ return single_crop_window_patches
655
+
656
  # Now build the output tokens
657
+ h = get_num_patches(tiling[0], self.pooling_h)
658
+ w = get_num_patches(tiling[1], self.pooling_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
  # for each row of patches, add a patch token per patch
660
  per_row = np.full((w // self.pooling_w,), self.patch_token_id, dtype=np.int32)
661
  if self.use_column_tokens:
 
810
 
811
  return slices, image_masks, patch_ordering_arr, best_grid
812
 
 
 
 
 
 
 
 
 
813
  def minicpm_adaptive_slicing(self, image: np.ndarray, with_thumbnail: bool = True):
814
  scale_resolution = self.base_input_size[0]
815
  refine_image, image_mask, best_grid = self._minicpm_refine_image_for_slicing(
 
946
  self.start_token_id = start_token_id
947
  self.end_token_id = end_token_id
948
 
949
+ def preprocess(
950
+ self,
951
+ images: ImageInput,
952
+ **kwargs: Unpack[JinaVLMImagesKwargs],
953
+ ) -> Dict[str, List[np.ndarray]]:
954
+ """Preprocess an image or batch of images."""
955
+ if images is None or len(images) == 0:
956
+ return {
957
+ 'image_crops': [],
958
+ 'image_tokens': [],
959
+ 'image_input_idx': [],
960
+ 'image_padding_mask': [],
961
+ }
962
+
963
  if 'max_crops' in kwargs and kwargs['max_crops'] is not None:
964
  max_crops = kwargs['max_crops']
965
+ self.max_crops = max_crops
966
 
967
  min_pixels = self.min_pixels
968
  if 'min_pixels' in kwargs and kwargs['min_pixels'] is not None:
 
984
  size = {'shortest_edge': min_pixels, 'longest_edge': max_pixels}
985
  else:
986
  size = {**self.size}
987
+
 
988
  do_resize = self.do_resize
989
  if 'do_resize' in kwargs and kwargs['do_resize'] is not None:
990
  do_resize = kwargs['do_resize']
991
+
992
  do_convert_rgb = self.do_convert_rgb
993
  if 'do_convert_rgb' in kwargs and kwargs['do_convert_rgb'] is not None:
994
  do_convert_rgb = kwargs['do_convert_rgb']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995
 
996
  # noinspection PyTypeChecker
997
  images = self.fetch_images(images)
 
1001
  'Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray '
1002
  'or torch.Tensor'
1003
  )
1004
+
1005
  if do_convert_rgb:
1006
  images = [convert_to_rgb(image) for image in images]
1007
 
1008
  # All transformations expect numpy arrays
1009
  images = [to_numpy_array(image) for image in images]
1010
+
1011
+ input_data_format = None
1012
+ if 'input_data_format' in kwargs:
1013
+ input_data_format = kwargs['input_data_format']
1014
  if input_data_format is None:
1015
  # We assume that all images have the same channel dimension format.
1016
  input_data_format = infer_channel_dimension_format(images[0])
modeling_jvlm.py CHANGED
@@ -27,13 +27,14 @@ from .blocks_jvlm import (
27
  TransformerBlock,
28
  VisionLanguageConnector,
29
  build_layer_norm,
30
- resolve_causal_mask,
31
  )
32
  from .configuration_jvlm import JinaVLMConfig, JinaVLMTextConfig, JinaVLMVisionConfig
33
 
34
 
35
  class JinaPreTrainedModel(PreTrainedModel):
36
  config: JinaVLMConfig
 
37
  base_model_prefix = 'model'
38
  supports_gradient_checkpointing = True
39
  _supports_flash_attn = True
@@ -50,6 +51,8 @@ class JinaPreTrainedModel(PreTrainedModel):
50
 
51
  class JinaVLMVisionModel(JinaPreTrainedModel):
52
  config: JinaVLMVisionConfig
 
 
53
 
54
  def __init__(self, config: JinaVLMVisionConfig, *args, **kwargs):
55
  super().__init__(config, *args, **kwargs)
@@ -183,11 +186,7 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
183
  pos = pos_emb[None, :, :].to(x.dtype)
184
  return x + pos
185
 
186
- def get_visual_features(
187
- self,
188
- images: torch.Tensor,
189
- **kwargs: Unpack[FlashAttentionKwargs],
190
- ) -> BaseModelOutput:
191
  x, shape = self.patch_embed(images)
192
  if self.cls_embed is not None:
193
  cls = self.cls_embed.view(1, 1, -1).expand(x.shape[0], -1, -1).to(x.dtype)
@@ -202,11 +201,7 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
202
  hidden_states = []
203
  attentions = []
204
  for layer in self.layers:
205
- x, attn = layer(
206
- x,
207
- attn_implementation=self.config._attn_implementation,
208
- **kwargs,
209
- )
210
  hidden_states.append(x)
211
  attentions.append(attn)
212
  x = self.post_lnorm(x)
@@ -219,15 +214,12 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
219
  )
220
 
221
  def forward(
222
- self,
223
- images: torch.Tensor,
224
- image_masks: torch.Tensor,
225
- **kwargs: Unpack[FlashAttentionKwargs],
226
  ) -> BaseModelOutput:
227
  b, t, n, d = images.shape
228
  mask = ~torch.all(images.view(b * t, n, d) == -1, dim=(1, 2), keepdim=True)
229
  images = images.view(b * t, n, d)
230
- out = self.get_visual_features(images, **kwargs)
231
  image_features = out.hidden_states
232
 
233
  features = []
@@ -238,13 +230,14 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
238
  features.append(feats)
239
  image_features = torch.cat(features, dim=-1)
240
  image_features = image_features * mask
241
- image_features = image_features.view(b, t, n, -1).contiguous()
 
242
  image_features = self.vl_connector(
243
  image_features,
244
  image_masks,
245
  attn_implementation=self.config._attn_implementation,
246
- **kwargs,
247
  )
 
248
  return BaseModelOutput(
249
  last_hidden_state=image_features,
250
  hidden_states=out.hidden_states,
@@ -253,7 +246,11 @@ class JinaVLMVisionModel(JinaPreTrainedModel):
253
 
254
 
255
  class JinaVLMTextModel(JinaPreTrainedModel):
 
 
256
  config: JinaVLMTextConfig
 
 
257
 
258
  def __init__(self, config: JinaVLMTextConfig, *args, **kwargs):
259
  super().__init__(config, *args, **kwargs)
@@ -300,6 +297,7 @@ class JinaVLMTextModel(JinaPreTrainedModel):
300
  theta=self.config.rope_theta,
301
  head_dim=self.config.block_config.attn_config.head_dim,
302
  hidden_size=self.config.hidden_size,
 
303
  partial_rotary_factor=self.config.partial_rotary_factor,
304
  scaling=self.config.rope_scaling,
305
  )
@@ -390,7 +388,6 @@ class JinaVLMTextModel(JinaPreTrainedModel):
390
  batch_idx = torch.arange(bs, device=x.device)
391
  batch_idx = torch.tile(batch_idx[:, None], [1, image_features.shape[1]])
392
  image_features = image_features.to(x.device)
393
- x = x.clone() # Clone x to avoid in-place operation on leaf tensor
394
  x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
395
 
396
  if not self.rope:
@@ -446,7 +443,7 @@ class JinaVLMTextModel(JinaPreTrainedModel):
446
 
447
 
448
  class JinaVLM(JinaPreTrainedModel):
449
- config: JinaVLMConfig
450
 
451
  def __init__(self, config: JinaVLMConfig):
452
  super().__init__(config)
@@ -495,7 +492,7 @@ class JinaVLM(JinaPreTrainedModel):
495
  ) -> BaseModelOutputWithPast:
496
  image_features = None
497
  if images is not None and images.shape[1] > 0:
498
- image_out = self.vision_model(images, image_masks, **kwargs)
499
  image_features = image_out.last_hidden_state
500
  return self.language_model(
501
  input_ids=input_ids,
@@ -514,10 +511,10 @@ class JinaVLM(JinaPreTrainedModel):
514
 
515
 
516
  class JinaVLMForConditionalGeneration(JinaPreTrainedModel, GenerationMixin):
517
- _tied_weights_keys = {
518
- 'lm_head.weight': 'model.language_model.embedding.embedding.weight'
519
- }
520
  accepts_loss_kwargs = False
 
521
  config: JinaVLMConfig
522
 
523
  def __init__(self, config: JinaVLMConfig):
 
27
  TransformerBlock,
28
  VisionLanguageConnector,
29
  build_layer_norm,
30
+ resolve_causal_mask
31
  )
32
  from .configuration_jvlm import JinaVLMConfig, JinaVLMTextConfig, JinaVLMVisionConfig
33
 
34
 
35
  class JinaPreTrainedModel(PreTrainedModel):
36
  config: JinaVLMConfig
37
+ config_class = JinaVLMConfig
38
  base_model_prefix = 'model'
39
  supports_gradient_checkpointing = True
40
  _supports_flash_attn = True
 
51
 
52
  class JinaVLMVisionModel(JinaPreTrainedModel):
53
  config: JinaVLMVisionConfig
54
+ config_class = JinaVLMVisionConfig
55
+ base_model_prefix = ''
56
 
57
  def __init__(self, config: JinaVLMVisionConfig, *args, **kwargs):
58
  super().__init__(config, *args, **kwargs)
 
186
  pos = pos_emb[None, :, :].to(x.dtype)
187
  return x + pos
188
 
189
+ def get_visual_features(self, images: torch.Tensor) -> BaseModelOutput:
 
 
 
 
190
  x, shape = self.patch_embed(images)
191
  if self.cls_embed is not None:
192
  cls = self.cls_embed.view(1, 1, -1).expand(x.shape[0], -1, -1).to(x.dtype)
 
201
  hidden_states = []
202
  attentions = []
203
  for layer in self.layers:
204
+ x, attn = layer(x, attn_implementation=self.config._attn_implementation)
 
 
 
 
205
  hidden_states.append(x)
206
  attentions.append(attn)
207
  x = self.post_lnorm(x)
 
214
  )
215
 
216
  def forward(
217
+ self, images: torch.Tensor, image_masks: torch.Tensor
 
 
 
218
  ) -> BaseModelOutput:
219
  b, t, n, d = images.shape
220
  mask = ~torch.all(images.view(b * t, n, d) == -1, dim=(1, 2), keepdim=True)
221
  images = images.view(b * t, n, d)
222
+ out = self.get_visual_features(images)
223
  image_features = out.hidden_states
224
 
225
  features = []
 
230
  features.append(feats)
231
  image_features = torch.cat(features, dim=-1)
232
  image_features = image_features * mask
233
+ image_features = image_features.view(b, t, n, -1)
234
+
235
  image_features = self.vl_connector(
236
  image_features,
237
  image_masks,
238
  attn_implementation=self.config._attn_implementation,
 
239
  )
240
+
241
  return BaseModelOutput(
242
  last_hidden_state=image_features,
243
  hidden_states=out.hidden_states,
 
246
 
247
 
248
  class JinaVLMTextModel(JinaPreTrainedModel):
249
+ """Decoder-only language model."""
250
+
251
  config: JinaVLMTextConfig
252
+ config_class = JinaVLMTextConfig
253
+ base_model_prefix = ''
254
 
255
  def __init__(self, config: JinaVLMTextConfig, *args, **kwargs):
256
  super().__init__(config, *args, **kwargs)
 
297
  theta=self.config.rope_theta,
298
  head_dim=self.config.block_config.attn_config.head_dim,
299
  hidden_size=self.config.hidden_size,
300
+ n_heads=self.config.block_config.attn_config.n_heads,
301
  partial_rotary_factor=self.config.partial_rotary_factor,
302
  scaling=self.config.rope_scaling,
303
  )
 
388
  batch_idx = torch.arange(bs, device=x.device)
389
  batch_idx = torch.tile(batch_idx[:, None], [1, image_features.shape[1]])
390
  image_features = image_features.to(x.device)
 
391
  x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
392
 
393
  if not self.rope:
 
443
 
444
 
445
  class JinaVLM(JinaPreTrainedModel):
446
+ base_model_prefix = ''
447
 
448
  def __init__(self, config: JinaVLMConfig):
449
  super().__init__(config)
 
492
  ) -> BaseModelOutputWithPast:
493
  image_features = None
494
  if images is not None and images.shape[1] > 0:
495
+ image_out = self.vision_model(images, image_masks)
496
  image_features = image_out.last_hidden_state
497
  return self.language_model(
498
  input_ids=input_ids,
 
511
 
512
 
513
  class JinaVLMForConditionalGeneration(JinaPreTrainedModel, GenerationMixin):
514
+ _checkpoint_conversion_mapping = {}
515
+ _tied_weights_keys = ['lm_head.weight']
 
516
  accepts_loss_kwargs = False
517
+ base_model_prefix = 'model'
518
  config: JinaVLMConfig
519
 
520
  def __init__(self, config: JinaVLMConfig):
processing_jvlm.py CHANGED
@@ -10,14 +10,11 @@ from transformers.image_utils import ImageInput
10
  from transformers.processing_utils import (
11
  AllKwargsForChatTemplate,
12
  CommonKwargs,
13
- MultiModalData,
14
  ProcessorMixin,
15
  Unpack,
16
  )
17
  from transformers.tokenization_utils_base import (
18
- PaddingStrategy,
19
- PreTokenizedInput,
20
- TextInput,
21
  )
22
 
23
  from .image_processing_jvlm import JinaVLMImageProcessor, JinaVLMImagesKwargs
@@ -41,8 +38,8 @@ class JinaVLMTextKwargs(TypedDict, total=False):
41
  is_split_into_words: Optional[bool]
42
 
43
 
44
- class JinaVLMProcessingKwargs(JinaVLMTextKwargs, JinaVLMImagesKwargs, CommonKwargs):
45
- return_labels: Optional[bool]
46
 
47
 
48
  class JinaVLMProcessor(ProcessorMixin):
@@ -174,8 +171,8 @@ class JinaVLMProcessor(ProcessorMixin):
174
  def _collate(
175
  self,
176
  batch: Dict[str, List[Optional[np.ndarray]]],
177
- text_max_sequence_length: Optional[int] = None,
178
- image_max_sequence_length: Optional[int] = None,
179
  padding: Union[
180
  PaddingStrategy.MAX_LENGTH, PaddingStrategy.LONGEST
181
  ] = PaddingStrategy.MAX_LENGTH,
@@ -188,10 +185,10 @@ class JinaVLMProcessor(ProcessorMixin):
188
  _padding_side = 'right'
189
  if key in self.TEXT_KEYS:
190
  _padding_side = padding_side
191
- max_len = text_max_sequence_length
192
  dtype = np.int64
193
  elif key in self.IMAGE_KEYS:
194
- max_len = image_max_sequence_length
195
  dtype = np.int64
196
  if key == 'images':
197
  dtype = np.float32
@@ -217,22 +214,22 @@ class JinaVLMProcessor(ProcessorMixin):
217
  shift = input_ids_padlens[:, np.newaxis, np.newaxis]
218
  shift = np.repeat(shift, n_image_tokens, axis=2)
219
  shift = np.repeat(shift, n_crops, axis=1)
220
- image_input_idx[image_input_idx < 0] = -text_max_sequence_length
221
  image_input_idx = image_input_idx + shift
222
  out['image_input_idx'] = image_input_idx
223
 
224
- if text_max_sequence_length is not None:
225
  image_input_idx = out.get('image_input_idx', [])
226
  n = len(image_input_idx)
227
  for i in range(n):
228
  arr = image_input_idx[i]
229
  if arr.ndim > 0 and arr.size > 0:
230
  n_image_tokens = arr.max()
231
- if n_image_tokens > text_max_sequence_length - 3:
232
  raise RuntimeError(
233
  'Image tokens truncation at sequence boundary. Max '
234
- f'sequence length ({text_max_sequence_length}) is too '
235
- 'small to fit the generated image tokens '
236
  f'({n_image_tokens}). Consider increasing the max '
237
  'sequence length or tweaking the image processing '
238
  'parameters (`max_crops`, `max_pixels`) to reduce the '
@@ -262,7 +259,6 @@ class JinaVLMProcessor(ProcessorMixin):
262
  image_tokens: List[np.ndarray],
263
  image_input_idx: List[np.ndarray],
264
  image_padding_mask: List[np.ndarray],
265
- return_labels: bool = False,
266
  add_empty_image_features: bool = False,
267
  ):
268
  """Interleave images and text tokens into multi-modal features for the model."""
@@ -286,9 +282,8 @@ class JinaVLMProcessor(ProcessorMixin):
286
  data = {
287
  'input_ids': input_ids,
288
  'position_ids': position_ids,
 
289
  }
290
- if return_labels:
291
- data['labels'] = target_tokens
292
  if add_empty_image_features:
293
  # Add size-zero image features, this can be useful to make sure all
294
  # devices get an image input when the image ViT is FSDP wrapped
@@ -372,16 +367,14 @@ class JinaVLMProcessor(ProcessorMixin):
372
  image_input_idx < 0, image_input_idx, image_input_idx + 1
373
  )
374
  position_ids = np.arange(len(input_ids), dtype=np.int64)
375
- data = {
376
  'input_ids': input_ids,
377
  'position_ids': position_ids,
378
  'images': images,
379
  'image_input_idx': image_input_idx,
380
  'image_masks': image_masks,
 
381
  }
382
- if return_labels:
383
- data['labels'] = target_tokens
384
- return data
385
 
386
  def __call__(
387
  self,
@@ -389,7 +382,7 @@ class JinaVLMProcessor(ProcessorMixin):
389
  text: Union[
390
  None, TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
391
  ] = None,
392
- **kwargs: Unpack[JinaVLMProcessingKwargs],
393
  ) -> BatchFeature:
394
  """Main method to prepare for the model one or several sequences(s) and
395
  image(s). This method forwards the `text` and `kwargs` arguments to the
@@ -432,7 +425,6 @@ class JinaVLMProcessor(ProcessorMixin):
432
  raise ValueError('Processor requires text input.')
433
 
434
  return_tensors = kwargs.pop('return_tensors', None)
435
- return_labels = kwargs.pop('return_labels', False)
436
  padding = kwargs.pop('padding', PaddingStrategy.LONGEST)
437
  padding_side = kwargs.pop('padding_side', 'left')
438
  max_length = kwargs.pop('max_length', None)
@@ -461,7 +453,6 @@ class JinaVLMProcessor(ProcessorMixin):
461
  )
462
  token_ids = text_inputs['input_ids']
463
  batch_size = token_ids.shape[0]
464
- images = images or [[] for _ in range(batch_size)]
465
 
466
  if batch_size == 1:
467
  if isinstance(images, list):
@@ -492,11 +483,9 @@ class JinaVLMProcessor(ProcessorMixin):
492
  )
493
 
494
  outputs = defaultdict(list)
495
- n_images = []
496
  for idx in range(batch_size):
497
  _token_ids = token_ids[idx]
498
  _images = images[idx]
499
- n_images.append(len(_images))
500
  image_inputs = self.image_processor(_images, **images_kwargs)
501
  image_crops = image_inputs['image_crops']
502
  image_tokens = image_inputs['image_tokens']
@@ -509,48 +498,19 @@ class JinaVLMProcessor(ProcessorMixin):
509
  image_input_idx,
510
  image_padding_mask if image_padding_mask is not None else [],
511
  add_empty_image_features=(batch_size > 1),
512
- return_labels=return_labels,
513
  )
514
  for k, v in output.items():
515
  outputs[k].append(v)
516
 
517
  if padding != PaddingStrategy.DO_NOT_PAD:
518
- text_max_sequence_length = max_length or self.max_sequence_length
519
- max_crops = max_crops or self.max_crops
520
- max_n_images = max(n_images)
521
- image_max_sequence_length = (max_crops + 1) * max_n_images
522
  outputs = self._collate(
523
  outputs,
524
- text_max_sequence_length=text_max_sequence_length,
525
- image_max_sequence_length=image_max_sequence_length,
526
  padding=padding,
527
  padding_side=padding_side,
528
  )
529
  return BatchFeature(data=outputs, tensor_type=return_tensors)
530
 
531
- def _get_num_multimodal_tokens(
532
- self,
533
- image_sizes: Optional[List[List[int]]] = None,
534
- **kwargs: Unpack[JinaVLMImagesKwargs],
535
- ) -> MultiModalData:
536
- """Computes the number of placeholder tokens needed for multimodal inputs with
537
- the given sizes.
538
-
539
- Args:
540
- image_sizes (`list[list[int]]`, *optional*):
541
- The input sizes formatted as (height, width) per each image.
542
- Returns:
543
- `MultiModalData`: A `MultiModalData` object holding number of tokens per
544
- each of the provided input modalities, along with other useful data.
545
- """
546
- data = {}
547
- if image_sizes is not None:
548
- n_patches = [
549
- self.image_processor.get_n_image_patches(h, w, **kwargs)
550
- for h, w in image_sizes
551
- ]
552
- data.update({'num_image_tokens': n_patches, 'num_image_patches': n_patches})
553
- return MultiModalData(**data)
554
-
555
 
556
  JinaVLMProcessor.register_for_auto_class()
 
10
  from transformers.processing_utils import (
11
  AllKwargsForChatTemplate,
12
  CommonKwargs,
 
13
  ProcessorMixin,
14
  Unpack,
15
  )
16
  from transformers.tokenization_utils_base import (
17
+ PaddingStrategy, PreTokenizedInput, TextInput,
 
 
18
  )
19
 
20
  from .image_processing_jvlm import JinaVLMImageProcessor, JinaVLMImagesKwargs
 
38
  is_split_into_words: Optional[bool]
39
 
40
 
41
+ class JinaVLProcessingKwargs(JinaVLMTextKwargs, JinaVLMImagesKwargs, CommonKwargs):
42
+ pass
43
 
44
 
45
  class JinaVLMProcessor(ProcessorMixin):
 
171
  def _collate(
172
  self,
173
  batch: Dict[str, List[Optional[np.ndarray]]],
174
+ max_sequence_length: Optional[int] = None,
175
+ max_crops: Optional[int] = None,
176
  padding: Union[
177
  PaddingStrategy.MAX_LENGTH, PaddingStrategy.LONGEST
178
  ] = PaddingStrategy.MAX_LENGTH,
 
185
  _padding_side = 'right'
186
  if key in self.TEXT_KEYS:
187
  _padding_side = padding_side
188
+ max_len = max_sequence_length
189
  dtype = np.int64
190
  elif key in self.IMAGE_KEYS:
191
+ max_len = max_crops
192
  dtype = np.int64
193
  if key == 'images':
194
  dtype = np.float32
 
214
  shift = input_ids_padlens[:, np.newaxis, np.newaxis]
215
  shift = np.repeat(shift, n_image_tokens, axis=2)
216
  shift = np.repeat(shift, n_crops, axis=1)
217
+ image_input_idx[image_input_idx < 0] = -max_sequence_length
218
  image_input_idx = image_input_idx + shift
219
  out['image_input_idx'] = image_input_idx
220
 
221
+ if max_sequence_length is not None:
222
  image_input_idx = out.get('image_input_idx', [])
223
  n = len(image_input_idx)
224
  for i in range(n):
225
  arr = image_input_idx[i]
226
  if arr.ndim > 0 and arr.size > 0:
227
  n_image_tokens = arr.max()
228
+ if n_image_tokens > max_sequence_length - 3:
229
  raise RuntimeError(
230
  'Image tokens truncation at sequence boundary. Max '
231
+ f'sequence length ({max_sequence_length}) is too small '
232
+ 'to fit the generated image tokens '
233
  f'({n_image_tokens}). Consider increasing the max '
234
  'sequence length or tweaking the image processing '
235
  'parameters (`max_crops`, `max_pixels`) to reduce the '
 
259
  image_tokens: List[np.ndarray],
260
  image_input_idx: List[np.ndarray],
261
  image_padding_mask: List[np.ndarray],
 
262
  add_empty_image_features: bool = False,
263
  ):
264
  """Interleave images and text tokens into multi-modal features for the model."""
 
282
  data = {
283
  'input_ids': input_ids,
284
  'position_ids': position_ids,
285
+ 'labels': target_tokens,
286
  }
 
 
287
  if add_empty_image_features:
288
  # Add size-zero image features, this can be useful to make sure all
289
  # devices get an image input when the image ViT is FSDP wrapped
 
367
  image_input_idx < 0, image_input_idx, image_input_idx + 1
368
  )
369
  position_ids = np.arange(len(input_ids), dtype=np.int64)
370
+ return {
371
  'input_ids': input_ids,
372
  'position_ids': position_ids,
373
  'images': images,
374
  'image_input_idx': image_input_idx,
375
  'image_masks': image_masks,
376
+ 'labels': target_tokens,
377
  }
 
 
 
378
 
379
  def __call__(
380
  self,
 
382
  text: Union[
383
  None, TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
384
  ] = None,
385
+ **kwargs: Unpack[JinaVLProcessingKwargs],
386
  ) -> BatchFeature:
387
  """Main method to prepare for the model one or several sequences(s) and
388
  image(s). This method forwards the `text` and `kwargs` arguments to the
 
425
  raise ValueError('Processor requires text input.')
426
 
427
  return_tensors = kwargs.pop('return_tensors', None)
 
428
  padding = kwargs.pop('padding', PaddingStrategy.LONGEST)
429
  padding_side = kwargs.pop('padding_side', 'left')
430
  max_length = kwargs.pop('max_length', None)
 
453
  )
454
  token_ids = text_inputs['input_ids']
455
  batch_size = token_ids.shape[0]
 
456
 
457
  if batch_size == 1:
458
  if isinstance(images, list):
 
483
  )
484
 
485
  outputs = defaultdict(list)
 
486
  for idx in range(batch_size):
487
  _token_ids = token_ids[idx]
488
  _images = images[idx]
 
489
  image_inputs = self.image_processor(_images, **images_kwargs)
490
  image_crops = image_inputs['image_crops']
491
  image_tokens = image_inputs['image_tokens']
 
498
  image_input_idx,
499
  image_padding_mask if image_padding_mask is not None else [],
500
  add_empty_image_features=(batch_size > 1),
 
501
  )
502
  for k, v in output.items():
503
  outputs[k].append(v)
504
 
505
  if padding != PaddingStrategy.DO_NOT_PAD:
 
 
 
 
506
  outputs = self._collate(
507
  outputs,
508
+ max_sequence_length=max_length or self.max_sequence_length,
509
+ max_crops=max_crops or self.max_crops,
510
  padding=padding,
511
  padding_side=padding_side,
512
  )
513
  return BatchFeature(data=outputs, tensor_type=return_tensors)
514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
  JinaVLMProcessor.register_for_auto_class()
pyproject.toml DELETED
@@ -1,18 +0,0 @@
1
- [project]
2
- name = "jina-vlm"
3
- version = "1.0.0"
4
- description = "Jina VLM v1: Lightweight Vision Language Alignment"
5
- readme = "README.md"
6
- license = "CC-BY-NC-4.0"
7
- requires-python = ">=3.10"
8
- dependencies = [
9
- "torch>=2.9.0",
10
- "torchvision>=0.24.0",
11
- "transformers>=4.57.0",
12
- "pillow>=12.0.0",
13
- "einops>=0.8.1",
14
- "accelerate>=1.0.0",
15
- ]
16
-
17
- [project.optional-dependencies]
18
- flash-attn = ["flash-attn>=2.0.0"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer.py → test_jvlm.py RENAMED
@@ -11,10 +11,7 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
11
 
12
  import torch
13
  from transformers import (
14
- AutoModelForCausalLM,
15
- AutoProcessor,
16
- GenerationConfig,
17
- TextStreamer,
18
  )
19
  from transformers.utils import is_flash_attn_2_available
20
 
@@ -63,8 +60,7 @@ def _build_conversations(
63
  try:
64
  result = urlparse(_path)
65
  return result.scheme in ('http', 'https')
66
- except Exception as e:
67
- _ = str(e)
68
  return False
69
 
70
  images = images or []
@@ -87,9 +83,8 @@ def _build_conversations(
87
  images = [TEST_IMAGE]
88
  n_images = len(images)
89
  prompts = (
90
- ['Describe the image in 100 words']
91
- if n_images == 1 or map_mode
92
- else ['Describe the images in 100 words']
93
  )
94
  n_prompts = len(prompts)
95
 
@@ -124,16 +119,8 @@ def _build_conversations(
124
  allimages = []
125
  allprompts = []
126
  ordinals = [
127
- 'first',
128
- 'second',
129
- 'third',
130
- 'fourth',
131
- 'fifth',
132
- 'sixth',
133
- 'seventh',
134
- 'eighth',
135
- 'ninth',
136
- 'tenth',
137
  ]
138
  for images, prompt in examples:
139
  content = []
@@ -143,17 +130,15 @@ def _build_conversations(
143
  content.append({'type': 'text', 'text': prompt})
144
  if len(images) > 1 and image_labels:
145
  for idx, img in enumerate(images):
146
- ordinal = ordinals[idx] if idx < len(ordinals) else f'{idx + 1}th'
147
  image = images[idx]
148
  descriptor = f'url: {image}'
149
  if os.path.isfile(image):
150
  descriptor = f'filename: {os.path.basename(image)}'
151
- content.append(
152
- {
153
- 'type': 'text',
154
- 'text': f'(this is the {ordinal} image, {descriptor})',
155
- }
156
- )
157
  content.append({'type': 'image', 'image': img})
158
  else:
159
  content.extend([{'type': 'image', 'image': image} for image in images])
@@ -204,7 +189,9 @@ def _token_usage_report(
204
  tokens_per_image_list = []
205
 
206
  # Find all img_start and img_end positions in input_ids
207
- start_positions = (input_ids == image_start_id).nonzero(as_tuple=True)[0].tolist()
 
 
208
  end_positions = (input_ids == image_end_id).nonzero(as_tuple=True)[0].tolist()
209
 
210
  if len(start_positions) > 0 and len(end_positions) > 0:
@@ -224,8 +211,9 @@ def _token_usage_report(
224
  # Get the start and end indices for this image
225
  start_idx_begin = idx * n_starts_per_image
226
  end_idx_end = (idx + 1) * n_starts_per_image
227
- if start_idx_begin < len(start_positions) and end_idx_end <= len(
228
- end_positions
 
229
  ):
230
  # First start position and last end position define the image span
231
  first_start = start_positions[start_idx_begin]
@@ -245,10 +233,10 @@ def _token_usage_report(
245
 
246
  for idx in range(n_images):
247
  n_tokens = tokens_per_image_list[idx] if idx < len(tokens_per_image_list) else 0
248
- pct = n_tokens / max_sequence_length * 100
249
  report.append(f'├── Image {idx + 1} → {n_tokens} tokens ({pct:.1f}%)')
250
 
251
- text_pct = text_token_count / max_sequence_length * 100
252
  report.append(f'└── Text: {text_token_count} tokens ({text_pct:.1f}%)')
253
 
254
  return '\n'.join(report)
@@ -256,17 +244,16 @@ def _token_usage_report(
256
 
257
  def test_jvlm():
258
  parser = argparse.ArgumentParser(
259
- description='jina-vlm vision-language model inference.'
260
  )
261
- default_model = '.' if os.path.exists('./config.json') else 'jinaai/jina-vlm'
262
  parser.add_argument(
263
  '-m',
264
  '--model',
265
- default=default_model,
266
  help=(
267
- 'Model path. Auto-detects local repo (if config.json exists) or '
268
- 'falls back to "jinaai/jina-vlm" from HuggingFace.'
269
- ),
270
  )
271
  parser.add_argument(
272
  '-i',
@@ -340,7 +327,7 @@ def test_jvlm():
340
  args = parser.parse_args()
341
 
342
  print()
343
- print('Welcome to the jinaai/jina-vlm playground ✨')
344
  print('Use this script to test our model!')
345
  print('- Jina AI')
346
  print()
@@ -352,9 +339,7 @@ def test_jvlm():
352
  print(f'Using dtype: {dtype}')
353
  print('Model path: ', args.model)
354
  processor = AutoProcessor.from_pretrained(
355
- args.model,
356
- trust_remote_code=True,
357
- use_fast=False,
358
  )
359
  model = AutoModelForCausalLM.from_pretrained(
360
  args.model,
@@ -371,13 +356,13 @@ def test_jvlm():
371
  print('Done ✅')
372
  print()
373
 
374
- print("--- Let's create some conversations ...")
375
  conversations, images, prompts = _build_conversations(
376
  args.image,
377
  args.prompt,
378
  map_mode=args.map,
379
  prompt_first=args.prompt_first,
380
- image_labels=args.image_labels,
381
  )
382
  n_conversations = len(conversations)
383
  print(f'Built {n_conversations} conversations 🚀')
@@ -449,28 +434,25 @@ def test_jvlm():
449
  print(f'├── 🖼️Images: {images[idx]}')
450
  print(f'├── 📜Prompt: {prompts[idx]}')
451
  print(f'├── 💬Chat:{texts[idx]}')
452
- print('└── 🧠Response:', end='')
453
  ith_inputs = {k: v[idx].unsqueeze(0) for k, v in device_inputs.items()}
454
  with (
455
  timer,
456
  torch.no_grad(),
457
- torch.autocast(
458
- device.type, enabled=(device.type != 'mps'), dtype=dtype
459
- ),
460
  ):
461
  output = model.generate(
462
  **ith_inputs,
463
  streamer=streamer,
464
  generation_config=GenerationConfig(
465
- max_new_tokens=args.max_tokens,
466
- do_sample=False,
467
  ),
468
  return_dict_in_generate=True,
469
  use_model_defaults=True,
470
  )
471
  generation_time += timer.time
472
 
473
- out = output.sequences[0][len(input_prompts[idx].tolist()) :]
474
  generated_tokens += len(out)
475
  print('Token usage report:')
476
  print(token_usage_reports[idx])
@@ -488,8 +470,7 @@ def test_jvlm():
488
  output = model.generate(
489
  **device_inputs,
490
  generation_config=GenerationConfig(
491
- max_new_tokens=args.max_tokens,
492
- do_sample=False,
493
  ),
494
  return_dict_in_generate=True,
495
  use_model_defaults=True,
@@ -497,7 +478,7 @@ def test_jvlm():
497
  generation_time = timer.time
498
 
499
  for idx in range(n_conversations):
500
- out = output.sequences[idx][len(input_prompts[idx].tolist()) :]
501
  generated_tokens += len(out)
502
  response = processor.tokenizer.decode(out, skip_special_tokens=True)
503
  print(f'* Conversation {idx + 1}/{n_conversations}')
 
11
 
12
  import torch
13
  from transformers import (
14
+ AutoModelForCausalLM, AutoProcessor, GenerationConfig, TextStreamer
 
 
 
15
  )
16
  from transformers.utils import is_flash_attn_2_available
17
 
 
60
  try:
61
  result = urlparse(_path)
62
  return result.scheme in ('http', 'https')
63
+ except:
 
64
  return False
65
 
66
  images = images or []
 
83
  images = [TEST_IMAGE]
84
  n_images = len(images)
85
  prompts = (
86
+ ['Describe the image in 100 words'] if n_images == 1 or map_mode else
87
+ ['Describe the images in 100 words']
 
88
  )
89
  n_prompts = len(prompts)
90
 
 
119
  allimages = []
120
  allprompts = []
121
  ordinals = [
122
+ 'first', 'second', 'third', 'fourth', 'fifth',
123
+ 'sixth', 'seventh', 'eighth', 'ninth', 'tenth',
 
 
 
 
 
 
 
 
124
  ]
125
  for images, prompt in examples:
126
  content = []
 
130
  content.append({'type': 'text', 'text': prompt})
131
  if len(images) > 1 and image_labels:
132
  for idx, img in enumerate(images):
133
+ ordinal = ordinals[idx] if idx < len(ordinals) else f'{idx+1}th'
134
  image = images[idx]
135
  descriptor = f'url: {image}'
136
  if os.path.isfile(image):
137
  descriptor = f'filename: {os.path.basename(image)}'
138
+ content.append({
139
+ 'type': 'text',
140
+ 'text': f'(this is the {ordinal} image, {descriptor})',
141
+ })
 
 
142
  content.append({'type': 'image', 'image': img})
143
  else:
144
  content.extend([{'type': 'image', 'image': image} for image in images])
 
189
  tokens_per_image_list = []
190
 
191
  # Find all img_start and img_end positions in input_ids
192
+ start_positions = (input_ids == image_start_id).nonzero(
193
+ as_tuple=True
194
+ )[0].tolist()
195
  end_positions = (input_ids == image_end_id).nonzero(as_tuple=True)[0].tolist()
196
 
197
  if len(start_positions) > 0 and len(end_positions) > 0:
 
211
  # Get the start and end indices for this image
212
  start_idx_begin = idx * n_starts_per_image
213
  end_idx_end = (idx + 1) * n_starts_per_image
214
+ if (
215
+ start_idx_begin < len(start_positions) and
216
+ end_idx_end <= len(end_positions)
217
  ):
218
  # First start position and last end position define the image span
219
  first_start = start_positions[start_idx_begin]
 
233
 
234
  for idx in range(n_images):
235
  n_tokens = tokens_per_image_list[idx] if idx < len(tokens_per_image_list) else 0
236
+ pct = (n_tokens / max_sequence_length * 100)
237
  report.append(f'├── Image {idx + 1} → {n_tokens} tokens ({pct:.1f}%)')
238
 
239
+ text_pct = (text_token_count / max_sequence_length * 100)
240
  report.append(f'└── Text: {text_token_count} tokens ({text_pct:.1f}%)')
241
 
242
  return '\n'.join(report)
 
244
 
245
  def test_jvlm():
246
  parser = argparse.ArgumentParser(
247
+ description='jina-vlm-v1 vision-language model inference.'
248
  )
 
249
  parser.add_argument(
250
  '-m',
251
  '--model',
252
+ default='.',
253
  help=(
254
+ 'Model path (default: `"."`). Set this to `"jinaai/jina-vlm-v1"` if you '
255
+ 'are running this script outside this repo.'
256
+ )
257
  )
258
  parser.add_argument(
259
  '-i',
 
327
  args = parser.parse_args()
328
 
329
  print()
330
+ print('Welcome to the jinaai/jina-vlm-v1 playground ✨')
331
  print('Use this script to test our model!')
332
  print('- Jina AI')
333
  print()
 
339
  print(f'Using dtype: {dtype}')
340
  print('Model path: ', args.model)
341
  processor = AutoProcessor.from_pretrained(
342
+ args.model, trust_remote_code=True, use_fast=False,
 
 
343
  )
344
  model = AutoModelForCausalLM.from_pretrained(
345
  args.model,
 
356
  print('Done ✅')
357
  print()
358
 
359
+ print('--- Let\'s create some conversations ...')
360
  conversations, images, prompts = _build_conversations(
361
  args.image,
362
  args.prompt,
363
  map_mode=args.map,
364
  prompt_first=args.prompt_first,
365
+ image_labels=args.image_labels
366
  )
367
  n_conversations = len(conversations)
368
  print(f'Built {n_conversations} conversations 🚀')
 
434
  print(f'├── 🖼️Images: {images[idx]}')
435
  print(f'├── 📜Prompt: {prompts[idx]}')
436
  print(f'├── 💬Chat:{texts[idx]}')
437
+ print(f'└── 🧠Response:', end='')
438
  ith_inputs = {k: v[idx].unsqueeze(0) for k, v in device_inputs.items()}
439
  with (
440
  timer,
441
  torch.no_grad(),
442
+ torch.autocast(device.type, enabled=(device.type != 'mps'), dtype=dtype)
 
 
443
  ):
444
  output = model.generate(
445
  **ith_inputs,
446
  streamer=streamer,
447
  generation_config=GenerationConfig(
448
+ max_new_tokens=args.max_tokens, do_sample=False,
 
449
  ),
450
  return_dict_in_generate=True,
451
  use_model_defaults=True,
452
  )
453
  generation_time += timer.time
454
 
455
+ out = output.sequences[0][len(input_prompts[idx].tolist()):]
456
  generated_tokens += len(out)
457
  print('Token usage report:')
458
  print(token_usage_reports[idx])
 
470
  output = model.generate(
471
  **device_inputs,
472
  generation_config=GenerationConfig(
473
+ max_new_tokens=args.max_tokens, do_sample=False,
 
474
  ),
475
  return_dict_in_generate=True,
476
  use_model_defaults=True,
 
478
  generation_time = timer.time
479
 
480
  for idx in range(n_conversations):
481
+ out = output.sequences[idx][len(input_prompts[idx].tolist()):]
482
  generated_tokens += len(out)
483
  response = processor.tokenizer.decode(out, skip_special_tokens=True)
484
  print(f'* Conversation {idx + 1}/{n_conversations}')