JacobLinCool commited on
Commit
6510c49
Β·
verified Β·
1 Parent(s): 9ecb18c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -21
app.py CHANGED
@@ -16,6 +16,7 @@ import os
16
 
17
  DEVICE = Accelerator().device
18
  MODEL_NAME = "qihoo360/fg-clip2-so400m"
 
19
 
20
 
21
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True).to(
@@ -69,21 +70,28 @@ def generate_image_embeddings(zip_file):
69
  if len(images) == 0:
70
  return None, "❌ No valid images found in the zip file"
71
 
72
- # Generate embeddings
73
  embeddings = []
74
  print(f"Generating embeddings for {len(images)} images...")
75
  with torch.no_grad():
76
- for i, image in enumerate(images):
77
- print(f"Processing image {i + 1}/{len(images)}")
 
 
 
 
 
 
 
78
  image_input = image_processor(
79
- images=image,
80
- max_num_patches=determine_max_value(image),
81
  return_tensors="pt",
82
  ).to(DEVICE)
83
- image_feature = model.get_image_features(**image_input)
84
 
85
- # Normalize the embedding
86
- normalized_features = image_feature / image_feature.norm(
87
  dim=-1, keepdim=True
88
  )
89
  embeddings.append(normalized_features.cpu().numpy())
@@ -147,12 +155,12 @@ def extract_frames(video_path: str, fps: int = 4):
147
 
148
 
149
  @spaces.GPU
150
- def generate_video_embeddings(video_file, fps):
151
  """
152
  Generate embeddings from video frames.
153
 
154
  Args:
155
- video_file: Uploaded video file
156
  fps: Frames per second to extract
157
 
158
  Returns:
@@ -160,28 +168,35 @@ def generate_video_embeddings(video_file, fps):
160
  """
161
  try:
162
  # Extract frames
163
- print(f"Extracting frames from video: {video_file.name} at {fps} fps")
164
- frames = extract_frames(video_file.name, fps)
165
  print(f"Extracted {len(frames)} frames from video")
166
 
167
  if len(frames) == 0:
168
  return None, "❌ No frames could be extracted from the video"
169
 
170
- # Generate embeddings
171
  embeddings = []
172
  print(f"Generating embeddings for {len(frames)} frames...")
173
  with torch.no_grad():
174
- for i, frame in enumerate(frames):
175
- print(f"Processing frame {i + 1}/{len(frames)}")
 
 
 
 
 
 
 
176
  image_input = image_processor(
177
- images=frame,
178
- max_num_patches=determine_max_value(frame),
179
  return_tensors="pt",
180
  ).to(DEVICE)
181
- image_feature = model.get_image_features(**image_input)
182
 
183
- # Normalize the embedding
184
- normalized_features = image_feature / image_feature.norm(
185
  dim=-1, keepdim=True
186
  )
187
  embeddings.append(normalized_features.cpu().numpy())
@@ -250,8 +265,15 @@ with gr.Blocks(title="Video & Image Embedding Generator") as demo:
250
  vid_output = gr.JSON(label="Embeddings (JSON)")
251
  vid_status = gr.Textbox(label="Status", lines=3)
252
 
 
 
 
 
 
 
 
253
  vid_submit_btn.click(
254
- fn=generate_video_embeddings,
255
  inputs=[video_input, fps_input],
256
  outputs=[vid_output, vid_status],
257
  )
 
16
 
17
  DEVICE = Accelerator().device
18
  MODEL_NAME = "qihoo360/fg-clip2-so400m"
19
+ BATCH_SIZE = 64
20
 
21
 
22
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True).to(
 
70
  if len(images) == 0:
71
  return None, "❌ No valid images found in the zip file"
72
 
73
+ # Generate embeddings with batching
74
  embeddings = []
75
  print(f"Generating embeddings for {len(images)} images...")
76
  with torch.no_grad():
77
+ for i in range(0, len(images), BATCH_SIZE):
78
+ batch = images[i : i + BATCH_SIZE]
79
+ print(
80
+ f"Processing batch {i // BATCH_SIZE + 1}/{(len(images) + BATCH_SIZE - 1) // BATCH_SIZE} ({len(batch)} images)"
81
+ )
82
+
83
+ # Use the same max_num_patches for all images in batch
84
+ max_patches = max(determine_max_value(img) for img in batch)
85
+
86
  image_input = image_processor(
87
+ images=batch,
88
+ max_num_patches=max_patches,
89
  return_tensors="pt",
90
  ).to(DEVICE)
91
+ image_features = model.get_image_features(**image_input)
92
 
93
+ # Normalize the embeddings
94
+ normalized_features = image_features / image_features.norm(
95
  dim=-1, keepdim=True
96
  )
97
  embeddings.append(normalized_features.cpu().numpy())
 
155
 
156
 
157
  @spaces.GPU
158
+ def generate_video_embeddings(video_path, fps):
159
  """
160
  Generate embeddings from video frames.
161
 
162
  Args:
163
+ video_path: Path to video file (str)
164
  fps: Frames per second to extract
165
 
166
  Returns:
 
168
  """
169
  try:
170
  # Extract frames
171
+ print(f"Extracting frames from video: {video_path} at {fps} fps")
172
+ frames = extract_frames(video_path, fps)
173
  print(f"Extracted {len(frames)} frames from video")
174
 
175
  if len(frames) == 0:
176
  return None, "❌ No frames could be extracted from the video"
177
 
178
+ # Generate embeddings with batching
179
  embeddings = []
180
  print(f"Generating embeddings for {len(frames)} frames...")
181
  with torch.no_grad():
182
+ for i in range(0, len(frames), BATCH_SIZE):
183
+ batch = frames[i : i + BATCH_SIZE]
184
+ print(
185
+ f"Processing batch {i // BATCH_SIZE + 1}/{(len(frames) + BATCH_SIZE - 1) // BATCH_SIZE} ({len(batch)} frames)"
186
+ )
187
+
188
+ # Use the same max_num_patches for all frames in batch
189
+ max_patches = max(determine_max_value(frame) for frame in batch)
190
+
191
  image_input = image_processor(
192
+ images=batch,
193
+ max_num_patches=max_patches,
194
  return_tensors="pt",
195
  ).to(DEVICE)
196
+ image_features = model.get_image_features(**image_input)
197
 
198
+ # Normalize the embeddings
199
+ normalized_features = image_features / image_features.norm(
200
  dim=-1, keepdim=True
201
  )
202
  embeddings.append(normalized_features.cpu().numpy())
 
265
  vid_output = gr.JSON(label="Embeddings (JSON)")
266
  vid_status = gr.Textbox(label="Status", lines=3)
267
 
268
+ def handle_video_upload(video_file, fps):
269
+ if video_file is None:
270
+ return None, "❌ Please upload a video file"
271
+ return generate_video_embeddings(
272
+ video_file.name if hasattr(video_file, "name") else video_file, fps
273
+ )
274
+
275
  vid_submit_btn.click(
276
+ fn=handle_video_upload,
277
  inputs=[video_input, fps_input],
278
  outputs=[vid_output, vid_status],
279
  )