siglip-so400m-patch14-384 / python /inference_onnx.py
qqc1989's picture
Upload 11 files
f752476 verified
import time
from PIL import Image
import numpy as np
import torch
from transformers import AutoProcessor, AutoModel
import onnxruntime as ort
# Set random seed
np.random.seed(42)
def load_model(model_path):
"""
Load an ONNX model and measure loading time.
Args:
model_path (str): Path to the ONNX model file.
Returns:
ort.InferenceSession: Loaded ONNX model session.
"""
start_time = time.time()
session = ort.InferenceSession(model_path)
elapsed_time = time.time() - start_time
print(f"Model loading time: {elapsed_time:.2f} seconds")
return session
def model_inference(sess, inputs):
"""
Perform inference using the given ONNX session and measure inference time.
Args:
sess (ort.InferenceSession): ONNX model session.
inputs (list or np.ndarray): Input data.
Returns:
np.ndarray: Inference results.
"""
input_name = sess.get_inputs()[0].name
start_time = time.time()
outputs = [sess.run(None, {input_name: np.array([i])})[1] for i in inputs]
outputs = np.array(outputs).reshape(-1, 1152) # Adjust output shape if needed
elapsed_time = time.time() - start_time
print(f"Model inference time: {elapsed_time:.2f} seconds")
return outputs
# Load processor and prepare inputs
processor = AutoProcessor.from_pretrained("./tokenizer")
image = Image.open("./000000039769.jpg")
texts = ["a photo of 2 cats", "a photo of 2 dogs"]
inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
# Load vision and text models
start_time = time.time()
vision_sess = load_model("./onnx/siglip-so400m-patch14-384_vision.onnx")
text_sess = load_model("./onnx/siglip-so400m-patch14-384_text.onnx")
print(f"Total model loading time: {time.time() - start_time:.2f} seconds")
# Run inference on both models
start_time = time.time()
vision_outputs = model_inference(vision_sess, inputs['pixel_values'].numpy())
text_outputs = model_inference(text_sess, inputs['input_ids'].numpy())
print(f"Total inference time: {time.time() - start_time:.2f} seconds")
# Post-processing
image_embeds = torch.tensor(vision_outputs)
text_embeds = torch.tensor(text_outputs)
# Normalize features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# Compute similarity logits
logit_scale = np.random.randn(1)[0] # Replace with trained value if available
logit_bias = np.random.randn(1)[0]
logits_per_text = (
np.matmul(text_embeds, image_embeds.t().numpy()) * np.exp(logit_scale)
+ logit_bias
)
logits_per_image = logits_per_text.T
def sigmoid(x):
return 1 / (1 + np.exp(-x))
probs = sigmoid(logits_per_image) # Get probabilities
print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")