Muhammadidrees commited on
Commit
9109d83
·
verified ·
1 Parent(s): 2635688

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +952 -0
app.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import subprocess
3
+ import gradio as gr
4
+
5
+ import os, sys
6
+ from glob import glob
7
+ from datetime import datetime
8
+ import math
9
+ import random
10
+ import librosa
11
+ import numpy as np
12
+ import uuid
13
+ import shutil
14
+ from tqdm import tqdm
15
+
16
+ import importlib, site, sys
17
+ from huggingface_hub import hf_hub_download, snapshot_download
18
+
19
+ # Re-discover all .pth/.egg-link files
20
+ for sitedir in site.getsitepackages():
21
+ site.addsitedir(sitedir)
22
+
23
+ # Clear caches so importlib will pick up new modules
24
+ importlib.invalidate_caches()
25
+
26
+ def sh(cmd): subprocess.check_call(cmd, shell=True)
27
+
28
+ flash_attention_installed = False
29
+
30
+ try:
31
+ print("Attempting to download and install FlashAttention wheel...")
32
+ flash_attention_wheel = hf_hub_download(
33
+ repo_id="alexnasa/flash-attn-3",
34
+ repo_type="model",
35
+ filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
36
+ )
37
+
38
+ sh(f"pip install {flash_attention_wheel}")
39
+
40
+ # tell Python to re-scan site-packages now that the egg-link exists
41
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
42
+
43
+ flash_attention_installed = True
44
+ print("FlashAttention installed successfully.")
45
+
46
+ except Exception as e:
47
+ print(f"⚠️ Could not install FlashAttention: {e}")
48
+ print("Continuing without FlashAttention...")
49
+
50
+ import torch
51
+ print(f"Torch version: {torch.__version__}")
52
+ # print(f"FlashAttention available: {flash_attention_installed}")
53
+
54
+
55
+
56
+ import torch.nn as nn
57
+ from tqdm import tqdm
58
+ from functools import partial
59
+ from omegaconf import OmegaConf
60
+ from argparse import Namespace
61
+ from gradio_extendedimage import extendedimage
62
+
63
+ import torchaudio
64
+
65
+ # load the one true config you dumped
66
+ _args_cfg = OmegaConf.load("args_config.yaml")
67
+ args = Namespace(**OmegaConf.to_container(_args_cfg, resolve=True))
68
+
69
+ from OmniAvatar.utils.args_config import set_global_args
70
+
71
+ set_global_args(args)
72
+ # args = parse_args()
73
+
74
+ from OmniAvatar.utils.io_utils import load_state_dict
75
+ from peft import LoraConfig, inject_adapter_in_model
76
+ from OmniAvatar.models.model_manager import ModelManager
77
+ from OmniAvatar.schedulers.flow_match import FlowMatchScheduler
78
+ from OmniAvatar.wan_video import WanVideoPipeline
79
+ from OmniAvatar.utils.io_utils import save_video_as_grid_and_mp4
80
+ import torchvision.transforms as TT
81
+ from transformers import Wav2Vec2FeatureExtractor
82
+ import torchvision.transforms as transforms
83
+ import torch.nn.functional as F
84
+ from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
85
+
86
+ from diffusers import FluxKontextPipeline
87
+ from diffusers.utils import load_image
88
+
89
+ from PIL import Image
90
+
91
+
92
+ os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results"
93
+
94
+
95
+ flux_pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16)
96
+ flux_pipe.load_lora_weights("alexnasa/Claymation-Kontext-Dev-Lora")
97
+ flux_pipe.to("cuda")
98
+ flux_inference = 10
99
+
100
+ def tensor_to_pil(tensor):
101
+ """
102
+ Args:
103
+ tensor: torch.Tensor with shape like
104
+ (1, C, H, W), (1, C, 1, H, W), (C, H, W), etc.
105
+ values in [-1, 1], on any device.
106
+ Returns:
107
+ A PIL.Image in RGB mode.
108
+ """
109
+ # 1) Remove batch dim if it exists
110
+ if tensor.dim() > 3 and tensor.shape[0] == 1:
111
+ tensor = tensor[0]
112
+
113
+ # 2) Squeeze out any other singleton dims (e.g. that extra frame axis)
114
+ tensor = tensor.squeeze()
115
+
116
+ # Now we should have exactly 3 dims: (C, H, W)
117
+ if tensor.dim() != 3:
118
+ raise ValueError(f"Expected 3 dims after squeeze, got {tensor.dim()}")
119
+
120
+ # 3) Move to CPU float32
121
+ tensor = tensor.cpu().float()
122
+
123
+ # 4) Undo normalization from [-1,1] -> [0,1]
124
+ tensor = (tensor + 1.0) / 2.0
125
+
126
+ # 5) Clamp to [0,1]
127
+ tensor = torch.clamp(tensor, 0.0, 1.0)
128
+
129
+ # 6) To NumPy H×W×C in [0,255]
130
+ np_img = (tensor.permute(1, 2, 0).numpy() * 255.0).round().astype("uint8")
131
+
132
+ # 7) Build PIL Image
133
+ return Image.fromarray(np_img)
134
+
135
+
136
+ def set_seed(seed: int = 42):
137
+ random.seed(seed)
138
+ np.random.seed(seed)
139
+ torch.manual_seed(seed)
140
+ torch.cuda.manual_seed(seed) # 设置当前GPU
141
+ torch.cuda.manual_seed_all(seed) # 设置所有GPU
142
+
143
+ def read_from_file(p):
144
+ with open(p, "r") as fin:
145
+ for l in fin:
146
+ yield l.strip()
147
+
148
+ def match_size(image_size, h, w):
149
+ ratio_ = 9999
150
+ size_ = 9999
151
+ select_size = None
152
+ for image_s in image_size:
153
+ ratio_tmp = abs(image_s[0] / image_s[1] - h / w)
154
+ size_tmp = abs(max(image_s) - max(w, h))
155
+ if ratio_tmp < ratio_:
156
+ ratio_ = ratio_tmp
157
+ size_ = size_tmp
158
+ select_size = image_s
159
+ if ratio_ == ratio_tmp:
160
+ if size_ == size_tmp:
161
+ select_size = image_s
162
+ return select_size
163
+
164
+ def resize_pad(image, ori_size, tgt_size):
165
+ h, w = ori_size
166
+ scale_ratio = max(tgt_size[0] / h, tgt_size[1] / w)
167
+ scale_h = int(h * scale_ratio)
168
+ scale_w = int(w * scale_ratio)
169
+
170
+ image = transforms.Resize(size=[scale_h, scale_w])(image)
171
+
172
+ padding_h = tgt_size[0] - scale_h
173
+ padding_w = tgt_size[1] - scale_w
174
+ pad_top = padding_h // 2
175
+ pad_bottom = padding_h - pad_top
176
+ pad_left = padding_w // 2
177
+ pad_right = padding_w - pad_left
178
+
179
+ image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
180
+ return image
181
+
182
+ class WanInferencePipeline(nn.Module):
183
+ def __init__(self, args):
184
+ super().__init__()
185
+ self.args = args
186
+ self.device = torch.device(f"cuda")
187
+ self.dtype = torch.bfloat16
188
+ self.pipe = self.load_model()
189
+ chained_trainsforms = []
190
+ chained_trainsforms.append(TT.ToTensor())
191
+ self.transform = TT.Compose(chained_trainsforms)
192
+
193
+ if self.args.use_audio:
194
+ from OmniAvatar.models.wav2vec import Wav2VecModel
195
+ self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
196
+ self.args.wav2vec_path
197
+ )
198
+ self.audio_encoder = Wav2VecModel.from_pretrained(self.args.wav2vec_path, local_files_only=True).to(device=self.device, dtype=self.dtype)
199
+ self.audio_encoder.feature_extractor._freeze_parameters()
200
+
201
+
202
+ def load_model(self):
203
+ ckpt_path = f'{self.args.exp_path}/pytorch_model.pt'
204
+ assert os.path.exists(ckpt_path), f"pytorch_model.pt not found in {self.args.exp_path}"
205
+ if self.args.train_architecture == 'lora':
206
+ self.args.pretrained_lora_path = pretrained_lora_path = ckpt_path
207
+ else:
208
+ resume_path = ckpt_path
209
+
210
+ self.step = 0
211
+
212
+ # Load models
213
+ model_manager = ModelManager(device="cuda", infer=True)
214
+
215
+ model_manager.load_models(
216
+ [
217
+ self.args.dit_path.split(","),
218
+ self.args.vae_path,
219
+ self.args.text_encoder_path
220
+ ],
221
+ torch_dtype=self.dtype,
222
+ device='cuda',
223
+ )
224
+
225
+ pipe = WanVideoPipeline.from_model_manager(model_manager,
226
+ torch_dtype=self.dtype,
227
+ device="cuda",
228
+ use_usp=False,
229
+ infer=True)
230
+
231
+ if self.args.train_architecture == "lora":
232
+ print(f'Use LoRA: lora rank: {self.args.lora_rank}, lora alpha: {self.args.lora_alpha}')
233
+ self.add_lora_to_model(
234
+ pipe.denoising_model(),
235
+ lora_rank=self.args.lora_rank,
236
+ lora_alpha=self.args.lora_alpha,
237
+ lora_target_modules=self.args.lora_target_modules,
238
+ init_lora_weights=self.args.init_lora_weights,
239
+ pretrained_lora_path=pretrained_lora_path,
240
+ )
241
+ print(next(pipe.denoising_model().parameters()).device)
242
+ else:
243
+ missing_keys, unexpected_keys = pipe.denoising_model().load_state_dict(load_state_dict(resume_path), strict=True)
244
+ print(f"load from {resume_path}, {len(missing_keys)} missing keys, {len(unexpected_keys)} unexpected keys")
245
+ pipe.requires_grad_(False)
246
+ pipe.eval()
247
+ # pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
248
+ return pipe
249
+
250
+ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
251
+ # Add LoRA to UNet
252
+
253
+ self.lora_alpha = lora_alpha
254
+ if init_lora_weights == "kaiming":
255
+ init_lora_weights = True
256
+
257
+ lora_config = LoraConfig(
258
+ r=lora_rank,
259
+ lora_alpha=lora_alpha,
260
+ init_lora_weights=init_lora_weights,
261
+ target_modules=lora_target_modules.split(","),
262
+ )
263
+ model = inject_adapter_in_model(lora_config, model)
264
+
265
+ # Lora pretrained lora weights
266
+ if pretrained_lora_path is not None:
267
+ state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.dtype)
268
+ if state_dict_converter is not None:
269
+ state_dict = state_dict_converter(state_dict)
270
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
271
+ all_keys = [i for i, _ in model.named_parameters()]
272
+ num_updated_keys = len(all_keys) - len(missing_keys)
273
+ num_unexpected_keys = len(unexpected_keys)
274
+
275
+ print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
276
+
277
+ def get_times(self, prompt,
278
+ image_path=None,
279
+ audio_path=None,
280
+ seq_len=101, # not used while audio_path is not None
281
+ height=720,
282
+ width=720,
283
+ overlap_frame=None,
284
+ num_steps=None,
285
+ negative_prompt=None,
286
+ guidance_scale=None,
287
+ audio_scale=None):
288
+
289
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
290
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
291
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
292
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
293
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
294
+
295
+ if image_path is not None:
296
+ image = Image.open(image_path).convert("RGB")
297
+
298
+ image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
299
+
300
+ _, _, h, w = image.shape
301
+ select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
302
+ image = resize_pad(image, (h, w), select_size)
303
+ image = image * 2.0 - 1.0
304
+ image = image[:, :, None]
305
+
306
+ else:
307
+ image = None
308
+ select_size = [height, width]
309
+ num = self.args.max_tokens * 16 * 16 * 4
310
+ den = select_size[0] * select_size[1]
311
+ L0 = num // den
312
+ diff = (L0 - 1) % 4
313
+ L = L0 - diff
314
+ if L < 1:
315
+ L = 1
316
+ T = (L + 3) // 4
317
+
318
+
319
+ if self.args.random_prefix_frames:
320
+ fixed_frame = overlap_frame
321
+ assert fixed_frame % 4 == 1
322
+ else:
323
+ fixed_frame = 1
324
+ prefix_lat_frame = (3 + fixed_frame) // 4
325
+ first_fixed_frame = 1
326
+
327
+
328
+ audio, sr = librosa.load(audio_path, sr= self.args.sample_rate)
329
+
330
+ input_values = np.squeeze(
331
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
332
+ )
333
+ input_values = torch.from_numpy(input_values).float().to(dtype=self.dtype)
334
+ audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
335
+
336
+ if audio_len < L - first_fixed_frame:
337
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
338
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
339
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
340
+
341
+ seq_len = audio_len
342
+
343
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
344
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
345
+ times += 1
346
+
347
+ return times
348
+
349
+ @torch.no_grad()
350
+ def forward(self, prompt,
351
+ image_path=None,
352
+ audio_path=None,
353
+ seq_len=101, # not used while audio_path is not None
354
+ height=720,
355
+ width=720,
356
+ overlap_frame=None,
357
+ num_steps=None,
358
+ negative_prompt=None,
359
+ guidance_scale=None,
360
+ audio_scale=None):
361
+ overlap_frame = overlap_frame if overlap_frame is not None else self.args.overlap_frame
362
+ num_steps = num_steps if num_steps is not None else self.args.num_steps
363
+ negative_prompt = negative_prompt if negative_prompt is not None else self.args.negative_prompt
364
+ guidance_scale = guidance_scale if guidance_scale is not None else self.args.guidance_scale
365
+ audio_scale = audio_scale if audio_scale is not None else self.args.audio_scale
366
+
367
+ if image_path is not None:
368
+ image = Image.open(image_path).convert("RGB")
369
+
370
+ image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
371
+
372
+ _, _, h, w = image.shape
373
+ select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
374
+ image = resize_pad(image, (h, w), select_size)
375
+ image = image * 2.0 - 1.0
376
+ image = image[:, :, None]
377
+
378
+ else:
379
+ image = None
380
+ select_size = [height, width]
381
+ # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
382
+ # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
383
+ # T = (L + 3) // 4 # latent frames
384
+
385
+ # step 1: numerator and denominator as ints
386
+ num = args.max_tokens * 16 * 16 * 4
387
+ den = select_size[0] * select_size[1]
388
+
389
+ # step 2: integer division
390
+ L0 = num // den # exact floor division, no float in sight
391
+
392
+ # step 3: make it ≡ 1 mod 4
393
+ # if L0 % 4 == 1, keep L0;
394
+ # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
395
+ # but ensure the result stays positive.
396
+ diff = (L0 - 1) % 4
397
+ L = L0 - diff
398
+ if L < 1:
399
+ L = 1 # or whatever your minimal frame count is
400
+
401
+ # step 4: latent frames
402
+ T = (L + 3) // 4
403
+
404
+
405
+ if self.args.i2v:
406
+ if self.args.random_prefix_frames:
407
+ fixed_frame = overlap_frame
408
+ assert fixed_frame % 4 == 1
409
+ else:
410
+ fixed_frame = 1
411
+ prefix_lat_frame = (3 + fixed_frame) // 4
412
+ first_fixed_frame = 1
413
+ else:
414
+ fixed_frame = 0
415
+ prefix_lat_frame = 0
416
+ first_fixed_frame = 0
417
+
418
+
419
+ if audio_path is not None and self.args.use_audio:
420
+ audio, sr = librosa.load(audio_path, sr=self.args.sample_rate)
421
+ input_values = np.squeeze(
422
+ self.wav_feature_extractor(audio, sampling_rate=16000).input_values
423
+ )
424
+ input_values = torch.from_numpy(input_values).float().to(device=self.device, dtype=self.dtype)
425
+ ori_audio_len = audio_len = math.ceil(len(input_values) / self.args.sample_rate * self.args.fps)
426
+ input_values = input_values.unsqueeze(0)
427
+ # padding audio
428
+ if audio_len < L - first_fixed_frame:
429
+ audio_len = audio_len + ((L - first_fixed_frame) - audio_len % (L - first_fixed_frame))
430
+ elif (audio_len - (L - first_fixed_frame)) % (L - fixed_frame) != 0:
431
+ audio_len = audio_len + ((L - fixed_frame) - (audio_len - (L - first_fixed_frame)) % (L - fixed_frame))
432
+ input_values = F.pad(input_values, (0, audio_len * int(self.args.sample_rate / self.args.fps) - input_values.shape[1]), mode='constant', value=0)
433
+ with torch.no_grad():
434
+ hidden_states = self.audio_encoder(input_values, seq_len=audio_len, output_hidden_states=True)
435
+ audio_embeddings = hidden_states.last_hidden_state
436
+ for mid_hidden_states in hidden_states.hidden_states:
437
+ audio_embeddings = torch.cat((audio_embeddings, mid_hidden_states), -1)
438
+ seq_len = audio_len
439
+ audio_embeddings = audio_embeddings.squeeze(0)
440
+ audio_prefix = torch.zeros_like(audio_embeddings[:first_fixed_frame])
441
+ else:
442
+ audio_embeddings = None
443
+
444
+ # loop
445
+ times = (seq_len - L + first_fixed_frame) // (L-fixed_frame) + 1
446
+ if times * (L-fixed_frame) + fixed_frame < seq_len:
447
+ times += 1
448
+ video = []
449
+ image_emb = {}
450
+ img_lat = None
451
+ if self.args.i2v:
452
+ self.pipe.load_models_to_device(['vae'])
453
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
454
+
455
+ msk = torch.zeros_like(img_lat.repeat(1, 1, T, 1, 1)[:,:1], dtype=self.dtype)
456
+ image_cat = img_lat.repeat(1, 1, T, 1, 1)
457
+ msk[:, :, 1:] = 1
458
+ image_emb["y"] = torch.cat([image_cat, msk], dim=1)
459
+
460
+ total_iterations = times * num_steps
461
+
462
+ with tqdm(total=total_iterations) as pbar:
463
+ for t in range(times):
464
+ print(f"[{t+1}/{times}]")
465
+ audio_emb = {}
466
+ if t == 0:
467
+ overlap = first_fixed_frame
468
+ else:
469
+ overlap = fixed_frame
470
+ image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
471
+ prefix_overlap = (3 + overlap) // 4
472
+ if audio_embeddings is not None:
473
+ if t == 0:
474
+ audio_tensor = audio_embeddings[
475
+ :min(L - overlap, audio_embeddings.shape[0])
476
+ ]
477
+ else:
478
+ audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
479
+ audio_tensor = audio_embeddings[
480
+ audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
481
+ ]
482
+
483
+ audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
484
+ audio_prefix = audio_tensor[-fixed_frame:]
485
+ audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
486
+ audio_emb["audio_emb"] = audio_tensor
487
+ else:
488
+ audio_prefix = None
489
+ if image is not None and img_lat is None:
490
+ self.pipe.load_models_to_device(['vae'])
491
+ img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
492
+ assert img_lat.shape[2] == prefix_overlap
493
+ img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
494
+ frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
495
+ negative_prompt, num_inference_steps=num_steps,
496
+ cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
497
+ return_latent=True,
498
+ tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B", progress_bar_cmd=pbar)
499
+
500
+ torch.cuda.empty_cache()
501
+ img_lat = None
502
+ image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
503
+
504
+ if t == 0:
505
+ video.append(frames)
506
+ else:
507
+ video.append(frames[:, overlap:])
508
+ video = torch.cat(video, dim=1)
509
+ video = video[:, :ori_audio_len + 1]
510
+
511
+ return video
512
+
513
+
514
+ snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./pretrained_models/Wan2.1-T2V-1.3B")
515
+ snapshot_download(repo_id="facebook/wav2vec2-base-960h", local_dir="./pretrained_models/wav2vec2-base-960h")
516
+ snapshot_download(repo_id="OmniAvatar/OmniAvatar-1.3B", local_dir="./pretrained_models/OmniAvatar-1.3B")
517
+
518
+ import tempfile
519
+
520
+
521
+ set_seed(args.seed)
522
+ seq_len = args.seq_len
523
+ inferpipe = WanInferencePipeline(args)
524
+
525
+
526
+ ADAPTIVE_PROMPT_TEMPLATES = [
527
+ "A claymation video of a person speaking and moving their head accordingly but without moving their hands.",
528
+ "A claymation video of a person speaking and sometimes looking directly to the camera and moving their eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera but with subtle hands movement that complements their speech.",
529
+ "A claymation video of a person speaking and sometimes looking directly to the camera and moving their eyes and pupils and head accordingly and turning and looking at the camera and looking away from the camera based on their movements with dynamic and rhythmic and subtle hand gestures that complement their speech and don't disrupt things if they are holding something with their hands. Their hands are clearly visible, independent, and unobstructed. Their facial expressions are expressive and full of emotion, enhancing the delivery. The camera remains steady, capturing sharp, clear movements and a focused, engaging presence."
530
+ ]
531
+
532
+ def slider_value_change(image_path, audio_path, text, num_steps, session_state):
533
+ return update_generate_button(image_path, audio_path, text, num_steps, session_state), text
534
+
535
+
536
+ def update_generate_button(image_path, audio_path, text, num_steps, session_state):
537
+
538
+ if image_path is None or audio_path is None:
539
+ return gr.update(value="⌚ Zero GPU Required: --")
540
+
541
+ duration_s = get_duration(image_path, audio_path, text, num_steps, session_state, None)
542
+ duration_m = duration_s / 60
543
+
544
+ return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
545
+
546
+ def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
547
+
548
+ if image_path is None:
549
+ gr.Info("Step1: Please Provide an Image or Choose from Image Samples")
550
+ print("Step1: Please Provide an Image or Choose from Image Samples")
551
+
552
+ return 0
553
+
554
+ if audio_path is None:
555
+ gr.Info("Step2: Please Provide an Audio or Choose from Audio Samples")
556
+ print("Step2: Please Provide an Audio or Choose from Audio Samples")
557
+
558
+ return 0
559
+
560
+
561
+ audio_chunks = inferpipe.get_times(
562
+ prompt=text,
563
+ image_path=image_path,
564
+ audio_path=audio_path,
565
+ seq_len=args.seq_len,
566
+ num_steps=num_steps
567
+ )
568
+
569
+ if session_id is None:
570
+ session_id = uuid.uuid4().hex
571
+
572
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
573
+
574
+ dirpath = os.path.dirname(image_path)
575
+ basename = os.path.basename(image_path)
576
+ name, ext = os.path.splitext(basename)
577
+
578
+ new_basename = f"clay_{name}{ext}"
579
+ clay_image_path = os.path.join(dirpath, new_basename)
580
+
581
+ if os.path.exists(clay_image_path):
582
+ claymation = 0
583
+ else:
584
+ claymation = flux_inference * 2
585
+
586
+ warmup_s = 25
587
+ last_step_s = 20
588
+ duration_s = (4 * (num_steps - 1) + last_step_s)
589
+
590
+ if audio_chunks > 1:
591
+ duration_s = (duration_s * audio_chunks)
592
+
593
+ duration_s = duration_s + warmup_s + claymation
594
+
595
+ print(f'for {audio_chunks} times and {num_steps} steps, {session_id} is preparing for {duration_s}')
596
+
597
+ return int(duration_s)
598
+
599
+ def preprocess_img(input_image_path, raw_image_path, session_id = None):
600
+
601
+ if session_id is None:
602
+ session_id = uuid.uuid4().hex
603
+
604
+ if input_image_path is None:
605
+ return None, None
606
+
607
+ if raw_image_path == '':
608
+ raw_image_path = input_image_path
609
+
610
+ image = Image.open(raw_image_path).convert("RGB")
611
+
612
+ img_id = uuid.uuid4().hex
613
+
614
+ image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
615
+
616
+ _, _, h, w = image.shape
617
+ select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
618
+ image = resize_pad(image, (h, w), select_size)
619
+ image = image * 2.0 - 1.0
620
+ image = image[:, :, None]
621
+
622
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
623
+
624
+ img_dir = output_dir + '/image'
625
+ os.makedirs(img_dir, exist_ok=True)
626
+ input_img_path = os.path.join(img_dir, f"img_{img_id}.jpg")
627
+
628
+ image = tensor_to_pil(image)
629
+ image.save(input_img_path)
630
+
631
+ return input_img_path, raw_image_path
632
+
633
+ def infer_example(image_path, audio_path, num_steps, raw_image_path, session_id = None, progress=gr.Progress(track_tqdm=True),):
634
+
635
+ current_image_size = args.image_sizes_720
636
+ args.image_sizes_720 = [[720, 400]]
637
+ text = ADAPTIVE_PROMPT_TEMPLATES[2]
638
+
639
+ result = infer(image_path, audio_path, text, num_steps, session_id, progress)
640
+
641
+ args.image_sizes_720 = current_image_size
642
+
643
+ return result
644
+
645
+ @spaces.GPU(duration=get_duration)
646
+ def infer(image_path, audio_path, text, num_steps, session_id = None, progress=gr.Progress(track_tqdm=True),):
647
+
648
+ if image_path is None:
649
+
650
+ return None
651
+
652
+ if audio_path is None:
653
+
654
+ return None
655
+
656
+ if session_id is None:
657
+ session_id = uuid.uuid4().hex
658
+
659
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
660
+
661
+ # Decompose the path
662
+ dirpath = os.path.dirname(image_path)
663
+ basename = os.path.basename(image_path) # e.g. "photo.png"
664
+ name, ext = os.path.splitext(basename) # name="photo", ext=".png"
665
+
666
+ # Rebuild with "clay_" prefix
667
+ new_basename = f"clay_{name}{ext}" # "clay_photo.png"
668
+ clay_image_path = os.path.join(dirpath, new_basename)
669
+
670
+ # If the output file already exists, skip inference
671
+ if os.path.exists(clay_image_path):
672
+
673
+ print("using existing image")
674
+
675
+ else:
676
+
677
+ flux_prompt = "in style of omniavatar-claymation"
678
+ raw_image = load_image(image_path)
679
+ w, h = raw_image.size
680
+
681
+ clay_image = flux_pipe(image=raw_image, width=w, height=h, prompt=flux_prompt, negative_prompt=args.negative_prompt, num_inference_steps=flux_inference, true_cfg_scale=2.5).images[0]
682
+ clay_image.save(clay_image_path)
683
+
684
+
685
+ audio_dir = output_dir + '/audio'
686
+ os.makedirs(audio_dir, exist_ok=True)
687
+ if args.silence_duration_s > 0:
688
+ input_audio_path = os.path.join(audio_dir, f"audio_input.wav")
689
+ else:
690
+ input_audio_path = audio_path
691
+ prompt_dir = output_dir + '/prompt'
692
+ os.makedirs(prompt_dir, exist_ok=True)
693
+
694
+ if args.silence_duration_s > 0:
695
+ add_silence_to_audio_ffmpeg(audio_path, input_audio_path, args.silence_duration_s)
696
+
697
+ tmp2_audio_path = os.path.join(audio_dir, f"audio_out.wav")
698
+ prompt_path = os.path.join(prompt_dir, f"prompt.txt")
699
+
700
+ video = inferpipe(
701
+ prompt=text,
702
+ image_path=clay_image_path,
703
+ audio_path=input_audio_path,
704
+ seq_len=args.seq_len,
705
+ num_steps=num_steps
706
+ )
707
+
708
+ torch.cuda.empty_cache()
709
+
710
+ add_silence_to_audio_ffmpeg(audio_path, tmp2_audio_path, 1.0 / args.fps + args.silence_duration_s)
711
+ video_paths = save_video_as_grid_and_mp4(video,
712
+ output_dir,
713
+ args.fps,
714
+ prompt=text,
715
+ prompt_path = prompt_path,
716
+ audio_path=tmp2_audio_path if args.use_audio else None,
717
+ prefix=f'result')
718
+
719
+ return video_paths[0]
720
+
721
+ def apply_image(request):
722
+ print('image applied')
723
+ return request, None
724
+
725
+ def apply_audio(request):
726
+ print('audio applied')
727
+ return request
728
+
729
+ def cleanup(request: gr.Request):
730
+
731
+ sid = request.session_hash
732
+ if sid:
733
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
734
+ shutil.rmtree(d1, ignore_errors=True)
735
+
736
+ def start_session(request: gr.Request):
737
+
738
+ return request.session_hash
739
+
740
+ def orientation_changed(session_id, evt: gr.EventData):
741
+
742
+ detail = getattr(evt, "data", None) or getattr(evt, "_data", {}) or {}
743
+
744
+ if detail['value'] == "9:16":
745
+ args.image_sizes_720 = [[720, 400]]
746
+ elif detail['value'] == "1:1":
747
+ args.image_sizes_720 = [[720, 720]]
748
+ elif detail['value'] == "16:9":
749
+ args.image_sizes_720 = [[400, 720]]
750
+
751
+ print(f'{session_id} has {args.image_sizes_720} orientation')
752
+
753
+ def clear_raw_image():
754
+ return ''
755
+
756
+ def preprocess_audio_first_nseconds_librosa(audio_path, limit_in_seconds, session_id=None):
757
+
758
+ if not audio_path:
759
+ return None
760
+
761
+ # Robust duration check (librosa changed arg name across versions)
762
+ try:
763
+ dur = librosa.get_duration(path=audio_path)
764
+ except TypeError:
765
+ dur = librosa.get_duration(filename=audio_path)
766
+
767
+ # Small tolerance to avoid re-encoding 4.9999s files
768
+ if dur < 5.0 - 1e-3:
769
+ return audio_path
770
+
771
+ if session_id is None:
772
+ session_id = uuid.uuid4().hex
773
+
774
+ # Where we'll store per-session processed audio
775
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
776
+ audio_dir = os.path.join(output_dir, "audio")
777
+ os.makedirs(audio_dir, exist_ok=True)
778
+
779
+ trimmed_path = os.path.join(audio_dir, f"audio_input_{limit_in_seconds}s.wav")
780
+ sr = getattr(args, "sample_rate", 16000)
781
+
782
+ y, _ = librosa.load(audio_path, sr=sr, mono=True, duration=float(limit_in_seconds))
783
+
784
+ # Save as 16-bit PCM mono WAV
785
+ waveform = torch.from_numpy(y).unsqueeze(0) # [1, num_samples]
786
+ torchaudio.save(
787
+ trimmed_path,
788
+ waveform,
789
+ sr,
790
+ encoding="PCM_S",
791
+ bits_per_sample=16,
792
+ format="wav",
793
+ )
794
+
795
+ return trimmed_path
796
+
797
+
798
+ css = """
799
+ #col-container {
800
+ margin: 0 auto;
801
+ max-width: 1560px;
802
+ }
803
+ /* editable vs locked, reusing theme variables that adapt to dark/light */
804
+ .stateful textarea:not(:disabled):not([readonly]) {
805
+ color: var(--color-text) !important; /* accent in both modes */
806
+ }
807
+ .stateful textarea:disabled,
808
+ .stateful textarea[readonly]{
809
+ color: var(--body-text-color-subdued) !important; /* subdued in both modes */
810
+ }
811
+ """
812
+
813
+ with gr.Blocks(css=css) as demo:
814
+
815
+ session_state = gr.State()
816
+ demo.load(start_session, outputs=[session_state])
817
+
818
+
819
+ with gr.Column(elem_id="col-container"):
820
+ gr.HTML(
821
+ """
822
+ <div style="text-align: center;">
823
+ <div style="display: flex; justify-content: center;">
824
+ <img src="https://huggingface.co/spaces/alexnasa/OmniAvatar-Clay-Fast/resolve/main/assets/logo-omniavatar.png" alt="Logo">
825
+ </div>
826
+ </div>
827
+ <div style="text-align: center;">
828
+ <p style="font-size:16px; display: inline; margin: 0;">
829
+ <strong>OmniAvatar</strong> – Efficient Audio-Driven Avatar Video Generation with Adaptive Body Animation
830
+ </p>
831
+ <a href="https://huggingface.co/OmniAvatar/OmniAvatar-1.3B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
832
+ [model]
833
+ </a>
834
+ </div>
835
+
836
+ <div style="text-align: center;">
837
+ <strong>HF Space by:</strong>
838
+ <a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
839
+ <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
840
+ </a>
841
+ </div>
842
+ <div style="text-align: center;">
843
+ <p style="font-size:16px; display: inline; margin: 0;">
844
+ If you looking for realism please try the other HF Space:
845
+ </p>
846
+ <a href="https://huggingface.co/spaces/alexnasa/OmniAvatar" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
847
+ <img src="https://img.shields.io/badge/🤗-HF Demo-yellow.svg">
848
+ </a>
849
+ </div>
850
+
851
+ """
852
+ )
853
+
854
+ with gr.Row():
855
+
856
+ with gr.Column(scale=1):
857
+
858
+ image_input = extendedimage(label="Reference Image", type="filepath", height=512)
859
+ audio_input = gr.Audio(label="Input Audio", type="filepath")
860
+ gr.Markdown("*Change the duration limit in Advanced Settings*")
861
+
862
+
863
+ with gr.Column(scale=1):
864
+
865
+ output_video = gr.Video(label="Avatar", height=512)
866
+ num_steps = gr.Slider(8, 50, value=8, step=1, label="Steps")
867
+ time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
868
+ infer_btn = gr.Button("🗿 Clay Me", variant="primary")
869
+ with gr.Accordion("Advanced Settings", open=False):
870
+ raw_img_text = gr.Text(show_label=False, label="", value='', visible=False)
871
+ limit_in_seconds = gr.Slider(5, 180, value=5, step=5, label="Duration")
872
+ text_input = gr.Textbox(label="Prompt", lines=6, value= ADAPTIVE_PROMPT_TEMPLATES[2])
873
+
874
+ with gr.Column(scale=1):
875
+
876
+ cached_examples = gr.Examples(
877
+ examples=[
878
+
879
+ [
880
+ "examples/images/female-002.png",
881
+ "examples/audios/lion.wav",
882
+ 10,
883
+ ''
884
+ ],
885
+
886
+ [
887
+ "examples/images/female-003.png",
888
+ "examples/audios/fox.wav",
889
+ 10,
890
+ ''
891
+ ],
892
+
893
+ [
894
+ "examples/images/female-009.png",
895
+ "examples/audios/script.wav",
896
+ 10,
897
+ ''
898
+ ],
899
+
900
+ ],
901
+ label="Cached Examples",
902
+ inputs=[image_input, audio_input, num_steps, raw_img_text],
903
+ outputs=[output_video],
904
+ fn=infer_example,
905
+ cache_examples=True
906
+ )
907
+
908
+ uncached_examples = gr.Examples(
909
+ examples=[
910
+
911
+ [
912
+ "examples/images/male-001.png",
913
+ "examples/audios/ocean.wav",
914
+ 10,
915
+ ''
916
+ ],
917
+
918
+
919
+ ],
920
+ label="Uncached Examples",
921
+ inputs=[image_input, audio_input, num_steps, raw_img_text],
922
+ cache_examples=False
923
+ )
924
+
925
+
926
+ infer_btn.click(
927
+ fn=infer,
928
+ inputs=[image_input, audio_input, text_input, num_steps, session_state],
929
+ outputs=[output_video]
930
+ )
931
+
932
+ image_input.orientation(fn=orientation_changed, inputs=[session_state]).then(fn=preprocess_img, inputs=[image_input, raw_img_text, session_state], outputs=[image_input, raw_img_text])
933
+ image_input.clear(fn=clear_raw_image, outputs=[raw_img_text])
934
+ image_input.upload(fn=preprocess_img, inputs=[image_input, raw_img_text, session_state], outputs=[image_input, raw_img_text])
935
+ image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
936
+ audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
937
+ num_steps.change(fn=slider_value_change, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required, text_input])
938
+ audio_input.upload(fn=apply_audio, inputs=[audio_input], outputs=[audio_input]
939
+ ).then(
940
+ fn=preprocess_audio_first_nseconds_librosa,
941
+ inputs=[audio_input, limit_in_seconds, session_state],
942
+ outputs=[audio_input],
943
+ ).then(
944
+ fn=apply_audio,
945
+ inputs=[audio_input],
946
+ outputs=[audio_input]
947
+ )
948
+
949
+ if __name__ == "__main__":
950
+ demo.unload(cleanup)
951
+ demo.queue()
952
+ demo.launch(ssr_mode=False)