|
|
import time |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
import onnxruntime as ort |
|
|
|
|
|
|
|
|
np.random.seed(42) |
|
|
|
|
|
def load_model(model_path): |
|
|
""" |
|
|
Load an ONNX model and measure loading time. |
|
|
|
|
|
Args: |
|
|
model_path (str): Path to the ONNX model file. |
|
|
|
|
|
Returns: |
|
|
ort.InferenceSession: Loaded ONNX model session. |
|
|
""" |
|
|
start_time = time.time() |
|
|
session = ort.InferenceSession(model_path) |
|
|
elapsed_time = time.time() - start_time |
|
|
print(f"Model loading time: {elapsed_time:.2f} seconds") |
|
|
return session |
|
|
|
|
|
def model_inference(sess, inputs): |
|
|
""" |
|
|
Perform inference using the given ONNX session and measure inference time. |
|
|
|
|
|
Args: |
|
|
sess (ort.InferenceSession): ONNX model session. |
|
|
inputs (list or np.ndarray): Input data. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Inference results. |
|
|
""" |
|
|
input_name = sess.get_inputs()[0].name |
|
|
start_time = time.time() |
|
|
outputs = [sess.run(None, {input_name: np.array([i])})[1] for i in inputs] |
|
|
outputs = np.array(outputs).reshape(-1, 1152) |
|
|
elapsed_time = time.time() - start_time |
|
|
print(f"Model inference time: {elapsed_time:.2f} seconds") |
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained("./tokenizer") |
|
|
image = Image.open("./000000039769.jpg") |
|
|
texts = ["a photo of 2 cats", "a photo of 2 dogs"] |
|
|
inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
vision_sess = load_model("./onnx/siglip-so400m-patch14-384_vision.onnx") |
|
|
text_sess = load_model("./onnx/siglip-so400m-patch14-384_text.onnx") |
|
|
print(f"Total model loading time: {time.time() - start_time:.2f} seconds") |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
vision_outputs = model_inference(vision_sess, inputs['pixel_values'].numpy()) |
|
|
text_outputs = model_inference(text_sess, inputs['input_ids'].numpy()) |
|
|
print(f"Total inference time: {time.time() - start_time:.2f} seconds") |
|
|
|
|
|
|
|
|
image_embeds = torch.tensor(vision_outputs) |
|
|
text_embeds = torch.tensor(text_outputs) |
|
|
|
|
|
|
|
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
logit_scale = np.random.randn(1)[0] |
|
|
logit_bias = np.random.randn(1)[0] |
|
|
|
|
|
logits_per_text = ( |
|
|
np.matmul(text_embeds, image_embeds.t().numpy()) * np.exp(logit_scale) |
|
|
+ logit_bias |
|
|
) |
|
|
logits_per_image = logits_per_text.T |
|
|
|
|
|
def sigmoid(x): |
|
|
return 1 / (1 + np.exp(-x)) |
|
|
|
|
|
probs = sigmoid(logits_per_image) |
|
|
print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") |
|
|
|