ZJUCQR commited on
Commit
e2bca25
·
1 Parent(s): e46f0d2

Add hf_AC audio generation demo

Browse files

- Implement Gradio interface for video-to-audio generation
- Add Chinese language support
- Include hf_AC model integration
- Add requirements.txt and packages.txt for HF Space deployment
- Add example prompts and usage tips
- Include deployment documentation

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +87 -0
  2. DEPLOYMENT.md +103 -0
  3. README.md +43 -8
  4. app.py +348 -58
  5. hf_AC/.gitignore +224 -0
  6. hf_AC/README.md +25 -0
  7. hf_AC/config/__init__.py +0 -0
  8. hf_AC/config/base_config.yaml +62 -0
  9. hf_AC/config/data/base.yaml +69 -0
  10. hf_AC/config/data/base2.yaml +69 -0
  11. hf_AC/config/eval_config.yaml +17 -0
  12. hf_AC/config/eval_data/base.yaml +22 -0
  13. hf_AC/config/hydra/job_logging/custom-eval.yaml +32 -0
  14. hf_AC/config/hydra/job_logging/custom-no-rank.yaml +32 -0
  15. hf_AC/config/hydra/job_logging/custom-simplest.yaml +26 -0
  16. hf_AC/config/hydra/job_logging/custom.yaml +33 -0
  17. hf_AC/config/train_config.yaml +41 -0
  18. hf_AC/config/train_config_2node.yaml +41 -0
  19. hf_AC/config/train_config_2node2.yaml +41 -0
  20. hf_AC/inf.py +181 -0
  21. hf_AC/mmaudio/__init__.py +0 -0
  22. hf_AC/mmaudio/data/__init__.py +0 -0
  23. hf_AC/mmaudio/data/av_utils.py +162 -0
  24. hf_AC/mmaudio/data/data_setup.py +177 -0
  25. hf_AC/mmaudio/data/eval/__init__.py +0 -0
  26. hf_AC/mmaudio/data/eval/audiocaps.py +39 -0
  27. hf_AC/mmaudio/data/eval/moviegen.py +131 -0
  28. hf_AC/mmaudio/data/eval/video_dataset.py +231 -0
  29. hf_AC/mmaudio/data/extracted_audio.py +97 -0
  30. hf_AC/mmaudio/data/extracted_vgg.py +109 -0
  31. hf_AC/mmaudio/data/extraction/__init__.py +0 -0
  32. hf_AC/mmaudio/data/extraction/vgg_sound.py +208 -0
  33. hf_AC/mmaudio/data/extraction/wav_dataset.py +135 -0
  34. hf_AC/mmaudio/data/mm_dataset.py +45 -0
  35. hf_AC/mmaudio/data/utils.py +148 -0
  36. hf_AC/mmaudio/eval_utils.py +249 -0
  37. hf_AC/mmaudio/ext/__init__.py +1 -0
  38. hf_AC/mmaudio/ext/autoencoder/__init__.py +1 -0
  39. hf_AC/mmaudio/ext/autoencoder/autoencoder.py +52 -0
  40. hf_AC/mmaudio/ext/autoencoder/edm2_utils.py +168 -0
  41. hf_AC/mmaudio/ext/autoencoder/vae.py +369 -0
  42. hf_AC/mmaudio/ext/autoencoder/vae_modules.py +117 -0
  43. hf_AC/mmaudio/ext/bigvgan/LICENSE +21 -0
  44. hf_AC/mmaudio/ext/bigvgan/__init__.py +1 -0
  45. hf_AC/mmaudio/ext/bigvgan/activations.py +120 -0
  46. hf_AC/mmaudio/ext/bigvgan/alias_free_torch/__init__.py +6 -0
  47. hf_AC/mmaudio/ext/bigvgan/alias_free_torch/act.py +28 -0
  48. hf_AC/mmaudio/ext/bigvgan/alias_free_torch/filter.py +95 -0
  49. hf_AC/mmaudio/ext/bigvgan/alias_free_torch/resample.py +49 -0
  50. hf_AC/mmaudio/ext/bigvgan/bigvgan.py +32 -0
.gitignore ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+
24
+ # PyTorch
25
+ *.pth
26
+ *.pt
27
+ *.ckpt
28
+ weights/
29
+ checkpoints/
30
+
31
+ # Gradio
32
+ gradio_cached_examples/
33
+ flagged/
34
+
35
+ # Temporary files
36
+ *.tmp
37
+ *.temp
38
+ /tmp/
39
+ temp/
40
+
41
+ # Logs
42
+ *.log
43
+ logs/
44
+
45
+ # Environment
46
+ .env
47
+ .venv
48
+ env/
49
+ venv/
50
+ ENV/
51
+ env.bak/
52
+ venv.bak/
53
+
54
+ # IDE
55
+ .vscode/
56
+ .idea/
57
+ *.swp
58
+ *.swo
59
+ *~
60
+
61
+ # OS
62
+ .DS_Store
63
+ .DS_Store?
64
+ ._*
65
+ .Spotlight-V100
66
+ .Trashes
67
+ ehthumbs.db
68
+ Thumbs.db
69
+
70
+ # Model files (too large for git)
71
+ *.safetensors
72
+ model.pth
73
+ *.bin
74
+
75
+ # Audio/Video files
76
+ *.wav
77
+ *.mp3
78
+ *.mp4
79
+ *.avi
80
+ *.mov
81
+
82
+ # Jupyter
83
+ .ipynb_checkpoints/
84
+
85
+ # Cache
86
+ .cache/
87
+ *.cache
DEPLOYMENT.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Hugging Face Space 部署指南
2
+
3
+ ## 📁 文件结构
4
+
5
+ 确保你的HF Space包含以下文件:
6
+
7
+ ```
8
+ Acfoley/
9
+ ├── README.md # Space配置和说明
10
+ ├── app.py # 主应用文件
11
+ ├── requirements.txt # Python依赖
12
+ ├── packages.txt # 系统依赖
13
+ ├── hf_AC/ # hf_AC模型代码
14
+ └── .gitignore # Git忽略文件
15
+ ```
16
+
17
+ ## 🔧 部署步骤
18
+
19
+ ### 1. 上传代码到HF Space
20
+
21
+ 将所有文件上传到你的Hugging Face Space仓库:
22
+
23
+ ```bash
24
+ git add .
25
+ git commit -m "Add hf_AC audio generation demo"
26
+ git push
27
+ ```
28
+
29
+ ### 2. 模型权重下载
30
+
31
+ 模型会自动从以下位置下载:
32
+ - 主模型: `https://huggingface.co/FF2416/AC-Foley/resolve/main/model.pth`
33
+ - 其他组件会根据需要自动下载
34
+
35
+ ### 3. 环境配置
36
+
37
+ HF Space会自动:
38
+ - 安装`requirements.txt`中的Python包
39
+ - 安装`packages.txt`中的系统依赖
40
+ - 运行`app.py`启动Gradio界面
41
+
42
+ ## 📋 README.md 配置
43
+
44
+ 确保README.md顶部包含正确的YAML配置:
45
+
46
+ ```yaml
47
+ ---
48
+ title: hf_AC Audio Foley Generator
49
+ emoji: 🎵
50
+ colorFrom: blue
51
+ colorTo: green
52
+ sdk: gradio
53
+ sdk_version: 5.42.0
54
+ app_file: app.py
55
+ pinned: false
56
+ license: mit
57
+ ---
58
+ ```
59
+
60
+ ## 🔍 故障排除
61
+
62
+ ### 常见问题
63
+
64
+ 1. **模型下载失败**
65
+ - 检查网络连接
66
+ - 确认模型URL可访问
67
+
68
+ 2. **依赖安装失败**
69
+ - 检查`requirements.txt`格式
70
+ - 确认包版本兼容性
71
+
72
+ 3. **内存不足**
73
+ - HF Space免费版有内存限制
74
+ - 考虑优化模型或升级到付费版
75
+
76
+ ### 调试方法
77
+
78
+ 1. 查看Space日志
79
+ 2. 运行`test_setup.py`验证环境
80
+ 3. 检查模型文件是否正确下载
81
+
82
+ ## 🎯 使用说明
83
+
84
+ 部署成功后,用户可以:
85
+
86
+ 1. 上传MP4视频文件
87
+ 2. 输入音频描述文字
88
+ 3. 调整生成参数
89
+ 4. 点击生成按钮
90
+ 5. 下载生成的音频
91
+
92
+ ## 📊 性能优化
93
+
94
+ - 首次运行需要下载模型(约几GB)
95
+ - 生成时间取决于视频长度和硬件
96
+ - 建议视频时长控制在15秒以内
97
+
98
+ ## 🔗 相关链接
99
+
100
+ - [hf_AC GitHub](https://github.com/ff2416/hf_AC)
101
+ - [模型权重](https://huggingface.co/FF2416/AC-Foley)
102
+ - [Gradio文档](https://gradio.app/docs/)
103
+ - [HF Spaces文档](https://huggingface.co/docs/hub/spaces)
README.md CHANGED
@@ -1,15 +1,50 @@
1
  ---
2
- title: Acfoley
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
- hf_oauth: true
11
- hf_oauth_scopes:
12
- - inference-api
13
  ---
14
 
15
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: hf_AC Audio Foley Generator
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
 
 
11
  ---
12
 
13
+ # 🎵 hf_AC Audio Foley Generator
14
+
15
+ A Gradio demo for generating synchronized audio from videos using the hf_AC (Audio-Conditioned Foley) model. This application allows you to upload a video and generate matching audio content based on text descriptions.
16
+
17
+ ## Features
18
+
19
+ - **Video-to-Audio Generation**: Upload a video and generate synchronized audio
20
+ - **Text-Guided Generation**: Use text prompts to describe the desired audio
21
+ - **Customizable Parameters**: Adjust duration, CFG strength, and other generation parameters
22
+ - **Real-time Processing**: Generate audio in real-time with GPU acceleration
23
+
24
+ ## How to Use
25
+
26
+ 1. **Load Model**: The model will automatically load when you start the app
27
+ 2. **Upload Video**: Choose a video file (MP4 format recommended)
28
+ 3. **Describe Audio**: Write a text description of the audio you want to generate
29
+ 4. **Generate**: Click the generate button and wait for the audio to be created
30
+ 5. **Download**: Listen to and download the generated audio
31
+
32
+ ## Example Prompts
33
+
34
+ - "Crackling fireplace with gentle flames"
35
+ - "Ocean waves crashing on rocky shore"
36
+ - "Busy city street with car horns and chatter"
37
+ - "Forest ambience with bird songs and rustling leaves"
38
+ - "Keyboard typing in a quiet office"
39
+
40
+ ## Model Information
41
+
42
+ This demo uses the hf_AC model, which is designed for audio-visual synchronization and generation. The model can generate high-quality audio that matches the visual content and text descriptions.
43
+
44
+ ## Technical Details
45
+
46
+ - **Framework**: PyTorch, Gradio
47
+ - **Model**: hf_AC (Audio-Conditioned Foley)
48
+ - **Audio Format**: WAV, 44.1kHz
49
+ - **Video Support**: MP4, various resolutions
50
+ - **Processing**: GPU-accelerated when available
app.py CHANGED
@@ -1,70 +1,360 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
20
 
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import logging
5
+ import tempfile
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+ import numpy as np
10
+ from typing import Optional, Tuple
11
+ import time
12
+ import traceback
13
 
14
+ # Add hf_AC to path
15
+ current_dir = Path(__file__).parent
16
+ hf_ac_path = current_dir / "hf_AC"
17
+ if hf_ac_path.exists():
18
+ sys.path.insert(0, str(hf_ac_path))
19
 
20
+ # Configuration for HF Space
21
+ EXAMPLE_PROMPTS = [
22
+ "Crackling fireplace with gentle flames",
23
+ "Ocean waves crashing on rocky shore",
24
+ "Forest ambience with bird songs",
25
+ "Keyboard typing sounds",
26
+ "Footsteps on wooden floor",
27
+ "Rain on metal roof"
28
+ ]
 
 
 
 
29
 
30
+ USAGE_TIPS = """
31
+ ### 💡 使用技巧
32
 
33
+ 1. **视频质量**: 使用清晰、光线良好的视频
34
+ 2. **提示词**: 具体描述想要的音频类型
35
+ 3. **时长**: 建议1-30秒效果最佳
36
+ 4. **CFG强度**: 数值越高越贴合提示词,但可能降低质量
37
+ """
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Import hf_AC modules with error handling
40
+ try:
41
+ from hf_AC.mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video,
42
+ setup_eval_logging)
43
+ from hf_AC.mmaudio.model.flow_matching import FlowMatching
44
+ from hf_AC.mmaudio.model.networks import MMAudio, get_my_mmaudio
45
+ from hf_AC.mmaudio.model.utils.features_utils import FeaturesUtils
46
+
47
+ # Setup logging
48
+ setup_eval_logging()
49
+ log = logging.getLogger()
50
+ HF_AC_AVAILABLE = True
51
+ except ImportError as e:
52
+ print(f"Warning: hf_AC modules not available: {e}")
53
+ log = logging.getLogger()
54
+ HF_AC_AVAILABLE = False
55
 
56
+ class AudioFoleyModel:
57
+ def __init__(self):
58
+ self.device = 'cpu'
59
+ if torch.cuda.is_available():
60
+ self.device = 'cuda'
61
+ elif torch.backends.mps.is_available():
62
+ self.device = 'mps'
63
+
64
+ self.dtype = torch.bfloat16
65
+ self.model = None
66
+ self.net = None
67
+ self.fm = None
68
+ self.feature_utils = None
69
+
70
+ def load_model(self, variant='large_44k', model_path=None):
71
+ """Load the hf_AC model"""
72
+ try:
73
+ if not HF_AC_AVAILABLE:
74
+ return "❌ hf_AC modules not available. Please install the hf_AC package."
75
+
76
+ if variant not in all_model_cfg:
77
+ available_variants = list(all_model_cfg.keys()) if all_model_cfg else []
78
+ return f"❌ Unknown model variant: {variant}. Available: {available_variants}"
79
+
80
+ log.info(f"Loading model variant: {variant}")
81
+ self.model: ModelConfig = all_model_cfg[variant]
82
+
83
+ # Download model components if needed
84
+ try:
85
+ self.model.download_if_needed()
86
+ except Exception as e:
87
+ log.warning(f"Could not download model components: {e}")
88
+
89
+ # Set custom model path if provided
90
+ if model_path and os.path.exists(model_path):
91
+ self.model.model_path = Path(model_path)
92
+ log.info(f"Using custom model path: {model_path}")
93
+
94
+ # Load network
95
+ self.net: MMAudio = get_my_mmaudio(self.model.model_name).to(self.device, self.dtype).eval()
96
+
97
+ # Load weights
98
+ if hasattr(self.model, 'model_path') and self.model.model_path and Path(self.model.model_path).exists():
99
+ try:
100
+ weights = torch.load(self.model.model_path, map_location=self.device, weights_only=True)
101
+ self.net.load_weights(weights['weights'])
102
+ log.info(f'✅ Loaded weights from {self.model.model_path}')
103
+ except Exception as e:
104
+ log.error(f"Failed to load weights: {e}")
105
+ return f"❌ Failed to load model weights: {e}"
106
+ else:
107
+ log.warning('⚠️ No model weights found, using default initialization')
108
+ return "⚠️ Model loaded but no weights found. Download model.pth from HuggingFace."
109
+
110
+ # Initialize flow matching
111
+ self.fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=25)
112
+
113
+ # Initialize feature utils
114
+ try:
115
+ self.feature_utils = FeaturesUtils(
116
+ tod_vae_ckpt=self.model.vae_path,
117
+ synchformer_ckpt=self.model.synchformer_ckpt,
118
+ enable_conditions=True,
119
+ mode=self.model.mode,
120
+ bigvgan_vocoder_ckpt=self.model.bigvgan_16k_path,
121
+ need_vae_encoder=True
122
+ )
123
+ self.feature_utils = self.feature_utils.to(self.device, self.dtype).eval()
124
+ except Exception as e:
125
+ log.error(f"Failed to initialize feature utils: {e}")
126
+ return f"❌ Failed to initialize feature utilities: {e}"
127
+
128
+ return "✅ Model loaded successfully!"
129
+
130
+ except Exception as e:
131
+ error_msg = f"❌ Error loading model: {str(e)}\n{traceback.format_exc()}"
132
+ log.error(error_msg)
133
+ return error_msg
134
+
135
+ def generate_audio(self, video_file, prompt: str, negative_prompt: str = "",
136
+ duration: float = 8.0, cfg_strength: float = 4.5,
137
+ seed: int = 42) -> Tuple[Optional[str], str]:
138
+ """Generate audio from video and text prompt"""
139
+ try:
140
+ # Validation checks
141
+ if not HF_AC_AVAILABLE:
142
+ return None, "❌ hf_AC modules not available."
143
+
144
+ if self.net is None or self.feature_utils is None:
145
+ return None, "❌ Model not loaded. Please load the model first."
146
+
147
+ if video_file is None:
148
+ return None, "❌ Please upload a video file."
149
+
150
+ if not prompt.strip():
151
+ return None, "❌ Please provide a text prompt describing the desired audio."
152
+
153
+ log.info(f'🎬 Processing video: {video_file}')
154
+ log.info(f'📝 Prompt: "{prompt}"')
155
+
156
+ # Load and process video
157
+ try:
158
+ video_path = Path(video_file)
159
+ if not video_path.exists():
160
+ return None, f"❌ Video file not found: {video_file}"
161
+
162
+ video_info = load_video(video_path, duration)
163
+ clip_frames = video_info.clip_frames
164
+ sync_frames = video_info.sync_frames
165
+ duration_sec = video_info.duration_sec
166
+
167
+ log.info(f'📹 Video loaded: {duration_sec:.2f}s duration')
168
+
169
+ except Exception as e:
170
+ return None, f"❌ Failed to load video: {str(e)}"
171
+
172
+ # Prepare frames
173
+ clip_frames = clip_frames.unsqueeze(0) if clip_frames is not None else None
174
+ sync_frames = sync_frames.unsqueeze(0)
175
+
176
+ # Update model sequence configuration
177
+ try:
178
+ self.model.seq_cfg.duration = duration_sec
179
+ self.model.seq_cfg.audio_num_sample = 89088 # Default for 44kHz
180
+ self.net.update_seq_lengths(
181
+ self.model.seq_cfg.latent_seq_len,
182
+ self.model.seq_cfg.clip_seq_len,
183
+ self.model.seq_cfg.sync_seq_len,
184
+ self.model.seq_cfg.audio_seq_len
185
+ )
186
+ except Exception as e:
187
+ return None, f"❌ Failed to configure model: {str(e)}"
188
+
189
+ # Generate audio
190
+ try:
191
+ log.info('🎵 Generating audio...')
192
+ start_time = time.time()
193
+
194
+ with torch.inference_mode():
195
+ audios = generate(
196
+ clip_frames,
197
+ sync_frames,
198
+ [prompt],
199
+ None, # No reference audio
200
+ negative_text=[negative_prompt] if negative_prompt.strip() else None,
201
+ feature_utils=self.feature_utils,
202
+ net=self.net,
203
+ fm=self.fm,
204
+ rng=torch.Generator(device=self.device).manual_seed(seed),
205
+ cfg_strength=cfg_strength
206
+ )
207
+
208
+ generation_time = time.time() - start_time
209
+ log.info(f'⏱️ Generation completed in {generation_time:.2f}s')
210
+
211
+ except Exception as e:
212
+ return None, f"❌ Audio generation failed: {str(e)}"
213
+
214
+ # Save generated audio
215
+ try:
216
+ audio = audios.float().cpu()[0]
217
+
218
+ # Create output filename with timestamp
219
+ timestamp = int(time.time())
220
+ output_filename = f"generated_audio_{timestamp}.wav"
221
+ permanent_path = f"/tmp/{output_filename}"
222
+
223
+ # Save audio file
224
+ torchaudio.save(permanent_path, audio, self.model.seq_cfg.sampling_rate)
225
+
226
+ # Verify file was created
227
+ if not os.path.exists(permanent_path):
228
+ return None, "❌ Failed to save audio file"
229
+
230
+ file_size = os.path.getsize(permanent_path) / 1024 # KB
231
+ success_msg = f"✅ Audio generated successfully!\n"
232
+ success_msg += f"📊 Duration: {duration_sec:.2f}s | "
233
+ success_msg += f"Size: {file_size:.1f}KB | "
234
+ success_msg += f"Time: {generation_time:.2f}s"
235
+
236
+ return permanent_path, success_msg
237
+
238
+ except Exception as e:
239
+ return None, f"❌ Failed to save audio: {str(e)}"
240
+
241
+ except Exception as e:
242
+ error_msg = f"❌ Unexpected error: {str(e)}\n{traceback.format_exc()}"
243
+ log.error(error_msg)
244
+ return None, error_msg
245
 
246
+ # Initialize model
247
+ audio_model = AudioFoleyModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ def generate_audio_interface(video_file, prompt, duration, cfg_strength):
250
+ """Interface function for generating audio"""
251
+ # Use fixed seed for consistency in HF Space
252
+ seed = 42
253
+ negative_prompt = "" # Simplified interface
254
+
255
+ audio_path, message = audio_model.generate_audio(
256
+ video_file, prompt, negative_prompt, duration, cfg_strength, seed
257
+ )
258
+ return audio_path, message
259
 
260
+ # Create Gradio interface
261
+ with gr.Blocks(title="hf_AC Audio Foley Generator", theme=gr.themes.Soft()) as demo:
262
+ gr.Markdown("""
263
+ # 🎵 hf_AC Audio Foley Generator
264
+
265
+ 基于AI的视频音频生成工具。上传视频并提供文本描述,模型将生成匹配的音频内容。
266
+
267
+ **注意**: 首次使用时模型需要下载,请耐心等待。
268
+ """)
269
+
270
+ # Model status display
271
+ model_status = gr.Textbox(
272
+ label="模型状态",
273
+ value="正在初始化模型...",
274
+ interactive=False
275
+ )
276
+
277
+ with gr.Row():
278
+ with gr.Column():
279
+ video_input = gr.Video(
280
+ label="上传视频",
281
+ format="mp4"
282
+ )
283
+
284
+ prompt_input = gr.Textbox(
285
+ label="音频描述",
286
+ placeholder="描述你想要生成的音频 (例如: '脚步声', '鸟叫声', '汽车引擎声')",
287
+ lines=3
288
+ )
289
+
290
+ with gr.Row():
291
+ duration_slider = gr.Slider(
292
+ minimum=1.0,
293
+ maximum=15.0,
294
+ value=8.0,
295
+ step=0.5,
296
+ label="时长 (秒)"
297
+ )
298
+
299
+ cfg_strength_slider = gr.Slider(
300
+ minimum=1.0,
301
+ maximum=8.0,
302
+ value=4.5,
303
+ step=0.1,
304
+ label="CFG强度"
305
+ )
306
+
307
+ with gr.Column():
308
+ # Example prompts
309
+ gr.Markdown("### 🎯 示例提示词")
310
+ example_buttons = []
311
+ for prompt in EXAMPLE_PROMPTS[:6]:
312
+ btn = gr.Button(prompt, size="sm")
313
+ example_buttons.append(btn)
314
+ btn.click(
315
+ fn=lambda p=prompt: p,
316
+ outputs=prompt_input
317
+ )
318
+
319
+ generate_btn = gr.Button("🎵 生成音频", variant="primary", size="lg")
320
+
321
+ audio_output = gr.Audio(
322
+ label="生成的音频",
323
+ type="filepath"
324
+ )
325
+
326
+ generation_status = gr.Textbox(label="生成状态", interactive=False)
327
+
328
+ generate_btn.click(
329
+ fn=generate_audio_interface,
330
+ inputs=[
331
+ video_input, prompt_input, duration_slider, cfg_strength_slider
332
+ ],
333
+ outputs=[audio_output, generation_status]
334
+ )
335
+
336
+ with gr.Accordion("💡 使用说明", open=False):
337
+ gr.Markdown(USAGE_TIPS)
338
+
339
+ gr.Markdown("""
340
+ ### 🎬 更多示例提示词
341
+
342
+ - "壁炉中燃烧的柴火声"
343
+ - "海浪拍打岩石的声音"
344
+ - "繁忙街道上的���车和人声"
345
+ - "森林中的鸟叫和树叶声"
346
+ - "安静办公室里的键盘敲击声"
347
+ - "厨房里炒菜和切菜的声音"
348
+ - "雨滴打在金属屋顶上"
349
+ - "木地板上轻柔的脚步声"
350
+ """)
351
+
352
+ # Auto-load model on startup
353
+ demo.load(
354
+ fn=lambda: audio_model.load_model(),
355
+ outputs=[model_status]
356
+ )
357
 
358
  if __name__ == "__main__":
359
+ # HF Space will handle the server configuration
360
+ demo.launch()
hf_AC/.gitignore ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_*.sh
2
+ log/
3
+ saves
4
+ saves/
5
+ weights/
6
+ weights
7
+ output/
8
+ output
9
+ pretrained/
10
+ workspace
11
+ workspace/
12
+ ext_weights/
13
+ ext_weights
14
+ .checkpoints/
15
+ .vscode/
16
+ training/example_output/
17
+
18
+ # Byte-compiled / optimized / DLL files
19
+ __pycache__/
20
+ *.py[codz]
21
+ *$py.class
22
+
23
+ # C extensions
24
+ *.so
25
+
26
+ # Distribution / packaging
27
+ .Python
28
+ build/
29
+ develop-eggs/
30
+ dist/
31
+ downloads/
32
+ eggs/
33
+ .eggs/
34
+ lib/
35
+ lib64/
36
+ parts/
37
+ sdist/
38
+ var/
39
+ wheels/
40
+ share/python-wheels/
41
+ *.egg-info/
42
+ .installed.cfg
43
+ *.egg
44
+ MANIFEST
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ # Installer logs
53
+ pip-log.txt
54
+ pip-delete-this-directory.txt
55
+
56
+ # Unit test / coverage reports
57
+ htmlcov/
58
+ .tox/
59
+ .nox/
60
+ .coverage
61
+ .coverage.*
62
+ .cache
63
+ nosetests.xml
64
+ coverage.xml
65
+ *.cover
66
+ *.py.cover
67
+ .hypothesis/
68
+ .pytest_cache/
69
+ cover/
70
+
71
+ # Translations
72
+ *.mo
73
+ *.pot
74
+
75
+ # Django stuff:
76
+ *.log
77
+ local_settings.py
78
+ db.sqlite3
79
+ db.sqlite3-journal
80
+
81
+ # Flask stuff:
82
+ instance/
83
+ .webassets-cache
84
+
85
+ # Scrapy stuff:
86
+ .scrapy
87
+
88
+ # Sphinx documentation
89
+ docs/_build/
90
+
91
+ # PyBuilder
92
+ .pybuilder/
93
+ target/
94
+
95
+ # Jupyter Notebook
96
+ .ipynb_checkpoints
97
+
98
+ # IPython
99
+ profile_default/
100
+ ipython_config.py
101
+
102
+ # pyenv
103
+ # For a library or package, you might want to ignore these files since the code is
104
+ # intended to run in multiple environments; otherwise, check them in:
105
+ # .python-version
106
+
107
+ # pipenv
108
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
110
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
111
+ # install all needed dependencies.
112
+ #Pipfile.lock
113
+
114
+ # UV
115
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
116
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
117
+ # commonly ignored for libraries.
118
+ #uv.lock
119
+
120
+ # poetry
121
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
122
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
123
+ # commonly ignored for libraries.
124
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
125
+ #poetry.lock
126
+ #poetry.toml
127
+
128
+ # pdm
129
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
130
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
131
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
132
+ #pdm.lock
133
+ #pdm.toml
134
+ .pdm-python
135
+ .pdm-build/
136
+
137
+ # pixi
138
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
139
+ #pixi.lock
140
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
141
+ # in the .venv directory. It is recommended not to include this directory in version control.
142
+ .pixi
143
+
144
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
145
+ __pypackages__/
146
+
147
+ # Celery stuff
148
+ celerybeat-schedule
149
+ celerybeat.pid
150
+
151
+ # SageMath parsed files
152
+ *.sage.py
153
+
154
+ # Environments
155
+ .env
156
+ .envrc
157
+ .venv
158
+ env/
159
+ venv/
160
+ ENV/
161
+ env.bak/
162
+ venv.bak/
163
+
164
+ # Spyder project settings
165
+ .spyderproject
166
+ .spyproject
167
+
168
+ # Rope project settings
169
+ .ropeproject
170
+
171
+ # mkdocs documentation
172
+ /site
173
+
174
+ # mypy
175
+ .mypy_cache/
176
+ .dmypy.json
177
+ dmypy.json
178
+
179
+ # Pyre type checker
180
+ .pyre/
181
+
182
+ # pytype static type analyzer
183
+ .pytype/
184
+
185
+ # Cython debug symbols
186
+ cython_debug/
187
+
188
+ # PyCharm
189
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
190
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
191
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
192
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
193
+ #.idea/
194
+
195
+ # Abstra
196
+ # Abstra is an AI-powered process automation framework.
197
+ # Ignore directories containing user credentials, local state, and settings.
198
+ # Learn more at https://abstra.io/docs
199
+ .abstra/
200
+
201
+ # Visual Studio Code
202
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
203
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
204
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
205
+ # you could uncomment the following to ignore the entire vscode folder
206
+ # .vscode/
207
+
208
+ # Ruff stuff:
209
+ .ruff_cache/
210
+
211
+ # PyPI configuration file
212
+ .pypirc
213
+
214
+ # Cursor
215
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
216
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
217
+ # refer to https://docs.cursor.com/context/ignore-files
218
+ .cursorignore
219
+ .cursorindexingignore
220
+
221
+ # Marimo
222
+ marimo/_static/
223
+ marimo/_lsp/
224
+ __marimo__/
hf_AC/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # hf_AC
2
+
3
+ ## Environment Setup
4
+ - Python 3.9+
5
+ - PyTorch **2.5.1+** and corresponding torchvision/torchaudio (pick your CUDA version https://pytorch.org/, pip install recommended)
6
+
7
+ ```bash
8
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 --upgrade
9
+ git clone https://github.com/ff2416/hf_AC.git
10
+ cd hf_AC
11
+ pip install -e .
12
+ ```
13
+ ## Model Installation
14
+ https://huggingface.co/FF2416/AC-Foley/blob/main/model.pth
15
+
16
+ ## Inference
17
+ ```bash
18
+ python inf.py \
19
+ --model_path <model path> \
20
+ --duration 8 \
21
+ --prompt <prompt> \
22
+ --video_dir <videos directory or video path> \
23
+ --audio_path <audio path> \
24
+ --output <output path>
25
+ ```
hf_AC/config/__init__.py ADDED
File without changes
hf_AC/config/base_config.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: base
3
+ - eval_data: base
4
+ - override hydra/job_logging: custom-simplest
5
+ - _self_
6
+
7
+ hydra:
8
+ run:
9
+ dir: ./output/${exp_id}
10
+ output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
11
+
12
+ enable_email: False
13
+
14
+ model: large_16k
15
+
16
+ exp_id: default
17
+ debug: False
18
+ cudnn_benchmark: True
19
+ compile: True
20
+ amp: True
21
+ weights: null
22
+ checkpoint: null
23
+ seed: 14159265
24
+ num_workers: 10 # per-GPU
25
+ pin_memory: False # set to True if your system can handle it, i.e., have enough memory
26
+
27
+ # NOTE: This DOSE NOT affect the model during inference in any way
28
+ # they are just for the dataloader to fill in the missing data in multi-modal loading
29
+ # to change the sequence length for the model, see networks.py
30
+ data_dim:
31
+ text_seq_len: 77
32
+ clip_dim: 1024
33
+ sync_dim: 768
34
+ text_dim: 1024
35
+
36
+ # ema configuration
37
+ ema:
38
+ enable: True
39
+ sigma_rels: [0.05, 0.1]
40
+ update_every: 1
41
+ checkpoint_every: 10_000
42
+ checkpoint_folder: ${hydra:run.dir}/ema_ckpts
43
+ default_output_sigma: 0.05
44
+
45
+
46
+ # sampling
47
+ sampling:
48
+ mean: 0.0
49
+ scale: 1.0
50
+ min_sigma: 0.0
51
+ method: euler
52
+ num_steps: 25
53
+
54
+ # classifier-free guidance
55
+ null_condition_probability: 0.1
56
+ cfg_strength: 4.5
57
+
58
+ # checkpoint paths to external modules
59
+ vae_16k_ckpt: ./ext_weights/v1-16.pth
60
+ vae_44k_ckpt: ./ext_weights/v1-44.pth
61
+ bigvgan_vocoder_ckpt: ./ext_weights/best_netG.pt
62
+ synchformer_ckpt: ./ext_weights/synchformer_state_dict.pth
hf_AC/config/data/base.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VGGSound:
2
+ root: /project/llmsvgen/share/data_vggsound/dataset/scratch/shared/beegfs/hchen/train_data/VGGSound_final/video
3
+ subset_name: /project/llmsvgen/pengjun/MMAudio_dev/tsv/vgg-train.tsv
4
+ fps: 8
5
+ height: 384
6
+ width: 384
7
+ sample_duration_sec: 8.0
8
+
9
+ VGGSound_test:
10
+ root: /project/llmsvgen/share/data_vggsound/dataset/scratch/shared/beegfs/hchen/train_data/VGGSound_final/video
11
+ subset_name: /project/llmsvgen/pengjun/MMAudio_dev/tsv/vgg-test.tsv
12
+ fps: 8
13
+ height: 384
14
+ width: 384
15
+ sample_duration_sec: 8.0
16
+
17
+ VGGSound_val:
18
+ root: /project/llmsvgen/share/data_vggsound/dataset/scratch/shared/beegfs/hchen/train_data/VGGSound_final/video
19
+ subset_name: /project/llmsvgen/pengjun/MMAudio_dev/tsv/vgg-val.tsv
20
+ fps: 8
21
+ height: 384
22
+ width: 384
23
+ sample_duration_sec: 8.0
24
+
25
+ ExtractedVGG:
26
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg-clap/memmap/vgg-train.tsv
27
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg-clap/memmap/vgg-train
28
+
29
+ ExtractedVGG_test:
30
+ tag: test
31
+ gt_cache: ../data/eval-cache/vggsound-test
32
+ output_subdir: null
33
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg-clap/memmap/vgg-test.tsv
34
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg-clap/memmap/vgg-test
35
+
36
+ ExtractedVGG_val:
37
+ tag: val
38
+ gt_cache: /project/llmsvgen/pengjun/MMAudio_dev/training/val_cache
39
+ output_subdir: val
40
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg-clap/memmap/vgg-val.tsv
41
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg-clap/memmap/vgg-val
42
+
43
+ AudioCaps:
44
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/audiocaps_clap/memmap/audiocaps_clap.tsv
45
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/audiocaps_clap/memmap/audiocaps_clap
46
+
47
+ AudioSetSL:
48
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/audioset_clap/memmap/audioset_clap.tsv
49
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/audioset_clap/memmap/audioset_clap
50
+ # BBCSound:
51
+ # tsv: ../data/v1-16-memmap/bbcsound.tsv
52
+ # memmap_dir: ../data/v1-16-memmap/bbcsound
53
+
54
+ FreeSound:
55
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/freesound_clap/memmap/freesound_clap.tsv
56
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/freesound_clap/memmap/freesound_clap
57
+
58
+ # Clotho:
59
+ # tsv: ../data/v1-16-memmap/clotho.tsv
60
+ # memmap_dir: ../data/v1-16-memmap/clotho
61
+
62
+ # Example_video:
63
+ # tsv: ./training/example_output/memmap/vgg-example.tsv
64
+ # memmap_dir: ./training/example_output/memmap/vgg-example
65
+
66
+ # Example_audio:
67
+ # tsv: ./training/example_output/memmap/audio-example.tsv
68
+ # memmap_dir: ./training/example_output/memmap/audio-example
69
+
hf_AC/config/data/base2.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VGGSound:
2
+ root: /project/llmsvgen/share/data_vggsound/dataset/scratch/shared/beegfs/hchen/train_data/VGGSound_final/video
3
+ subset_name: /project/llmsvgen/pengjun/MMAudio_dev/tsv/vgg-train.tsv
4
+ fps: 8
5
+ height: 384
6
+ width: 384
7
+ sample_duration_sec: 8.0
8
+
9
+ VGGSound_test:
10
+ root: /project/llmsvgen/share/data_vggsound/dataset/scratch/shared/beegfs/hchen/train_data/VGGSound_final/video
11
+ subset_name: /project/llmsvgen/pengjun/MMAudio_dev/tsv/vgg-test.tsv
12
+ fps: 8
13
+ height: 384
14
+ width: 384
15
+ sample_duration_sec: 8.0
16
+
17
+ VGGSound_val:
18
+ root: /project/llmsvgen/share/data_vggsound/dataset/scratch/shared/beegfs/hchen/train_data/VGGSound_final/video
19
+ subset_name: /project/llmsvgen/pengjun/MMAudio_dev/tsv/vgg-val.tsv
20
+ fps: 8
21
+ height: 384
22
+ width: 384
23
+ sample_duration_sec: 8.0
24
+
25
+ ExtractedVGG:
26
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg/memmap/vgg-train.tsv
27
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg/memmap/vgg-train
28
+
29
+ ExtractedVGG_test:
30
+ tag: test
31
+ gt_cache: ../data/eval-cache/vggsound-test
32
+ output_subdir: null
33
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg/memmap/vgg-test.tsv
34
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg/memmap/vgg-test
35
+
36
+ ExtractedVGG_val:
37
+ tag: val
38
+ gt_cache: /project/llmsvgen/pengjun/MMAudio_dev/training/val_cache
39
+ output_subdir: val
40
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg/memmap/vgg-val.tsv
41
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/vgg/memmap/vgg-val
42
+
43
+ AudioCaps:
44
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/audiocaps/memmap/audiocaps.tsv
45
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/audiocaps/memmap/audiocaps
46
+
47
+ AudioSetSL:
48
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/audioset_sl/memmap/audioset_sl.tsv
49
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/audioset_sl/memmap/audioset_sl
50
+ # BBCSound:
51
+ # tsv: ../data/v1-16-memmap/bbcsound.tsv
52
+ # memmap_dir: ../data/v1-16-memmap/bbcsound
53
+
54
+ FreeSound:
55
+ tsv: /project/llmsvgen/pengjun/MMAudio_dev/training/freesound/memmap/freesound.tsv
56
+ memmap_dir: /project/llmsvgen/pengjun/MMAudio_dev/training/freesound/memmap/freesound
57
+
58
+ # Clotho:
59
+ # tsv: ../data/v1-16-memmap/clotho.tsv
60
+ # memmap_dir: ../data/v1-16-memmap/clotho
61
+
62
+ # Example_video:
63
+ # tsv: ./training/example_output/memmap/vgg-example.tsv
64
+ # memmap_dir: ./training/example_output/memmap/vgg-example
65
+
66
+ # Example_audio:
67
+ # tsv: ./training/example_output/memmap/audio-example.tsv
68
+ # memmap_dir: ./training/example_output/memmap/audio-example
69
+
hf_AC/config/eval_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_config
3
+ - override hydra/job_logging: custom-simplest
4
+ - _self_
5
+
6
+ hydra:
7
+ run:
8
+ dir: ./output/${exp_id}
9
+ output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra
10
+
11
+ exp_id: ${model}
12
+ dataset: audiocaps
13
+ duration_s: 8.0
14
+
15
+ # for inference, this is the per-GPU batch size
16
+ batch_size: 16
17
+ output_name: null
hf_AC/config/eval_data/base.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AudioCaps:
2
+ # audio_path: ../data/AudioCaps-test-audioldm-ver
3
+ # # a csv file, with a header row of 'name' and 'caption'
4
+ # # name should match the audio file name without extension
5
+ # # Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_audioldm_data.csv
6
+ # csv_path: ../data/AudioCaps-test-audioldm-ver/data.csv
7
+
8
+ # AudioCaps_full:
9
+ # audio_path: ../data/AudioCaps-test-full-ver
10
+ # # a csv file, with a header row of 'name' and 'caption'
11
+ # # name should match the audio file name without extension
12
+ # # Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_full_data.csv
13
+ # csv_path: ../data/AudioCaps-test-full-ver/data.csv
14
+
15
+ # MovieGen:
16
+ # video_path: ../data/MovieGen/MovieGenAudioBenchSfx/video_with_audio
17
+ # jsonl_path: ../data/MovieGen/MovieGenAudioBenchSfx/metadata
18
+
19
+ VGGSound:
20
+ video_path: /project/llmsvgen/pengjun/MMAudio_dev/training/test_video
21
+ # from the officially released csv file
22
+ csv_path: /project/llmsvgen/share/data_vggsound/VGGSound/vggsound.csv
hf_AC/config/hydra/job_logging/custom-eval.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ file:
23
+ class: logging.FileHandler
24
+ formatter: simple
25
+ # absolute file path
26
+ filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
27
+ mode: w
28
+ root:
29
+ level: INFO
30
+ handlers: [console, file]
31
+
32
+ disable_existing_loggers: false
hf_AC/config/hydra/job_logging/custom-no-rank.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ file:
23
+ class: logging.FileHandler
24
+ formatter: simple
25
+ # absolute file path
26
+ filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
27
+ mode: w
28
+ root:
29
+ level: INFO
30
+ handlers: [console, file]
31
+
32
+ disable_existing_loggers: false
hf_AC/config/hydra/job_logging/custom-simplest.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ root:
23
+ level: INFO
24
+ handlers: [console]
25
+
26
+ disable_existing_loggers: false
hf_AC/config/hydra/job_logging/custom.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package hydra.job_logging
2
+ # python logging configuration for tasks
3
+ version: 1
4
+ formatters:
5
+ simple:
6
+ format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
7
+ datefmt: '%Y-%m-%d %H:%M:%S'
8
+ colorlog:
9
+ '()': 'colorlog.ColoredFormatter'
10
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)sr${oc.env:LOCAL_RANK}%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
11
+ datefmt: '%Y-%m-%d %H:%M:%S'
12
+ log_colors:
13
+ DEBUG: purple
14
+ INFO: green
15
+ WARNING: yellow
16
+ ERROR: red
17
+ CRITICAL: red
18
+ handlers:
19
+ console:
20
+ class: logging.StreamHandler
21
+ formatter: colorlog
22
+ stream: ext://sys.stdout
23
+ file:
24
+ class: logging.FileHandler
25
+ formatter: simple
26
+ # absolute file path
27
+ filename: ${hydra.runtime.output_dir}/train-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
28
+ mode: w
29
+ root:
30
+ level: INFO
31
+ handlers: [console, file]
32
+
33
+ disable_existing_loggers: false
hf_AC/config/train_config.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_config
3
+ - override data: base
4
+ - override hydra/job_logging: custom
5
+ - _self_
6
+
7
+ hydra:
8
+ run:
9
+ dir: ./output/${exp_id}
10
+ output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
11
+
12
+ ema:
13
+ start: 0
14
+
15
+ mini_train: False
16
+ example_train: False
17
+ enable_grad_scaler: False
18
+ vgg_oversample_rate: 4
19
+
20
+ log_text_interval: 100
21
+ log_extra_interval: 20_000
22
+ val_interval: 10_000
23
+ eval_interval: 20_000
24
+ save_eval_interval: 40_000
25
+ save_weights_interval: 5_000
26
+ save_checkpoint_interval: 5_000
27
+ save_copy_iterations: [50000,100000,150000,200000,220000,240000,260000,280000,300000]
28
+
29
+ batch_size: 340
30
+ eval_batch_size: 32 # per-GPU
31
+
32
+ num_iterations: 300_000
33
+ learning_rate: 1.0e-4
34
+ linear_warmup_steps: 1_000
35
+
36
+ lr_schedule: step
37
+ lr_schedule_steps: [200_000, 240_000]
38
+ lr_schedule_gamma: 0.1
39
+
40
+ clip_grad_norm: 1.0
41
+ weight_decay: 1.0e-6
hf_AC/config/train_config_2node.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_config
3
+ - override data: base
4
+ - override hydra/job_logging: custom
5
+ - _self_
6
+
7
+ hydra:
8
+ run:
9
+ dir: ./output/${exp_id}
10
+ output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
11
+
12
+ ema:
13
+ start: 0
14
+
15
+ mini_train: False
16
+ example_train: False
17
+ enable_grad_scaler: False
18
+ vgg_oversample_rate: 4
19
+
20
+ log_text_interval: 200
21
+ log_extra_interval: 20_000
22
+ val_interval: 5_000
23
+ eval_interval: 20_000
24
+ save_eval_interval: 40_000
25
+ save_weights_interval: 10_000
26
+ save_checkpoint_interval: 10_000
27
+ save_copy_iterations: [40000,60000,80000,100000,150000,200000,220000,240000,260000,280000,300000]
28
+
29
+ batch_size: 320
30
+ eval_batch_size: 32 # per-GPU
31
+
32
+ num_iterations: 220_000
33
+ learning_rate: 1.0e-4
34
+ linear_warmup_steps: 1_000
35
+
36
+ lr_schedule: step
37
+ lr_schedule_steps: [200_000, 240_000, 270_000]
38
+ lr_schedule_gamma: 0.1
39
+
40
+ clip_grad_norm: 1.0
41
+ weight_decay: 1.0e-6
hf_AC/config/train_config_2node2.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_config
3
+ - override data: base2
4
+ - override hydra/job_logging: custom
5
+ - _self_
6
+
7
+ hydra:
8
+ run:
9
+ dir: ./output/${exp_id}
10
+ output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
11
+
12
+ ema:
13
+ start: 0
14
+
15
+ mini_train: False
16
+ example_train: False
17
+ enable_grad_scaler: False
18
+ vgg_oversample_rate: 4
19
+
20
+ log_text_interval: 200
21
+ log_extra_interval: 20_000
22
+ val_interval: 5_000
23
+ eval_interval: 20_000
24
+ save_eval_interval: 40_000
25
+ save_weights_interval: 10_000
26
+ save_checkpoint_interval: 10_000
27
+ save_copy_iterations: [100000,150000,200000,220000,240000,260000,280000,300000]
28
+
29
+ batch_size: 320
30
+ eval_batch_size: 32 # per-GPU
31
+
32
+ num_iterations: 300_000
33
+ learning_rate: 1.0e-4
34
+ linear_warmup_steps: 1_000
35
+
36
+ lr_schedule: step
37
+ lr_schedule_steps: [200_000, 240_000, 270_000]
38
+ lr_schedule_gamma: 0.1
39
+
40
+ clip_grad_norm: 1.0
41
+ weight_decay: 1.0e-6
hf_AC/inf.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torchaudio
7
+
8
+ from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
9
+ setup_eval_logging)
10
+ from mmaudio.model.flow_matching import FlowMatching
11
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
12
+ from mmaudio.model.utils.features_utils import FeaturesUtils
13
+ import os
14
+ from mmaudio.ext.mel_converter import get_mel_converter
15
+ from mmaudio.ext.autoencoder import AutoEncoderModule
16
+ import time
17
+ torch.backends.cuda.matmul.allow_tf32 = True
18
+ torch.backends.cudnn.allow_tf32 = True
19
+ import tqdm
20
+ import glob
21
+ log = logging.getLogger()
22
+
23
+ class Audio:
24
+ def __init__(self, audio_path, sample_rate):
25
+ self.audio_paths = audio_path
26
+ self.sample_rate = sample_rate
27
+ self.num_timbre_sample = 89088 if sample_rate == 44100 else 32768
28
+ self.resampler = {}
29
+
30
+ def load_audio(self):
31
+ chunk_list=[]
32
+ for audio_path in self.audio_paths:
33
+ audio_chunk, sample_rate = torchaudio.load(audio_path)
34
+ audio_chunk = audio_chunk.mean(dim=0) # mono
35
+ abs_max = audio_chunk.abs().max()
36
+ audio_chunk = audio_chunk / abs_max * 0.95
37
+
38
+ # resample
39
+ if sample_rate == self.sample_rate:
40
+ audio_chunk = audio_chunk
41
+ else:
42
+ if sample_rate not in self.resampler:
43
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
44
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
45
+ sample_rate,
46
+ self.sample_rate,
47
+ lowpass_filter_width=64,
48
+ rolloff=0.9475937167399596,
49
+ resampling_method='sinc_interp_kaiser',
50
+ beta=14.769656459379492,
51
+ )
52
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
53
+ if audio_chunk.size(0) < self.num_timbre_sample:
54
+ padding_length = self.num_timbre_sample - audio_chunk.size(0)
55
+ audio_chunk = torch.cat([audio_chunk, torch.zeros(padding_length)], dim=0)
56
+ else:
57
+ audio_chunk = audio_chunk[:self.num_timbre_sample]
58
+ # audio_chunk = audio_chunk[:self.num_timbre_sample]
59
+ chunk_list.append(audio_chunk)
60
+ return chunk_list
61
+
62
+ def process_video(video_path: Path, args, model: ModelConfig, net: MMAudio, fm: FlowMatching, feature_utils: FeaturesUtils, device: str, dtype: torch.dtype, audio: torch.Tensor, i):
63
+ log.info(f'Processing video: {video_path}')
64
+ t=time.time()
65
+ audio_num_sample = 89088
66
+ if audio is not None:
67
+ audio_num_sample = audio.shape[0]
68
+ video_info = load_video(video_path, args.duration)
69
+ clip_frames = video_info.clip_frames
70
+ sync_frames = video_info.sync_frames
71
+ duration = video_info.duration_sec
72
+ if args.mask_away_clip:
73
+ clip_frames = None
74
+ else:
75
+ clip_frames = clip_frames.unsqueeze(0)
76
+ sync_frames = sync_frames.unsqueeze(0)
77
+
78
+ model.seq_cfg.duration = duration
79
+ model.seq_cfg.audio_num_sample = audio_num_sample
80
+ net.update_seq_lengths(model.seq_cfg.latent_seq_len, model.seq_cfg.clip_seq_len, model.seq_cfg.sync_seq_len, model.seq_cfg.audio_seq_len)
81
+
82
+ log.info(f'Prompt: {args.prompt}')
83
+ log.info(f'Negative prompt: {args.negative_prompt}')
84
+ audios = generate(clip_frames,
85
+ sync_frames, [args.prompt], audio,
86
+ negative_text=[args.negative_prompt],
87
+ feature_utils=feature_utils,
88
+ net=net,
89
+ fm=fm,
90
+ rng=torch.Generator(device=device).manual_seed(args.seed),
91
+ cfg_strength=args.cfg_strength)
92
+ audio = audios.float().cpu()[0]
93
+ save_path = args.output / f'{video_path.stem}{i}.wav'
94
+ torchaudio.save(save_path, audio, model.seq_cfg.sampling_rate)
95
+ log.info(f'Audio saved to {save_path}')
96
+
97
+ if not args.skip_video_composite:
98
+ video_save_path = args.output / f'{video_path.stem}{i}.mp4'
99
+ make_video(video_info, video_save_path, audio, sampling_rate=model.seq_cfg.sampling_rate)
100
+ log.info(f'Video saved to {video_save_path}')
101
+
102
+ @torch.inference_mode()
103
+ def main():
104
+ setup_eval_logging()
105
+
106
+ parser = ArgumentParser()
107
+ parser.add_argument('--variant',
108
+ type=str,
109
+ default='large_44k',)
110
+ parser.add_argument('--video_dir', type=Path, help='')
111
+ parser.add_argument('--audio_path', type=str, default='')
112
+ parser.add_argument('--prompt', type=str, help='Input prompt', default='')
113
+ parser.add_argument('--negative_prompt', type=str, help='Negative prompt', default='')
114
+ parser.add_argument('--duration', type=float, default=8.0)
115
+ parser.add_argument('--cfg_strength', type=float, default=4.5)
116
+ parser.add_argument('--num_steps', type=int, default=25)
117
+ parser.add_argument('--mask_away_clip', action='store_true')
118
+ parser.add_argument('--output', type=Path, help='Output directory', default='./')
119
+ parser.add_argument('--seed', type=int, help='Random seed', default=42)
120
+ parser.add_argument('--skip_video_composite', action='store_true')
121
+ parser.add_argument('--full_precision', action='store_true')
122
+ parser.add_argument('--model_path', type=str, default='weights/model.pth', help='Path to the model weights')
123
+
124
+ args = parser.parse_args()
125
+
126
+ if args.variant not in all_model_cfg:
127
+ raise ValueError(f'Unknown model variant: {args.variant}')
128
+ model: ModelConfig = all_model_cfg[args.variant]
129
+ model.download_if_needed()
130
+
131
+ device = 'cpu'
132
+ if torch.cuda.is_available():
133
+ device = 'cuda'
134
+ elif torch.backends.mps.is_available():
135
+ device = 'mps'
136
+ else:
137
+ log.warning('CUDA/MPS are not available, running on CPU')
138
+ dtype = torch.float32 if args.full_precision else torch.bfloat16
139
+
140
+ args.output.mkdir(parents=True, exist_ok=True)
141
+
142
+ if args.audio_path != '':
143
+ SAMPLE_RATE = 44100
144
+ audio = Audio([args.audio_path], SAMPLE_RATE)
145
+ audio_list = audio.load_audio()
146
+ else:
147
+ audio_list = None
148
+
149
+ model.model_path = Path(args.model_path)
150
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
151
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)['weights'])
152
+ log.info(f'Loaded weights from {model.model_path}')
153
+
154
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=args.num_steps)
155
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
156
+ synchformer_ckpt=model.synchformer_ckpt,
157
+ enable_conditions=True,
158
+ mode=model.mode,
159
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
160
+ need_vae_encoder=True)
161
+ feature_utils = feature_utils.to(device, dtype).eval()
162
+
163
+ if args.video_dir:
164
+ video_dir: Path = args.video_dir.expanduser()
165
+ video_files = sorted(list(video_dir.glob('*.mp4')))
166
+ if os.path.isfile(args.video_dir):
167
+ video_files=[args.video_dir]
168
+ if not video_files:
169
+ log.warning(f'No video files found in {video_dir}')
170
+ else:
171
+ if audio_list is None:
172
+ audio_list = [None] * len(video_files)
173
+ if len(audio_list)==1:
174
+ audio_list = audio_list * len(video_files)
175
+ for i in range(1):
176
+ for video_path, audio in tqdm.tqdm(zip(video_files,audio_list)):
177
+ args.seed = torch.seed()
178
+ process_video(video_path, args, model, net, fm, feature_utils, device, dtype, audio, i)
179
+
180
+ if __name__ == '__main__':
181
+ main()
hf_AC/mmaudio/__init__.py ADDED
File without changes
hf_AC/mmaudio/data/__init__.py ADDED
File without changes
hf_AC/mmaudio/data/av_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from fractions import Fraction
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import av
7
+ import numpy as np
8
+ import torch
9
+ from av import AudioFrame
10
+
11
+
12
+ @dataclass
13
+ class VideoInfo:
14
+ duration_sec: float
15
+ fps: Fraction
16
+ clip_frames: torch.Tensor
17
+ sync_frames: torch.Tensor
18
+ all_frames: Optional[list[np.ndarray]]
19
+
20
+ @property
21
+ def height(self):
22
+ return self.all_frames[0].shape[0]
23
+
24
+ @property
25
+ def width(self):
26
+ return self.all_frames[0].shape[1]
27
+
28
+ @classmethod
29
+ def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
30
+ fps: Fraction) -> 'VideoInfo':
31
+ num_frames = int(duration_sec * fps)
32
+ all_frames = [image_info.original_frame] * num_frames
33
+ return cls(duration_sec=duration_sec,
34
+ fps=fps,
35
+ clip_frames=image_info.clip_frames,
36
+ sync_frames=image_info.sync_frames,
37
+ all_frames=all_frames)
38
+
39
+
40
+ @dataclass
41
+ class ImageInfo:
42
+ clip_frames: torch.Tensor
43
+ sync_frames: torch.Tensor
44
+ original_frame: Optional[np.ndarray]
45
+
46
+ @property
47
+ def height(self):
48
+ return self.original_frame.shape[0]
49
+
50
+ @property
51
+ def width(self):
52
+ return self.original_frame.shape[1]
53
+
54
+
55
+ def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
56
+ need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
57
+ output_frames = [[] for _ in list_of_fps]
58
+ next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
59
+ time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
60
+ all_frames = []
61
+
62
+ # container = av.open(video_path)
63
+ with av.open(video_path) as container:
64
+ stream = container.streams.video[0]
65
+ fps = stream.guessed_rate
66
+ stream.thread_type = 'AUTO'
67
+ for packet in container.demux(stream):
68
+ for frame in packet.decode():
69
+ frame_time = frame.time
70
+ if frame_time < start_sec:
71
+ continue
72
+ if frame_time > end_sec:
73
+ break
74
+
75
+ frame_np = None
76
+ if need_all_frames:
77
+ frame_np = frame.to_ndarray(format='rgb24')
78
+ all_frames.append(frame_np)
79
+
80
+ for i, _ in enumerate(list_of_fps):
81
+ this_time = frame_time
82
+ while this_time >= next_frame_time_for_each_fps[i]:
83
+ if frame_np is None:
84
+ frame_np = frame.to_ndarray(format='rgb24')
85
+
86
+ output_frames[i].append(frame_np)
87
+ next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
88
+
89
+ output_frames = [np.stack(frames) for frames in output_frames]
90
+ return output_frames, all_frames, fps
91
+
92
+
93
+ def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
94
+ sampling_rate: int):
95
+ container = av.open(output_path, 'w')
96
+ output_video_stream = container.add_stream('h264', video_info.fps)
97
+ output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
98
+ output_video_stream.width = video_info.width
99
+ output_video_stream.height = video_info.height
100
+ output_video_stream.pix_fmt = 'yuv420p'
101
+
102
+ output_audio_stream = container.add_stream('aac', sampling_rate)
103
+
104
+ # encode video
105
+ for image in video_info.all_frames:
106
+ image = av.VideoFrame.from_ndarray(image)
107
+ packet = output_video_stream.encode(image)
108
+ container.mux(packet)
109
+
110
+ for packet in output_video_stream.encode():
111
+ container.mux(packet)
112
+
113
+ # convert float tensor audio to numpy array
114
+ audio_np = audio.numpy().astype(np.float32)
115
+ audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
116
+ audio_frame.sample_rate = sampling_rate
117
+
118
+ for packet in output_audio_stream.encode(audio_frame):
119
+ container.mux(packet)
120
+
121
+ for packet in output_audio_stream.encode():
122
+ container.mux(packet)
123
+
124
+ container.close()
125
+
126
+
127
+ def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
128
+ """
129
+ NOTE: I don't think we can get the exact video duration right without re-encoding
130
+ so we are not using this but keeping it here for reference
131
+ """
132
+ video = av.open(video_path)
133
+ output = av.open(output_path, 'w')
134
+ input_video_stream = video.streams.video[0]
135
+ output_video_stream = output.add_stream(template=input_video_stream)
136
+ output_audio_stream = output.add_stream('aac', sampling_rate)
137
+
138
+ duration_sec = audio.shape[-1] / sampling_rate
139
+
140
+ for packet in video.demux(input_video_stream):
141
+ # We need to skip the "flushing" packets that `demux` generates.
142
+ if packet.dts is None:
143
+ continue
144
+ # We need to assign the packet to the new stream.
145
+ packet.stream = output_video_stream
146
+ output.mux(packet)
147
+
148
+ # convert float tensor audio to numpy array
149
+ audio_np = audio.numpy().astype(np.float32)
150
+ audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
151
+ audio_frame.sample_rate = sampling_rate
152
+
153
+ for packet in output_audio_stream.encode(audio_frame):
154
+ output.mux(packet)
155
+
156
+ for packet in output_audio_stream.encode():
157
+ output.mux(packet)
158
+
159
+ video.close()
160
+ output.close()
161
+
162
+ output.close()
hf_AC/mmaudio/data/data_setup.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from torch.utils.data.dataloader import default_collate
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from mmaudio.data.eval.audiocaps import AudioCapsData
12
+ from mmaudio.data.eval.video_dataset import MovieGen, VGGSound
13
+ from mmaudio.data.extracted_audio import ExtractedAudio
14
+ from mmaudio.data.extracted_vgg import ExtractedVGG
15
+ from mmaudio.data.mm_dataset import MultiModalDataset
16
+ from mmaudio.utils.dist_utils import local_rank
17
+
18
+ log = logging.getLogger()
19
+
20
+
21
+ # Re-seed randomness every time we start a worker
22
+ def worker_init_fn(worker_id: int):
23
+ worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
24
+ np.random.seed(worker_seed)
25
+ random.seed(worker_seed)
26
+ log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
27
+
28
+
29
+ def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
30
+ dataset = ExtractedVGG(tsv_path=data_cfg.tsv,
31
+ data_dim=cfg.data_dim,
32
+ premade_mmap_dir=data_cfg.memmap_dir)
33
+
34
+ return dataset
35
+
36
+
37
+ def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
38
+ dataset = ExtractedAudio(tsv_path=data_cfg.tsv,
39
+ data_dim=cfg.data_dim,
40
+ premade_mmap_dir=data_cfg.memmap_dir)
41
+
42
+ return dataset
43
+
44
+
45
+ def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]:
46
+ if cfg.mini_train:
47
+ vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
48
+ audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
49
+ dataset = MultiModalDataset([vgg], [audiocaps])
50
+ if cfg.example_train:
51
+ video = load_vgg_data(cfg, cfg.data.Example_video)
52
+ audio = load_audio_data(cfg, cfg.data.Example_audio)
53
+ dataset = MultiModalDataset([video], [audio])
54
+ else:
55
+ # load the largest one first
56
+ freesound = load_audio_data(cfg, cfg.data.FreeSound)
57
+ vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG)
58
+ audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
59
+ audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL)
60
+ # bbcsound = load_audio_data(cfg, cfg.data.BBCSound)
61
+ # clotho = load_audio_data(cfg, cfg.data.Clotho)
62
+ dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate,
63
+ [audiocaps, audioset_sl, freesound])
64
+ # dataset = MultiModalDataset([vgg],[])
65
+
66
+ batch_size = cfg.batch_size
67
+ num_workers = cfg.num_workers
68
+ pin_memory = cfg.pin_memory
69
+ sampler, loader = construct_loader(dataset,
70
+ batch_size,
71
+ num_workers,
72
+ shuffle=True,
73
+ drop_last=True,
74
+ pin_memory=pin_memory)
75
+
76
+ return dataset, sampler, loader
77
+
78
+
79
+ def setup_test_datasets(cfg):
80
+ dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test)
81
+
82
+ batch_size = cfg.batch_size
83
+ num_workers = cfg.num_workers
84
+ pin_memory = cfg.pin_memory
85
+ sampler, loader = construct_loader(dataset,
86
+ batch_size,
87
+ num_workers,
88
+ shuffle=False,
89
+ drop_last=False,
90
+ pin_memory=pin_memory)
91
+
92
+ return dataset, sampler, loader
93
+
94
+
95
+ def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]:
96
+ if cfg.example_train:
97
+ dataset = load_vgg_data(cfg, cfg.data.Example_video)
98
+ else:
99
+ dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
100
+
101
+ val_batch_size = cfg.batch_size
102
+ val_eval_batch_size = cfg.eval_batch_size
103
+ num_workers = cfg.num_workers
104
+ pin_memory = cfg.pin_memory
105
+ _, val_loader = construct_loader(dataset,
106
+ val_batch_size,
107
+ num_workers,
108
+ shuffle=False,
109
+ drop_last=False,
110
+ pin_memory=pin_memory)
111
+ _, eval_loader = construct_loader(dataset,
112
+ val_eval_batch_size,
113
+ num_workers,
114
+ shuffle=False,
115
+ drop_last=False,
116
+ pin_memory=pin_memory)
117
+
118
+ return dataset, val_loader, eval_loader
119
+
120
+
121
+ def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
122
+ if dataset_name.startswith('audiocaps_full'):
123
+ dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path,
124
+ cfg.eval_data.AudioCaps_full.csv_path)
125
+ elif dataset_name.startswith('audiocaps'):
126
+ dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path,
127
+ cfg.eval_data.AudioCaps.csv_path)
128
+ elif dataset_name.startswith('moviegen'):
129
+ dataset = MovieGen(cfg.eval_data.MovieGen.video_path,
130
+ cfg.eval_data.MovieGen.jsonl_path,
131
+ duration_sec=cfg.duration_s)
132
+ elif dataset_name.startswith('vggsound'):
133
+ dataset = VGGSound(cfg.eval_data.VGGSound.video_path,
134
+ cfg.eval_data.VGGSound.csv_path,
135
+ duration_sec=cfg.duration_s)
136
+ else:
137
+ raise ValueError(f'Invalid dataset name: {dataset_name}')
138
+
139
+ batch_size = cfg.batch_size
140
+ num_workers = cfg.num_workers
141
+ pin_memory = cfg.pin_memory
142
+ _, loader = construct_loader(dataset,
143
+ batch_size,
144
+ num_workers,
145
+ shuffle=False,
146
+ drop_last=False,
147
+ pin_memory=pin_memory,
148
+ error_avoidance=True)
149
+ return dataset, loader
150
+
151
+
152
+ def error_avoidance_collate(batch):
153
+ batch = list(filter(lambda x: x is not None, batch))
154
+ if len(batch) == 0:
155
+ return None
156
+ return default_collate(batch)
157
+
158
+
159
+ def construct_loader(dataset: Dataset,
160
+ batch_size: int,
161
+ num_workers: int,
162
+ *,
163
+ shuffle: bool = True,
164
+ drop_last: bool = True,
165
+ pin_memory: bool = False,
166
+ error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]:
167
+ train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
168
+ train_loader = DataLoader(dataset,
169
+ batch_size,
170
+ sampler=train_sampler,
171
+ num_workers=num_workers,
172
+ worker_init_fn=worker_init_fn,
173
+ drop_last=drop_last,
174
+ persistent_workers=num_workers > 0,
175
+ pin_memory=pin_memory,
176
+ collate_fn=error_avoidance_collate if error_avoidance else None)
177
+ return train_sampler, train_loader
hf_AC/mmaudio/data/eval/__init__.py ADDED
File without changes
hf_AC/mmaudio/data/eval/audiocaps.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import pandas as pd
8
+ import torch
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+ log = logging.getLogger()
12
+
13
+
14
+ class AudioCapsData(Dataset):
15
+
16
+ def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
17
+ df = pd.read_csv(csv_path).to_dict(orient='records')
18
+
19
+ audio_files = sorted(os.listdir(audio_path))
20
+ audio_files = set(
21
+ [Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
22
+
23
+ self.data = []
24
+ for row in df:
25
+ self.data.append({
26
+ 'name': row['name'],
27
+ 'caption': row['caption'],
28
+ })
29
+
30
+ self.audio_path = Path(audio_path)
31
+ self.csv_path = Path(csv_path)
32
+
33
+ log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
34
+
35
+ def __getitem__(self, idx: int) -> torch.Tensor:
36
+ return self.data[idx]
37
+
38
+ def __len__(self):
39
+ return len(self.data)
hf_AC/mmaudio/data/eval/moviegen.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import torch
8
+ from torch.utils.data.dataset import Dataset
9
+ from torchvision.transforms import v2
10
+ from torio.io import StreamingMediaDecoder
11
+
12
+ from mmaudio.utils.dist_utils import local_rank
13
+
14
+ log = logging.getLogger()
15
+
16
+ _CLIP_SIZE = 384
17
+ _CLIP_FPS = 8.0
18
+
19
+ _SYNC_SIZE = 224
20
+ _SYNC_FPS = 25.0
21
+
22
+
23
+ class MovieGenData(Dataset):
24
+
25
+ def __init__(
26
+ self,
27
+ video_root: Union[str, Path],
28
+ sync_root: Union[str, Path],
29
+ jsonl_root: Union[str, Path],
30
+ *,
31
+ duration_sec: float = 10.0,
32
+ read_clip: bool = True,
33
+ ):
34
+ self.video_root = Path(video_root)
35
+ self.sync_root = Path(sync_root)
36
+ self.jsonl_root = Path(jsonl_root)
37
+ self.read_clip = read_clip
38
+
39
+ videos = sorted(os.listdir(self.video_root))
40
+ videos = [v[:-4] for v in videos] # remove extensions
41
+ self.captions = {}
42
+
43
+ for v in videos:
44
+ with open(self.jsonl_root / (v + '.jsonl')) as f:
45
+ data = json.load(f)
46
+ self.captions[v] = data['audio_prompt']
47
+
48
+ if local_rank == 0:
49
+ log.info(f'{len(videos)} videos found in {video_root}')
50
+
51
+ self.duration_sec = duration_sec
52
+
53
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
54
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
55
+
56
+ self.clip_augment = v2.Compose([
57
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
58
+ v2.ToImage(),
59
+ v2.ToDtype(torch.float32, scale=True),
60
+ ])
61
+
62
+ self.sync_augment = v2.Compose([
63
+ v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
64
+ v2.CenterCrop(_SYNC_SIZE),
65
+ v2.ToImage(),
66
+ v2.ToDtype(torch.float32, scale=True),
67
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
68
+ ])
69
+
70
+ self.videos = videos
71
+
72
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
73
+ video_id = self.videos[idx]
74
+ caption = self.captions[video_id]
75
+
76
+ reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
77
+ reader.add_basic_video_stream(
78
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
79
+ frame_rate=_CLIP_FPS,
80
+ format='rgb24',
81
+ )
82
+ reader.add_basic_video_stream(
83
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
84
+ frame_rate=_SYNC_FPS,
85
+ format='rgb24',
86
+ )
87
+
88
+ reader.fill_buffer()
89
+ data_chunk = reader.pop_chunks()
90
+
91
+ clip_chunk = data_chunk[0]
92
+ sync_chunk = data_chunk[1]
93
+ if clip_chunk is None:
94
+ raise RuntimeError(f'CLIP video returned None {video_id}')
95
+ if clip_chunk.shape[0] < self.clip_expected_length:
96
+ raise RuntimeError(f'CLIP video too short {video_id}')
97
+
98
+ if sync_chunk is None:
99
+ raise RuntimeError(f'Sync video returned None {video_id}')
100
+ if sync_chunk.shape[0] < self.sync_expected_length:
101
+ raise RuntimeError(f'Sync video too short {video_id}')
102
+
103
+ # truncate the video
104
+ clip_chunk = clip_chunk[:self.clip_expected_length]
105
+ if clip_chunk.shape[0] != self.clip_expected_length:
106
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
107
+ f'expected {self.clip_expected_length}, '
108
+ f'got {clip_chunk.shape[0]}')
109
+ clip_chunk = self.clip_augment(clip_chunk)
110
+
111
+ sync_chunk = sync_chunk[:self.sync_expected_length]
112
+ if sync_chunk.shape[0] != self.sync_expected_length:
113
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
114
+ f'expected {self.sync_expected_length}, '
115
+ f'got {sync_chunk.shape[0]}')
116
+ sync_chunk = self.sync_augment(sync_chunk)
117
+
118
+ data = {
119
+ 'name': video_id,
120
+ 'caption': caption,
121
+ 'clip_video': clip_chunk,
122
+ 'sync_video': sync_chunk,
123
+ }
124
+
125
+ return data
126
+
127
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
128
+ return self.sample(idx)
129
+
130
+ def __len__(self):
131
+ return len(self.captions)
hf_AC/mmaudio/data/eval/video_dataset.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import pandas as pd
8
+ import torch
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+ import torchaudio
13
+ from mmaudio.utils.dist_utils import local_rank
14
+ import random
15
+ log = logging.getLogger()
16
+
17
+ _CLIP_SIZE = 384
18
+ _CLIP_FPS = 8.0
19
+
20
+ _SYNC_SIZE = 224
21
+ _SYNC_FPS = 25.0
22
+
23
+
24
+ class VideoDataset(Dataset):
25
+
26
+ def __init__(
27
+ self,
28
+ video_root: Union[str, Path],
29
+ *,
30
+ duration_sec: float = 8.0,
31
+ ):
32
+ self.video_root = Path(video_root)
33
+
34
+ self.duration_sec = duration_sec
35
+ self.sample_rate = 44100
36
+ self.resampler = {}
37
+ self.expected_audio_length = 89088
38
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
39
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
40
+
41
+ self.clip_transform = v2.Compose([
42
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
43
+ v2.ToImage(),
44
+ v2.ToDtype(torch.float32, scale=True),
45
+ ])
46
+
47
+ self.sync_transform = v2.Compose([
48
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
49
+ v2.CenterCrop(_SYNC_SIZE),
50
+ v2.ToImage(),
51
+ v2.ToDtype(torch.float32, scale=True),
52
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
53
+ ])
54
+
55
+ # to be implemented by subclasses
56
+ self.captions = {}
57
+ self.videos = sorted(list(self.captions.keys()))
58
+
59
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
60
+ video_id = self.videos[idx]
61
+ caption = self.captions[video_id]
62
+
63
+ reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
64
+ reader.add_basic_video_stream(
65
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
66
+ frame_rate=_CLIP_FPS,
67
+ format='rgb24',
68
+ )
69
+ reader.add_basic_video_stream(
70
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
71
+ frame_rate=_SYNC_FPS,
72
+ format='rgb24',
73
+ )
74
+ reader.add_basic_audio_stream(frames_per_chunk=2**30, )
75
+
76
+ reader.fill_buffer()
77
+ data_chunk = reader.pop_chunks()
78
+
79
+ clip_chunk = data_chunk[0]
80
+ sync_chunk = data_chunk[1]
81
+ audio_chunk = data_chunk[2]
82
+
83
+ if clip_chunk is None:
84
+ raise RuntimeError(f'CLIP video returned None {video_id}')
85
+ if clip_chunk.shape[0] < self.clip_expected_length:
86
+ raise RuntimeError(
87
+ f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
88
+ )
89
+
90
+ if sync_chunk is None:
91
+ raise RuntimeError(f'Sync video returned None {video_id}')
92
+ if sync_chunk.shape[0] < self.sync_expected_length:
93
+ raise RuntimeError(
94
+ f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
95
+ )
96
+
97
+ # process audio
98
+ sample_rate = int(reader.get_out_stream_info(2).sample_rate)
99
+ audio_chunk = audio_chunk.transpose(0, 1)
100
+ audio_chunk = audio_chunk.mean(dim=0) # mono
101
+ abs_max = audio_chunk.abs().max()
102
+ audio_chunk = audio_chunk / abs_max * 0.95
103
+
104
+ # resample
105
+ if sample_rate == self.sample_rate:
106
+ audio_chunk = audio_chunk
107
+ else:
108
+ if sample_rate not in self.resampler:
109
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
110
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
111
+ sample_rate,
112
+ self.sample_rate,
113
+ lowpass_filter_width=64,
114
+ rolloff=0.9475937167399596,
115
+ resampling_method='sinc_interp_kaiser',
116
+ beta=14.769656459379492,
117
+ )
118
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
119
+
120
+ if audio_chunk.shape[0] < self.expected_audio_length:
121
+ raise RuntimeError(f'Audio too short {video_id}')
122
+ # start_index = random.randint(0, audio_chunk.shape[0] - self.expected_audio_length)
123
+ timbre_sample = audio_chunk[audio_chunk.shape[0]-self.expected_audio_length:]
124
+
125
+ # truncate the video
126
+ clip_chunk = clip_chunk[:self.clip_expected_length]
127
+ if clip_chunk.shape[0] != self.clip_expected_length:
128
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
129
+ f'expected {self.clip_expected_length}, '
130
+ f'got {clip_chunk.shape[0]}')
131
+ clip_chunk = self.clip_transform(clip_chunk)
132
+
133
+ sync_chunk = sync_chunk[:self.sync_expected_length]
134
+ if sync_chunk.shape[0] != self.sync_expected_length:
135
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
136
+ f'expected {self.sync_expected_length}, '
137
+ f'got {sync_chunk.shape[0]}')
138
+ sync_chunk = self.sync_transform(sync_chunk)
139
+
140
+ data = {
141
+ 'name': video_id,
142
+ 'caption': caption,
143
+ 'clip_video': clip_chunk,
144
+ 'sync_video': sync_chunk,
145
+ 'audio': timbre_sample
146
+ }
147
+
148
+ return data
149
+
150
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
151
+ try:
152
+ return self.sample(idx)
153
+ except Exception as e:
154
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
155
+ return None
156
+
157
+ def __len__(self):
158
+ return len(self.captions)
159
+
160
+
161
+ class VGGSound(VideoDataset):
162
+
163
+ def __init__(
164
+ self,
165
+ video_root: Union[str, Path],
166
+ csv_path: Union[str, Path],
167
+ *,
168
+ duration_sec: float = 8.0,
169
+ ):
170
+ super().__init__(video_root, duration_sec=duration_sec)
171
+ self.video_root = Path(video_root)
172
+ self.csv_path = Path(csv_path)
173
+
174
+ videos = sorted(os.listdir(self.video_root))
175
+ if local_rank == 0:
176
+ log.info(f'{len(videos)} videos found in {video_root}')
177
+ self.captions = {}
178
+
179
+ df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
180
+ 'split']).to_dict(orient='records')
181
+
182
+ videos_no_found = []
183
+ for row in df:
184
+ if row['split'] == 'test':
185
+ start_sec = int(row['sec'])
186
+ video_id = str(row['id'])
187
+ # this is how our videos are named
188
+ video_name = f'{video_id}_{start_sec:06d}'
189
+ if video_name + '.mp4' not in videos:
190
+ videos_no_found.append(video_name)
191
+ continue
192
+
193
+ self.captions[video_name] = row['caption']
194
+
195
+ if local_rank == 0:
196
+ log.info(f'{len(videos)} videos found in {video_root}')
197
+ log.info(f'{len(self.captions)} useable videos found')
198
+ if videos_no_found:
199
+ log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
200
+ log.info(
201
+ 'A small amount is expected, as not all videos are still available on YouTube')
202
+
203
+ self.videos = sorted(list(self.captions.keys()))
204
+
205
+
206
+ class MovieGen(VideoDataset):
207
+
208
+ def __init__(
209
+ self,
210
+ video_root: Union[str, Path],
211
+ jsonl_root: Union[str, Path],
212
+ *,
213
+ duration_sec: float = 10.0,
214
+ ):
215
+ super().__init__(video_root, duration_sec=duration_sec)
216
+ self.video_root = Path(video_root)
217
+ self.jsonl_root = Path(jsonl_root)
218
+
219
+ videos = sorted(os.listdir(self.video_root))
220
+ videos = [v[:-4] for v in videos] # remove extensions
221
+ self.captions = {}
222
+
223
+ for v in videos:
224
+ with open(self.jsonl_root / (v + '.jsonl')) as f:
225
+ data = json.load(f)
226
+ self.captions[v] = data['audio_prompt']
227
+
228
+ if local_rank == 0:
229
+ log.info(f'{len(videos)} videos found in {video_root}')
230
+
231
+ self.videos = videos
hf_AC/mmaudio/data/extracted_audio.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import pandas as pd
6
+ import torch
7
+ from tensordict import TensorDict
8
+ from torch.utils.data.dataset import Dataset
9
+
10
+ from mmaudio.utils.dist_utils import local_rank
11
+
12
+ log = logging.getLogger()
13
+
14
+
15
+ class ExtractedAudio(Dataset):
16
+
17
+ def __init__(
18
+ self,
19
+ tsv_path: Union[str, Path],
20
+ *,
21
+ premade_mmap_dir: Union[str, Path],
22
+ data_dim: dict[str, int],
23
+ ):
24
+ super().__init__()
25
+
26
+ self.data_dim = data_dim
27
+ self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
28
+ self.ids = [str(d['id']) for d in self.df_list]
29
+
30
+ log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
31
+ # load precomputed memory mapped tensors
32
+ premade_mmap_dir = Path(premade_mmap_dir)
33
+ td = TensorDict.load_memmap(premade_mmap_dir)
34
+ log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
35
+ self.mean = td['mean']
36
+ self.std = td['std']
37
+ self.text_features = td['text_features']
38
+ rng = torch.Generator(device=self.text_features.device)
39
+ rng.manual_seed(42)
40
+ randn = torch.empty_like(td['audio_feature_mean']).normal_(generator=rng)
41
+ self.audio_features = td['audio_feature_mean'] + td['audio_feature_std'] * randn
42
+
43
+ log.info(f'Loaded {len(self)} samples from {premade_mmap_dir}.')
44
+ log.info(f'Loaded mean: {self.mean.shape}.')
45
+ log.info(f'Loaded std: {self.std.shape}.')
46
+ log.info(f'Loaded text features: {self.text_features.shape}.')
47
+ log.info(f'Loaded audio features: {self.audio_features.shape}.')
48
+
49
+ assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
50
+ f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
51
+ assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
52
+ f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
53
+
54
+ assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
55
+ f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
56
+ assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
57
+ f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
58
+
59
+ self.fake_clip_features = torch.zeros(self.data_dim['clip_seq_len'],
60
+ self.data_dim['clip_dim'])
61
+ self.fake_sync_features = torch.zeros(self.data_dim['sync_seq_len'],
62
+ self.data_dim['sync_dim'])
63
+ self.video_exist = torch.tensor(0, dtype=torch.bool)
64
+ self.text_exist = torch.tensor(1, dtype=torch.bool)
65
+ self.audio_exist = torch.tensor(1, dtype=torch.bool)
66
+
67
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
68
+ latents = self.mean
69
+ return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
70
+
71
+ def get_memory_mapped_tensor(self) -> TensorDict:
72
+ td = TensorDict({
73
+ 'mean': self.mean,
74
+ 'std': self.std,
75
+ 'text_features': self.text_features,
76
+ 'audio_features': self.audio_features,
77
+ })
78
+ return td
79
+
80
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
81
+ data = {
82
+ 'id': str(self.df_list[idx]['id']),
83
+ 'a_mean': self.mean[idx],
84
+ 'a_std': self.std[idx],
85
+ 'clip_features': self.fake_clip_features,
86
+ 'sync_features': self.fake_sync_features,
87
+ 'text_features': self.text_features[idx],
88
+ 'audio_features': self.audio_features[idx],
89
+ 'caption': self.df_list[idx]['caption'],
90
+ 'video_exist': self.video_exist,
91
+ 'text_exist': self.text_exist,
92
+ 'audio_exist': self.audio_exist,
93
+ }
94
+ return data
95
+
96
+ def __len__(self):
97
+ return len(self.ids)
hf_AC/mmaudio/data/extracted_vgg.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import pandas as pd
6
+ import torch
7
+ from tensordict import TensorDict
8
+ from torch.utils.data.dataset import Dataset
9
+
10
+ from mmaudio.utils.dist_utils import local_rank
11
+
12
+ log = logging.getLogger()
13
+
14
+
15
+ class ExtractedVGG(Dataset):
16
+
17
+ def __init__(
18
+ self,
19
+ tsv_path: Union[str, Path],
20
+ *,
21
+ premade_mmap_dir: Union[str, Path],
22
+ data_dim: dict[str, int],
23
+ ):
24
+ super().__init__()
25
+
26
+ self.data_dim = data_dim
27
+ self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
28
+ self.ids = [d['id'] for d in self.df_list]
29
+
30
+ log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
31
+ # load precomputed memory mapped tensors
32
+ premade_mmap_dir = Path(premade_mmap_dir)
33
+ td = TensorDict.load_memmap(premade_mmap_dir)
34
+ log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
35
+ self.mean = td['mean']
36
+ self.std = td['std']
37
+ self.clip_features = td['clip_features']
38
+ self.sync_features = td['sync_features']
39
+ self.text_features = td['text_features']
40
+ rng = torch.Generator(device=self.clip_features.device)
41
+ rng.manual_seed(14159265)
42
+ randn = torch.empty_like(td['audio_feature_mean']).normal_(generator=rng)
43
+ self.audio_features = td['audio_feature_mean'] + td['audio_feature_std'] * randn
44
+
45
+ if local_rank == 0:
46
+ log.info(f'Loaded {len(self)} samples.')
47
+ log.info(f'Loaded mean: {self.mean.shape}.')
48
+ log.info(f'Loaded std: {self.std.shape}.')
49
+ log.info(f'Loaded clip_features: {self.clip_features.shape}.')
50
+ log.info(f'Loaded sync_features: {self.sync_features.shape}.')
51
+ log.info(f'Loaded text_features: {self.text_features.shape}.')
52
+
53
+ assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
54
+ f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
55
+ assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
56
+ f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
57
+
58
+ assert self.clip_features.shape[1] == self.data_dim['clip_seq_len'], \
59
+ f'{self.clip_features.shape[1]} != {self.data_dim["clip_seq_len"]}'
60
+ assert self.sync_features.shape[1] == self.data_dim['sync_seq_len'], \
61
+ f'{self.sync_features.shape[1]} != {self.data_dim["sync_seq_len"]}'
62
+ assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
63
+ f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
64
+
65
+ assert self.clip_features.shape[-1] == self.data_dim['clip_dim'], \
66
+ f'{self.clip_features.shape[-1]} != {self.data_dim["clip_dim"]}'
67
+ assert self.sync_features.shape[-1] == self.data_dim['sync_dim'], \
68
+ f'{self.sync_features.shape[-1]} != {self.data_dim["sync_dim"]}'
69
+ assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
70
+ f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
71
+
72
+ self.video_exist = torch.tensor(1, dtype=torch.bool)
73
+ self.text_exist = torch.tensor(1, dtype=torch.bool)
74
+ self.audio_exist = torch.tensor(1, dtype=torch.bool)
75
+
76
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
77
+ latents = self.mean
78
+ return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
79
+
80
+ def get_memory_mapped_tensor(self) -> TensorDict:
81
+ td = TensorDict({
82
+ 'mean': self.mean,
83
+ 'std': self.std,
84
+ 'clip_features': self.clip_features,
85
+ 'sync_features': self.sync_features,
86
+ 'text_features': self.text_features,
87
+ 'audio_features': self.audio_features,
88
+ })
89
+ return td
90
+
91
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
92
+ data = {
93
+ 'id': self.df_list[idx]['id'],
94
+ 'a_mean': self.mean[idx],
95
+ 'a_std': self.std[idx],
96
+ 'clip_features': self.clip_features[idx],
97
+ 'sync_features': self.sync_features[idx],
98
+ 'text_features': self.text_features[idx],
99
+ 'audio_features': self.audio_features[idx],
100
+ 'caption': self.df_list[idx]['label'],
101
+ 'video_exist': self.video_exist,
102
+ 'text_exist': self.text_exist,
103
+ 'audio_exist': self.audio_exist,
104
+ }
105
+
106
+ return data
107
+
108
+ def __len__(self):
109
+ return len(self.ids)
hf_AC/mmaudio/data/extraction/__init__.py ADDED
File without changes
hf_AC/mmaudio/data/extraction/vgg_sound.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+
13
+ from mmaudio.utils.dist_utils import local_rank
14
+ import random
15
+ log = logging.getLogger()
16
+
17
+ _CLIP_SIZE = 384
18
+ _CLIP_FPS = 8.0
19
+
20
+ _SYNC_SIZE = 224
21
+ _SYNC_FPS = 25.0
22
+
23
+
24
+ class VGGSound(Dataset):
25
+
26
+ def __init__(
27
+ self,
28
+ root: Union[str, Path],
29
+ *,
30
+ tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
31
+ sample_rate: int = 16_000,
32
+ duration_sec: float = 8.0,
33
+ audio_samples: Optional[int] = None,
34
+ normalize_audio: bool = False,
35
+ exclude_path: Optional[Union[str, Path]] = None,
36
+ ):
37
+ self.root = Path(root)
38
+ self.normalize_audio = normalize_audio
39
+ if audio_samples is None:
40
+ self.audio_samples = int(sample_rate * duration_sec)
41
+ else:
42
+ self.audio_samples = audio_samples
43
+ effective_duration = audio_samples / sample_rate
44
+ # make sure the duration is close enough, within 15ms
45
+ assert abs(effective_duration - duration_sec) < 0.015, \
46
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
47
+
48
+ videos = sorted(os.listdir(self.root))
49
+ videos = set([Path(v).stem for v in videos]) # remove extensions
50
+ self.labels = {}
51
+ self.videos = []
52
+ missing_videos = []
53
+ excluded_videos = []
54
+ self.exclude_list = []
55
+ if exclude_path is not None:
56
+
57
+ for t in sorted(os.listdir(exclude_path)):
58
+ data = torch.load(exclude_path / t, weights_only=True)
59
+ self.exclude_list.extend(data['id'])
60
+ # read the tsv for subset information
61
+ df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
62
+ for record in df_list:
63
+ id = record['id']
64
+ label = record['label']
65
+ if id in self.exclude_list:
66
+ excluded_videos.append(id)
67
+ continue
68
+ if id in videos:
69
+ self.labels[id] = label
70
+ self.videos.append(id)
71
+ else:
72
+ missing_videos.append(id)
73
+
74
+ if local_rank == 0:
75
+ log.info(f'{len(excluded_videos)} videos excluded as per exclude list')
76
+ log.info(f'{len(videos)} videos found in {root}')
77
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
78
+ log.info(f'{len(missing_videos)} videos missing in {root}')
79
+
80
+ self.sample_rate = sample_rate
81
+ self.duration_sec = duration_sec
82
+
83
+ self.expected_audio_length = audio_samples
84
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
85
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
86
+
87
+ self.clip_transform = v2.Compose([
88
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
89
+ v2.ToImage(),
90
+ v2.ToDtype(torch.float32, scale=True),
91
+ ])
92
+
93
+ self.sync_transform = v2.Compose([
94
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
95
+ v2.CenterCrop(_SYNC_SIZE),
96
+ v2.ToImage(),
97
+ v2.ToDtype(torch.float32, scale=True),
98
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
99
+ ])
100
+ self.num_timbre_sample = 89088 if sample_rate == 44100 else 32768
101
+ self.resampler = {}
102
+
103
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
104
+ video_id = self.videos[idx]
105
+ # if video_id in self.exclude_list:
106
+ # raise RuntimeError(f'Video {video_id} is in the exclude list')
107
+ label = self.labels[video_id]
108
+
109
+ reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
110
+ reader.add_basic_video_stream(
111
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
112
+ frame_rate=_CLIP_FPS,
113
+ format='rgb24',
114
+ )
115
+ reader.add_basic_video_stream(
116
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
117
+ frame_rate=_SYNC_FPS,
118
+ format='rgb24',
119
+ )
120
+ reader.add_basic_audio_stream(frames_per_chunk=2**30, )
121
+
122
+ reader.fill_buffer()
123
+ data_chunk = reader.pop_chunks()
124
+
125
+ clip_chunk = data_chunk[0]
126
+ sync_chunk = data_chunk[1]
127
+ audio_chunk = data_chunk[2]
128
+
129
+ if clip_chunk is None:
130
+ raise RuntimeError(f'CLIP video returned None {video_id}')
131
+ if clip_chunk.shape[0] < self.clip_expected_length:
132
+ raise RuntimeError(
133
+ f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
134
+ )
135
+
136
+ if sync_chunk is None:
137
+ raise RuntimeError(f'Sync video returned None {video_id}')
138
+ if sync_chunk.shape[0] < self.sync_expected_length:
139
+ raise RuntimeError(
140
+ f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
141
+ )
142
+
143
+ # process audio
144
+ sample_rate = int(reader.get_out_stream_info(2).sample_rate)
145
+ audio_chunk = audio_chunk.transpose(0, 1)
146
+ audio_chunk = audio_chunk.mean(dim=0) # mono
147
+ if self.normalize_audio:
148
+ abs_max = audio_chunk.abs().max()
149
+ audio_chunk = audio_chunk / abs_max * 0.95
150
+ if abs_max <= 1e-6:
151
+ raise RuntimeError(f'Audio is silent {video_id}')
152
+
153
+ # resample
154
+ if sample_rate == self.sample_rate:
155
+ audio_chunk = audio_chunk
156
+ else:
157
+ if sample_rate not in self.resampler:
158
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
159
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
160
+ sample_rate,
161
+ self.sample_rate,
162
+ lowpass_filter_width=64,
163
+ rolloff=0.9475937167399596,
164
+ resampling_method='sinc_interp_kaiser',
165
+ beta=14.769656459379492,
166
+ )
167
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
168
+
169
+ if audio_chunk.shape[0] < self.expected_audio_length:
170
+ raise RuntimeError(f'Audio too short {video_id}')
171
+ audio_chunk = audio_chunk[:self.expected_audio_length]
172
+ timbre_sample = audio_chunk[audio_chunk.shape[0]-self.num_timbre_sample:]
173
+
174
+ # truncate the video
175
+ clip_chunk = clip_chunk[:self.clip_expected_length]
176
+ if clip_chunk.shape[0] != self.clip_expected_length:
177
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
178
+ f'expected {self.clip_expected_length}, '
179
+ f'got {clip_chunk.shape[0]}')
180
+ clip_chunk = self.clip_transform(clip_chunk)
181
+
182
+ sync_chunk = sync_chunk[:self.sync_expected_length]
183
+ if sync_chunk.shape[0] != self.sync_expected_length:
184
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
185
+ f'expected {self.sync_expected_length}, '
186
+ f'got {sync_chunk.shape[0]}')
187
+ sync_chunk = self.sync_transform(sync_chunk)
188
+
189
+ data = {
190
+ 'id': video_id,
191
+ 'caption': label,
192
+ 'audio': audio_chunk,
193
+ 'clip_video': clip_chunk,
194
+ 'sync_video': sync_chunk,
195
+ 'timbre_sample': timbre_sample
196
+ }
197
+
198
+ return data
199
+
200
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
201
+ try:
202
+ return self.sample(idx)
203
+ except Exception as e:
204
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
205
+ return None
206
+
207
+ def __len__(self):
208
+ return len(self.labels)
hf_AC/mmaudio/data/extraction/wav_dataset.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import open_clip
7
+ import pandas as pd
8
+ import torch
9
+ import torchaudio
10
+ from torch.utils.data.dataset import Dataset
11
+ import random
12
+ log = logging.getLogger()
13
+
14
+
15
+ class WavTextClipsDataset(Dataset):
16
+
17
+ def __init__(
18
+ self,
19
+ root: Union[str, Path],
20
+ *,
21
+ captions_tsv: Union[str, Path],
22
+ clips_tsv: Union[str, Path],
23
+ sample_rate: int,
24
+ num_samples: int,
25
+ normalize_audio: bool = False,
26
+ reject_silent: bool = False,
27
+ tokenizer_id: str = 'ViT-H-14-378-quickgelu',
28
+ ):
29
+ self.root = Path(root)
30
+ self.sample_rate = sample_rate
31
+ self.num_samples = num_samples
32
+ self.normalize_audio = normalize_audio
33
+ self.reject_silent = reject_silent
34
+ self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
35
+ self.num_timbre_sample = 89088 if sample_rate == 44100 else 32768
36
+
37
+ audios = sorted(os.listdir(self.root))
38
+ audios = set([
39
+ Path(audio).stem for audio in audios
40
+ if audio.endswith('.wav') or audio.endswith('.flac')
41
+ ])
42
+ self.captions = {}
43
+
44
+ # read the caption tsv
45
+ df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
46
+ for record in df_list:
47
+ id = record['id']
48
+ caption = record['caption']
49
+ self.captions[id] = caption
50
+
51
+ # read the clip tsv
52
+ df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
53
+ 'id': str,
54
+ 'name': str
55
+ }).to_dict('records')
56
+ self.clips = []
57
+ for record in df_list:
58
+ record['id'] = record['id']
59
+ record['name'] = record['name']
60
+ id = record['id']
61
+ name = record['name']
62
+ if name not in self.captions:
63
+ log.warning(f'Audio {name} not found in {captions_tsv}')
64
+ continue
65
+ record['caption'] = self.captions[name]
66
+ self.clips.append(record)
67
+
68
+ log.info(f'Found {len(self.clips)} audio files in {self.root}')
69
+
70
+ self.resampler = {}
71
+
72
+ def __getitem__(self, idx: int) -> torch.Tensor:
73
+ try:
74
+ clip = self.clips[idx]
75
+ audio_name = clip['name']
76
+ audio_id = clip['id']
77
+ caption = clip['caption']
78
+ start_sample = clip['start_sample']
79
+ end_sample = clip['end_sample']
80
+
81
+ audio_path = self.root / f'{audio_name}.flac'
82
+ if not audio_path.exists():
83
+ audio_path = self.root / f'{audio_name}.wav'
84
+ assert audio_path.exists()
85
+
86
+ audio_chunk, sample_rate = torchaudio.load(audio_path)
87
+ audio_chunk = audio_chunk.mean(dim=0) # mono
88
+ abs_max = audio_chunk.abs().max()
89
+ if self.normalize_audio:
90
+ audio_chunk = audio_chunk / abs_max * 0.95
91
+
92
+ if self.reject_silent and abs_max < 1e-6:
93
+ log.warning(f'Rejecting silent audio')
94
+ return None
95
+
96
+ audio_chunk = audio_chunk[start_sample:end_sample]
97
+
98
+ # resample
99
+ if sample_rate == self.sample_rate:
100
+ audio_chunk = audio_chunk
101
+ else:
102
+ if sample_rate not in self.resampler:
103
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
104
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
105
+ sample_rate,
106
+ self.sample_rate,
107
+ lowpass_filter_width=64,
108
+ rolloff=0.9475937167399596,
109
+ resampling_method='sinc_interp_kaiser',
110
+ beta=14.769656459379492,
111
+ )
112
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
113
+
114
+ if audio_chunk.shape[0] < self.num_samples:
115
+ raise ValueError('Audio is too short')
116
+ timbre_sample = audio_chunk[:self.num_timbre_sample]
117
+ audio_chunk = audio_chunk[audio_chunk.shape[0]-self.num_samples:]
118
+
119
+ tokens = self.tokenizer([caption])[0]
120
+
121
+ output = {
122
+ 'waveform': audio_chunk,
123
+ 'id': audio_id,
124
+ 'caption': caption,
125
+ 'tokens': tokens,
126
+ 'timbre_sample': timbre_sample,
127
+ }
128
+
129
+ return output
130
+ except Exception as e:
131
+ log.error(f'Error reading {audio_path}: {e}')
132
+ return None
133
+
134
+ def __len__(self):
135
+ return len(self.clips)
hf_AC/mmaudio/data/mm_dataset.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+
3
+ import torch
4
+ from torch.utils.data.dataset import Dataset
5
+
6
+
7
+ # modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
8
+ class MultiModalDataset(Dataset):
9
+ datasets: list[Dataset]
10
+ cumulative_sizes: list[int]
11
+
12
+ @staticmethod
13
+ def cumsum(sequence):
14
+ r, s = [], 0
15
+ for e in sequence:
16
+ l = len(e)
17
+ r.append(l + s)
18
+ s += l
19
+ return r
20
+
21
+ def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
22
+ super().__init__()
23
+ self.video_datasets = list(video_datasets)
24
+ self.audio_datasets = list(audio_datasets)
25
+ self.datasets = self.video_datasets + self.audio_datasets
26
+
27
+ self.cumulative_sizes = self.cumsum(self.datasets)
28
+
29
+ def __len__(self):
30
+ return self.cumulative_sizes[-1]
31
+
32
+ def __getitem__(self, idx):
33
+ if idx < 0:
34
+ if -idx > len(self):
35
+ raise ValueError("absolute value of index should not exceed dataset length")
36
+ idx = len(self) + idx
37
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
38
+ if dataset_idx == 0:
39
+ sample_idx = idx
40
+ else:
41
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
42
+ return self.datasets[dataset_idx][sample_idx]
43
+
44
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
45
+ return self.video_datasets[0].compute_latent_stats()
hf_AC/mmaudio/data/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Any, Optional, Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from tensordict import MemoryMappedTensor
11
+ from torch.utils.data import DataLoader
12
+ from torch.utils.data.dataset import Dataset
13
+ from tqdm import tqdm
14
+
15
+ from mmaudio.utils.dist_utils import local_rank, world_size
16
+
17
+ scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
18
+ shm_path = Path('/dev/shm')
19
+
20
+ log = logging.getLogger()
21
+
22
+
23
+ def reseed(seed):
24
+ random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+
28
+ def local_scatter_torch(obj: Optional[Any]):
29
+ if world_size == 1:
30
+ # Just one worker. Do nothing.
31
+ return obj
32
+
33
+ array = [obj] * world_size
34
+ target_array = [None]
35
+ if local_rank == 0:
36
+ dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
37
+ else:
38
+ dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
39
+ return target_array[0]
40
+
41
+
42
+ class ShardDataset(Dataset):
43
+
44
+ def __init__(self, root):
45
+ self.root = root
46
+ self.shards = sorted(os.listdir(root))
47
+
48
+ def __len__(self):
49
+ return len(self.shards)
50
+
51
+ def __getitem__(self, idx):
52
+ return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
53
+
54
+
55
+ def get_tmp_dir(in_memory: bool) -> Path:
56
+ return shm_path if in_memory else scratch_path
57
+
58
+
59
+ def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
60
+ in_memory: bool) -> MemoryMappedTensor:
61
+ if local_rank == 0:
62
+ with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
63
+ log.info(f'Loading shards from {data_path} into {f.name}...')
64
+ data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
65
+ data = share_tensor_to_all(data)
66
+ torch.distributed.barrier()
67
+ f.close() # why does the context manager not close the file for me?
68
+ else:
69
+ log.info('Waiting for the data to be shared with me...')
70
+ data = share_tensor_to_all(None)
71
+ torch.distributed.barrier()
72
+
73
+ return data
74
+
75
+
76
+ def load_shards(
77
+ data_path: Union[str, Path],
78
+ ids: list[int],
79
+ *,
80
+ tmp_file_path: str,
81
+ ) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
82
+
83
+ id_set = set(ids)
84
+ shards = sorted(os.listdir(data_path))
85
+ log.info(f'Found {len(shards)} shards in {data_path}.')
86
+ first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
87
+
88
+ log.info(f'Rank {local_rank} created file {tmp_file_path}')
89
+ first_item = next(iter(first_shard.values()))
90
+ log.info(f'First item shape: {first_item.shape}')
91
+ mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
92
+ dtype=torch.float32,
93
+ filename=tmp_file_path,
94
+ existsok=True)
95
+ total_count = 0
96
+ used_index = set()
97
+ id_indexing = {i: idx for idx, i in enumerate(ids)}
98
+ # faster with no workers; otherwise we need to set_sharing_strategy('file_system')
99
+ loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
100
+ for data in tqdm(loader, desc='Loading shards'):
101
+ for i, v in data.items():
102
+ if i not in id_set:
103
+ continue
104
+
105
+ # tensor_index = ids.index(i)
106
+ tensor_index = id_indexing[i]
107
+ if tensor_index in used_index:
108
+ raise ValueError(f'Duplicate id {i} found in {data_path}.')
109
+ used_index.add(tensor_index)
110
+ mm_tensor[tensor_index] = v
111
+ total_count += 1
112
+
113
+ assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
114
+ log.info(f'Loaded {total_count} tensors from {data_path}.')
115
+
116
+ return mm_tensor
117
+
118
+
119
+ def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
120
+ """
121
+ x: the tensor to be shared; None if local_rank != 0
122
+ return: the shared tensor
123
+ """
124
+
125
+ # there is no need to share your stuff with anyone if you are alone; must be in memory
126
+ if world_size == 1:
127
+ return x
128
+
129
+ if local_rank == 0:
130
+ assert x is not None, 'x must not be None if local_rank == 0'
131
+ else:
132
+ assert x is None, 'x must be None if local_rank != 0'
133
+
134
+ if local_rank == 0:
135
+ filename = x.filename
136
+ meta_information = (filename, x.shape, x.dtype)
137
+ else:
138
+ meta_information = None
139
+
140
+ filename, data_shape, data_type = local_scatter_torch(meta_information)
141
+ if local_rank == 0:
142
+ data = x
143
+ else:
144
+ data = MemoryMappedTensor.from_filename(filename=filename,
145
+ dtype=data_type,
146
+ shape=data_shape)
147
+
148
+ return data
hf_AC/mmaudio/eval_utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ from colorlog import ColoredFormatter
9
+ from PIL import Image
10
+ from torchvision.transforms import v2
11
+
12
+ from mmaudio.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio
13
+ from mmaudio.model.flow_matching import FlowMatching
14
+ from mmaudio.model.networks import MMAudio
15
+ from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
16
+ from mmaudio.model.utils.features_utils import FeaturesUtils
17
+ from mmaudio.utils.download_utils import download_model_if_needed
18
+
19
+ log = logging.getLogger()
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class ModelConfig:
24
+ model_name: str
25
+ model_path: Path
26
+ vae_path: Path
27
+ bigvgan_16k_path: Optional[Path]
28
+ mode: str
29
+ synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth')
30
+
31
+ @property
32
+ def seq_cfg(self) -> SequenceConfig:
33
+ if self.mode == '16k':
34
+ return CONFIG_16K
35
+ elif self.mode == '44k':
36
+ return CONFIG_44K
37
+
38
+ def download_if_needed(self):
39
+ # download_model_if_needed(self.model_path)
40
+ download_model_if_needed(self.vae_path)
41
+ if self.bigvgan_16k_path is not None:
42
+ download_model_if_needed(self.bigvgan_16k_path)
43
+ download_model_if_needed(self.synchformer_ckpt)
44
+
45
+ large_44k = ModelConfig(model_name='large_44k',
46
+ model_path=Path('./weights/mmaudio_large_44k.pth'),
47
+ vae_path=Path('./ext_weights/v1-44.pth'),
48
+ bigvgan_16k_path=None,
49
+ mode='44k')
50
+
51
+ all_model_cfg: dict[str, ModelConfig] = {
52
+ 'large_44k': large_44k,
53
+ }
54
+
55
+
56
+ def generate(
57
+ clip_video: Optional[torch.Tensor],
58
+ sync_video: Optional[torch.Tensor],
59
+ text: Optional[list[str]],
60
+ audio: Optional[torch.Tensor],
61
+ *,
62
+ negative_text: Optional[list[str]] = None,
63
+ feature_utils: FeaturesUtils,
64
+ net: MMAudio,
65
+ fm: FlowMatching,
66
+ rng: torch.Generator,
67
+ cfg_strength: float,
68
+ clip_batch_size_multiplier: int = 40,
69
+ sync_batch_size_multiplier: int = 40,
70
+ image_input: bool = False,
71
+ ) -> torch.Tensor:
72
+ device = feature_utils.device
73
+ dtype = feature_utils.dtype
74
+
75
+ bs = len(text)
76
+ if clip_video is not None:
77
+ clip_video = clip_video.to(device, dtype, non_blocking=True)
78
+ clip_features = feature_utils.encode_video_with_clip(clip_video,
79
+ batch_size=bs *
80
+ clip_batch_size_multiplier)
81
+ if image_input:
82
+ clip_features = clip_features.expand(-1, net.clip_seq_len, -1)
83
+ else:
84
+ clip_features = net.get_empty_clip_sequence(bs)
85
+
86
+ if sync_video is not None and not image_input:
87
+ sync_video = sync_video.to(device, dtype, non_blocking=True)
88
+ sync_features = feature_utils.encode_video_with_sync(sync_video,
89
+ batch_size=bs *
90
+ sync_batch_size_multiplier)
91
+ else:
92
+ sync_features = net.get_empty_sync_sequence(bs)
93
+
94
+ if text is not None:
95
+ text_features = feature_utils.encode_text(text)
96
+ else:
97
+ text_features = net.get_empty_string_sequence(bs)
98
+
99
+ if negative_text is not None:
100
+ assert len(negative_text) == bs
101
+ negative_text_features = feature_utils.encode_text(negative_text)
102
+ else:
103
+ negative_text_features = net.get_empty_string_sequence(bs)
104
+
105
+ if audio is None:
106
+ audio_features = net.get_empty_audio_sequence(bs)
107
+ else:
108
+ if len(audio.shape) == 1:
109
+ audio = audio.cuda().unsqueeze(0)
110
+ audio = audio.repeat(bs, 1)
111
+ else:
112
+ audio = audio.cuda()
113
+ feature_utils_audio = feature_utils.to(device, torch.float32).eval()
114
+ dist = feature_utils_audio.encode_audio(audio)
115
+ audio_mean = dist.mean.detach().cuda().transpose(1, 2)
116
+ audio_std = dist.std.detach().cuda().transpose(1, 2)
117
+ randn = torch.empty_like(audio_mean).normal_(generator=rng)
118
+ audio_features = audio_mean + audio_std * randn
119
+ audio_features = audio_features.to(device, dtype, non_blocking=True)
120
+ feature_utils = feature_utils.to(device, dtype).eval()
121
+
122
+ x0 = torch.randn(bs,
123
+ net.latent_seq_len,
124
+ net.latent_dim,
125
+ device=device,
126
+ dtype=dtype,
127
+ generator=rng)
128
+ preprocessed_conditions = net.preprocess_conditions(clip_features, sync_features, text_features, audio_features)
129
+ empty_conditions = net.get_empty_conditions(
130
+ bs, negative_text_features=negative_text_features if negative_text is not None else None)
131
+
132
+ cfg_ode_wrapper = lambda t, x: net.ode_wrapper(t, x, preprocessed_conditions, empty_conditions,
133
+ cfg_strength)
134
+ x1 = fm.to_data(cfg_ode_wrapper, x0)
135
+ x1 = net.unnormalize(x1)
136
+ spec = feature_utils.decode(x1)
137
+ audio = feature_utils.vocode(spec)
138
+ return audio
139
+
140
+
141
+ LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s"
142
+
143
+
144
+ def setup_eval_logging(log_level: int = logging.INFO):
145
+ logging.root.setLevel(log_level)
146
+ formatter = ColoredFormatter(LOGFORMAT)
147
+ stream = logging.StreamHandler()
148
+ stream.setLevel(log_level)
149
+ stream.setFormatter(formatter)
150
+ log = logging.getLogger()
151
+ log.setLevel(log_level)
152
+ log.addHandler(stream)
153
+
154
+
155
+ _CLIP_SIZE = 384
156
+ _CLIP_FPS = 8.0
157
+
158
+ _SYNC_SIZE = 224
159
+ _SYNC_FPS = 25.0
160
+
161
+
162
+ def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
163
+
164
+ clip_transform = v2.Compose([
165
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
166
+ v2.ToImage(),
167
+ v2.ToDtype(torch.float32, scale=True),
168
+ ])
169
+
170
+ sync_transform = v2.Compose([
171
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
172
+ v2.CenterCrop(_SYNC_SIZE),
173
+ v2.ToImage(),
174
+ v2.ToDtype(torch.float32, scale=True),
175
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
176
+ ])
177
+
178
+ output_frames, all_frames, orig_fps = read_frames(video_path,
179
+ list_of_fps=[_CLIP_FPS, _SYNC_FPS],
180
+ start_sec=0,
181
+ end_sec=duration_sec,
182
+ need_all_frames=load_all_frames)
183
+
184
+ clip_chunk, sync_chunk = output_frames
185
+ clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2)
186
+ sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2)
187
+
188
+ clip_frames = clip_transform(clip_chunk)
189
+ sync_frames = sync_transform(sync_chunk)
190
+
191
+ clip_length_sec = clip_frames.shape[0] / _CLIP_FPS
192
+ sync_length_sec = sync_frames.shape[0] / _SYNC_FPS
193
+
194
+ if clip_length_sec < duration_sec:
195
+ log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}')
196
+ log.warning(f'Truncating to {clip_length_sec:.2f} sec')
197
+ duration_sec = clip_length_sec
198
+
199
+ if sync_length_sec < duration_sec:
200
+ log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}')
201
+ log.warning(f'Truncating to {sync_length_sec:.2f} sec')
202
+ duration_sec = sync_length_sec
203
+
204
+ clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
205
+ sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
206
+
207
+ video_info = VideoInfo(
208
+ duration_sec=duration_sec,
209
+ fps=orig_fps,
210
+ clip_frames=clip_frames,
211
+ sync_frames=sync_frames,
212
+ all_frames=all_frames if load_all_frames else None,
213
+ )
214
+ return video_info
215
+
216
+
217
+ def load_image(image_path: Path) -> VideoInfo:
218
+ clip_transform = v2.Compose([
219
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
220
+ v2.ToImage(),
221
+ v2.ToDtype(torch.float32, scale=True),
222
+ ])
223
+
224
+ sync_transform = v2.Compose([
225
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
226
+ v2.CenterCrop(_SYNC_SIZE),
227
+ v2.ToImage(),
228
+ v2.ToDtype(torch.float32, scale=True),
229
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
230
+ ])
231
+
232
+ frame = np.array(Image.open(image_path))
233
+
234
+ clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
235
+ sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
236
+
237
+ clip_frames = clip_transform(clip_chunk)
238
+ sync_frames = sync_transform(sync_chunk)
239
+
240
+ video_info = ImageInfo(
241
+ clip_frames=clip_frames,
242
+ sync_frames=sync_frames,
243
+ original_frame=frame,
244
+ )
245
+ return video_info
246
+
247
+
248
+ def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
249
+ reencode_with_audio(video_info, output_path, audio, sampling_rate)
hf_AC/mmaudio/ext/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
hf_AC/mmaudio/ext/autoencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .autoencoder import AutoEncoderModule
hf_AC/mmaudio/ext/autoencoder/autoencoder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from mmaudio.ext.autoencoder.vae import VAE, get_my_vae
7
+ from mmaudio.ext.bigvgan import BigVGAN
8
+ from mmaudio.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
9
+ from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
10
+
11
+
12
+ class AutoEncoderModule(nn.Module):
13
+
14
+ def __init__(self,
15
+ *,
16
+ vae_ckpt_path,
17
+ vocoder_ckpt_path: Optional[str] = None,
18
+ mode: Literal['16k', '44k'],
19
+ need_vae_encoder: bool = True):
20
+ super().__init__()
21
+ self.vae: VAE = get_my_vae(mode).eval()
22
+ vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
23
+ self.vae.load_state_dict(vae_state_dict)
24
+ self.vae.remove_weight_norm()
25
+
26
+ if mode == '16k':
27
+ assert vocoder_ckpt_path is not None
28
+ self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
29
+ elif mode == '44k':
30
+ self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
31
+ use_cuda_kernel=False)
32
+ self.vocoder.remove_weight_norm()
33
+ else:
34
+ raise ValueError(f'Unknown mode: {mode}')
35
+
36
+ for param in self.parameters():
37
+ param.requires_grad = False
38
+
39
+ if not need_vae_encoder:
40
+ del self.vae.encoder
41
+
42
+ @torch.inference_mode()
43
+ def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
44
+ return self.vae.encode(x)
45
+
46
+ @torch.inference_mode()
47
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
48
+ return self.vae.decode(z)
49
+
50
+ @torch.inference_mode()
51
+ def vocode(self, spec: torch.Tensor) -> torch.Tensor:
52
+ return self.vocoder(spec)
hf_AC/mmaudio/ext/autoencoder/edm2_utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+ """Improved diffusion model architecture proposed in the paper
8
+ "Analyzing and Improving the Training Dynamics of Diffusion Models"."""
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ #----------------------------------------------------------------------------
14
+ # Variant of constant() that inherits dtype and device from the given
15
+ # reference tensor by default.
16
+
17
+ _constant_cache = dict()
18
+
19
+
20
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
21
+ value = np.asarray(value)
22
+ if shape is not None:
23
+ shape = tuple(shape)
24
+ if dtype is None:
25
+ dtype = torch.get_default_dtype()
26
+ if device is None:
27
+ device = torch.device('cpu')
28
+ if memory_format is None:
29
+ memory_format = torch.contiguous_format
30
+
31
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
32
+ tensor = _constant_cache.get(key, None)
33
+ if tensor is None:
34
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
35
+ if shape is not None:
36
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
37
+ tensor = tensor.contiguous(memory_format=memory_format)
38
+ _constant_cache[key] = tensor
39
+ return tensor
40
+
41
+
42
+ def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
43
+ if dtype is None:
44
+ dtype = ref.dtype
45
+ if device is None:
46
+ device = ref.device
47
+ return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
48
+
49
+
50
+ #----------------------------------------------------------------------------
51
+ # Normalize given tensor to unit magnitude with respect to the given
52
+ # dimensions. Default = all dimensions except the first.
53
+
54
+
55
+ def normalize(x, dim=None, eps=1e-4):
56
+ if dim is None:
57
+ dim = list(range(1, x.ndim))
58
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
59
+ norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
60
+ return x / norm.to(x.dtype)
61
+
62
+
63
+ class Normalize(torch.nn.Module):
64
+
65
+ def __init__(self, dim=None, eps=1e-4):
66
+ super().__init__()
67
+ self.dim = dim
68
+ self.eps = eps
69
+
70
+ def forward(self, x):
71
+ return normalize(x, dim=self.dim, eps=self.eps)
72
+
73
+
74
+ #----------------------------------------------------------------------------
75
+ # Upsample or downsample the given tensor with the given filter,
76
+ # or keep it as is.
77
+
78
+
79
+ def resample(x, f=[1, 1], mode='keep'):
80
+ if mode == 'keep':
81
+ return x
82
+ f = np.float32(f)
83
+ assert f.ndim == 1 and len(f) % 2 == 0
84
+ pad = (len(f) - 1) // 2
85
+ f = f / f.sum()
86
+ f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
87
+ f = const_like(x, f)
88
+ c = x.shape[1]
89
+ if mode == 'down':
90
+ return torch.nn.functional.conv2d(x,
91
+ f.tile([c, 1, 1, 1]),
92
+ groups=c,
93
+ stride=2,
94
+ padding=(pad, ))
95
+ assert mode == 'up'
96
+ return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]),
97
+ groups=c,
98
+ stride=2,
99
+ padding=(pad, ))
100
+
101
+
102
+ #----------------------------------------------------------------------------
103
+ # Magnitude-preserving SiLU (Equation 81).
104
+
105
+
106
+ def mp_silu(x):
107
+ return torch.nn.functional.silu(x) / 0.596
108
+
109
+
110
+ class MPSiLU(torch.nn.Module):
111
+
112
+ def forward(self, x):
113
+ return mp_silu(x)
114
+
115
+
116
+ #----------------------------------------------------------------------------
117
+ # Magnitude-preserving sum (Equation 88).
118
+
119
+
120
+ def mp_sum(a, b, t=0.5):
121
+ return a.lerp(b, t) / np.sqrt((1 - t)**2 + t**2)
122
+
123
+
124
+ #----------------------------------------------------------------------------
125
+ # Magnitude-preserving concatenation (Equation 103).
126
+
127
+
128
+ def mp_cat(a, b, dim=1, t=0.5):
129
+ Na = a.shape[dim]
130
+ Nb = b.shape[dim]
131
+ C = np.sqrt((Na + Nb) / ((1 - t)**2 + t**2))
132
+ wa = C / np.sqrt(Na) * (1 - t)
133
+ wb = C / np.sqrt(Nb) * t
134
+ return torch.cat([wa * a, wb * b], dim=dim)
135
+
136
+
137
+ #----------------------------------------------------------------------------
138
+ # Magnitude-preserving convolution or fully-connected layer (Equation 47)
139
+ # with force weight normalization (Equation 66).
140
+
141
+
142
+ class MPConv1D(torch.nn.Module):
143
+
144
+ def __init__(self, in_channels, out_channels, kernel_size):
145
+ super().__init__()
146
+ self.out_channels = out_channels
147
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
148
+
149
+ self.weight_norm_removed = False
150
+
151
+ def forward(self, x, gain=1):
152
+ assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
153
+
154
+ w = self.weight * gain
155
+ if w.ndim == 2:
156
+ return x @ w.t()
157
+ assert w.ndim == 3
158
+ return torch.nn.functional.conv1d(x, w, padding=(w.shape[-1] // 2, ))
159
+
160
+ def remove_weight_norm(self):
161
+ w = self.weight.to(torch.float32)
162
+ w = normalize(w) # traditional weight normalization
163
+ w = w / np.sqrt(w[0].numel())
164
+ w = w.to(self.weight.dtype)
165
+ self.weight.data.copy_(w)
166
+
167
+ self.weight_norm_removed = True
168
+ return self
hf_AC/mmaudio/ext/autoencoder/vae.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from mmaudio.ext.autoencoder.edm2_utils import MPConv1D
8
+ from mmaudio.ext.autoencoder.vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
9
+ Upsample1D, nonlinearity)
10
+ from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
11
+
12
+ log = logging.getLogger()
13
+
14
+ DATA_MEAN_80D = [
15
+ -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
16
+ -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
17
+ -1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
18
+ -1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
19
+ -1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
20
+ -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
21
+ -2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
22
+ -2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
23
+ ]
24
+
25
+ DATA_STD_80D = [
26
+ 1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
27
+ 0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
28
+ 0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
29
+ 0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
30
+ 0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
31
+ 0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
32
+ 1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
33
+ ]
34
+
35
+ DATA_MEAN_128D = [
36
+ -3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
37
+ -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
38
+ -2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
39
+ -3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
40
+ -3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
41
+ -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
42
+ -3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
43
+ -4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
44
+ -4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
45
+ -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
46
+ -6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
47
+ -7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
48
+ -9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
49
+ ]
50
+
51
+ DATA_STD_128D = [
52
+ 2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
53
+ 2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
54
+ 2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
55
+ 2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
56
+ 2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
57
+ 2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
58
+ 2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
59
+ 2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
60
+ 2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
61
+ 2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
62
+ 2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
63
+ ]
64
+
65
+
66
+ class VAE(nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ *,
71
+ data_dim: int,
72
+ embed_dim: int,
73
+ hidden_dim: int,
74
+ ):
75
+ super().__init__()
76
+
77
+ if data_dim == 80:
78
+ self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
79
+ self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
80
+ elif data_dim == 128:
81
+ self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
82
+ self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
83
+
84
+ self.data_mean = self.data_mean.view(1, -1, 1)
85
+ self.data_std = self.data_std.view(1, -1, 1)
86
+
87
+ self.encoder = Encoder1D(
88
+ dim=hidden_dim,
89
+ ch_mult=(1, 2, 4),
90
+ num_res_blocks=2,
91
+ attn_layers=[3],
92
+ down_layers=[0],
93
+ in_dim=data_dim,
94
+ embed_dim=embed_dim,
95
+ )
96
+ self.decoder = Decoder1D(
97
+ dim=hidden_dim,
98
+ ch_mult=(1, 2, 4),
99
+ num_res_blocks=2,
100
+ attn_layers=[3],
101
+ down_layers=[0],
102
+ in_dim=data_dim,
103
+ out_dim=data_dim,
104
+ embed_dim=embed_dim,
105
+ )
106
+
107
+ self.embed_dim = embed_dim
108
+ # self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
109
+ # self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
110
+
111
+ self.initialize_weights()
112
+
113
+ def initialize_weights(self):
114
+ pass
115
+
116
+ def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
117
+ if normalize:
118
+ x = self.normalize(x)
119
+ moments = self.encoder(x)
120
+ posterior = DiagonalGaussianDistribution(moments)
121
+ return posterior
122
+
123
+ def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
124
+ dec = self.decoder(z)
125
+ if unnormalize:
126
+ dec = self.unnormalize(dec)
127
+ return dec
128
+
129
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
130
+ return (x - self.data_mean) / self.data_std
131
+
132
+ def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
133
+ return x * self.data_std + self.data_mean
134
+
135
+ def forward(
136
+ self,
137
+ x: torch.Tensor,
138
+ sample_posterior: bool = True,
139
+ rng: Optional[torch.Generator] = None,
140
+ normalize: bool = True,
141
+ unnormalize: bool = True,
142
+ ) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
143
+
144
+ posterior = self.encode(x, normalize=normalize)
145
+ if sample_posterior:
146
+ z = posterior.sample(rng)
147
+ else:
148
+ z = posterior.mode()
149
+ dec = self.decode(z, unnormalize=unnormalize)
150
+ return dec, posterior
151
+
152
+ def load_weights(self, src_dict) -> None:
153
+ self.load_state_dict(src_dict, strict=True)
154
+
155
+ @property
156
+ def device(self) -> torch.device:
157
+ return next(self.parameters()).device
158
+
159
+ def get_last_layer(self):
160
+ return self.decoder.conv_out.weight
161
+
162
+ def remove_weight_norm(self):
163
+ for name, m in self.named_modules():
164
+ if isinstance(m, MPConv1D):
165
+ m.remove_weight_norm()
166
+ log.debug(f"Removed weight norm from {name}")
167
+ return self
168
+
169
+
170
+ class Encoder1D(nn.Module):
171
+
172
+ def __init__(self,
173
+ *,
174
+ dim: int,
175
+ ch_mult: tuple[int] = (1, 2, 4, 8),
176
+ num_res_blocks: int,
177
+ attn_layers: list[int] = [],
178
+ down_layers: list[int] = [],
179
+ resamp_with_conv: bool = True,
180
+ in_dim: int,
181
+ embed_dim: int,
182
+ double_z: bool = True,
183
+ kernel_size: int = 3,
184
+ clip_act: float = 256.0):
185
+ super().__init__()
186
+ self.dim = dim
187
+ self.num_layers = len(ch_mult)
188
+ self.num_res_blocks = num_res_blocks
189
+ self.in_channels = in_dim
190
+ self.clip_act = clip_act
191
+ self.down_layers = down_layers
192
+ self.attn_layers = attn_layers
193
+ self.conv_in = MPConv1D(in_dim, self.dim, kernel_size=kernel_size)
194
+
195
+ in_ch_mult = (1, ) + tuple(ch_mult)
196
+ self.in_ch_mult = in_ch_mult
197
+ # downsampling
198
+ self.down = nn.ModuleList()
199
+ for i_level in range(self.num_layers):
200
+ block = nn.ModuleList()
201
+ attn = nn.ModuleList()
202
+ block_in = dim * in_ch_mult[i_level]
203
+ block_out = dim * ch_mult[i_level]
204
+ for i_block in range(self.num_res_blocks):
205
+ block.append(
206
+ ResnetBlock1D(in_dim=block_in,
207
+ out_dim=block_out,
208
+ kernel_size=kernel_size,
209
+ use_norm=True))
210
+ block_in = block_out
211
+ if i_level in attn_layers:
212
+ attn.append(AttnBlock1D(block_in))
213
+ down = nn.Module()
214
+ down.block = block
215
+ down.attn = attn
216
+ if i_level in down_layers:
217
+ down.downsample = Downsample1D(block_in, resamp_with_conv)
218
+ self.down.append(down)
219
+
220
+ # middle
221
+ self.mid = nn.Module()
222
+ self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
223
+ out_dim=block_in,
224
+ kernel_size=kernel_size,
225
+ use_norm=True)
226
+ self.mid.attn_1 = AttnBlock1D(block_in)
227
+ self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
228
+ out_dim=block_in,
229
+ kernel_size=kernel_size,
230
+ use_norm=True)
231
+
232
+ # end
233
+ self.conv_out = MPConv1D(block_in,
234
+ 2 * embed_dim if double_z else embed_dim,
235
+ kernel_size=kernel_size)
236
+
237
+ self.learnable_gain = nn.Parameter(torch.zeros([]))
238
+
239
+ def forward(self, x):
240
+
241
+ # downsampling
242
+ hs = [self.conv_in(x)]
243
+ for i_level in range(self.num_layers):
244
+ for i_block in range(self.num_res_blocks):
245
+ h = self.down[i_level].block[i_block](hs[-1])
246
+ if len(self.down[i_level].attn) > 0:
247
+ h = self.down[i_level].attn[i_block](h)
248
+ h = h.clamp(-self.clip_act, self.clip_act)
249
+ hs.append(h)
250
+ if i_level in self.down_layers:
251
+ hs.append(self.down[i_level].downsample(hs[-1]))
252
+
253
+ # middle
254
+ h = hs[-1]
255
+ h = self.mid.block_1(h)
256
+ h = self.mid.attn_1(h)
257
+ h = self.mid.block_2(h)
258
+ h = h.clamp(-self.clip_act, self.clip_act)
259
+
260
+ # end
261
+ h = nonlinearity(h)
262
+ h = self.conv_out(h, gain=(self.learnable_gain + 1))
263
+ return h
264
+
265
+
266
+ class Decoder1D(nn.Module):
267
+
268
+ def __init__(self,
269
+ *,
270
+ dim: int,
271
+ out_dim: int,
272
+ ch_mult: tuple[int] = (1, 2, 4, 8),
273
+ num_res_blocks: int,
274
+ attn_layers: list[int] = [],
275
+ down_layers: list[int] = [],
276
+ kernel_size: int = 3,
277
+ resamp_with_conv: bool = True,
278
+ in_dim: int,
279
+ embed_dim: int,
280
+ clip_act: float = 256.0):
281
+ super().__init__()
282
+ self.ch = dim
283
+ self.num_layers = len(ch_mult)
284
+ self.num_res_blocks = num_res_blocks
285
+ self.in_channels = in_dim
286
+ self.clip_act = clip_act
287
+ self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
288
+
289
+ # compute in_ch_mult, block_in and curr_res at lowest res
290
+ block_in = dim * ch_mult[self.num_layers - 1]
291
+
292
+ # z to block_in
293
+ self.conv_in = MPConv1D(embed_dim, block_in, kernel_size=kernel_size)
294
+
295
+ # middle
296
+ self.mid = nn.Module()
297
+ self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
298
+ self.mid.attn_1 = AttnBlock1D(block_in)
299
+ self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
300
+
301
+ # upsampling
302
+ self.up = nn.ModuleList()
303
+ for i_level in reversed(range(self.num_layers)):
304
+ block = nn.ModuleList()
305
+ attn = nn.ModuleList()
306
+ block_out = dim * ch_mult[i_level]
307
+ for i_block in range(self.num_res_blocks + 1):
308
+ block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
309
+ block_in = block_out
310
+ if i_level in attn_layers:
311
+ attn.append(AttnBlock1D(block_in))
312
+ up = nn.Module()
313
+ up.block = block
314
+ up.attn = attn
315
+ if i_level in self.down_layers:
316
+ up.upsample = Upsample1D(block_in, resamp_with_conv)
317
+ self.up.insert(0, up) # prepend to get consistent order
318
+
319
+ # end
320
+ self.conv_out = MPConv1D(block_in, out_dim, kernel_size=kernel_size)
321
+ self.learnable_gain = nn.Parameter(torch.zeros([]))
322
+
323
+ def forward(self, z):
324
+ # z to block_in
325
+ h = self.conv_in(z)
326
+
327
+ # middle
328
+ h = self.mid.block_1(h)
329
+ h = self.mid.attn_1(h)
330
+ h = self.mid.block_2(h)
331
+ h = h.clamp(-self.clip_act, self.clip_act)
332
+
333
+ # upsampling
334
+ for i_level in reversed(range(self.num_layers)):
335
+ for i_block in range(self.num_res_blocks + 1):
336
+ h = self.up[i_level].block[i_block](h)
337
+ if len(self.up[i_level].attn) > 0:
338
+ h = self.up[i_level].attn[i_block](h)
339
+ h = h.clamp(-self.clip_act, self.clip_act)
340
+ if i_level in self.down_layers:
341
+ h = self.up[i_level].upsample(h)
342
+
343
+ h = nonlinearity(h)
344
+ h = self.conv_out(h, gain=(self.learnable_gain + 1))
345
+ return h
346
+
347
+
348
+ def VAE_16k(**kwargs) -> VAE:
349
+ return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
350
+
351
+
352
+ def VAE_44k(**kwargs) -> VAE:
353
+ return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
354
+
355
+
356
+ def get_my_vae(name: str, **kwargs) -> VAE:
357
+ if name == '16k':
358
+ return VAE_16k(**kwargs)
359
+ if name == '44k':
360
+ return VAE_44k(**kwargs)
361
+ raise ValueError(f'Unknown model: {name}')
362
+
363
+
364
+ if __name__ == '__main__':
365
+ network = get_my_vae('standard')
366
+
367
+ # print the number of parameters in terms of millions
368
+ num_params = sum(p.numel() for p in network.parameters()) / 1e6
369
+ print(f'Number of parameters: {num_params:.2f}M')
hf_AC/mmaudio/ext/autoencoder/vae_modules.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ from mmaudio.ext.autoencoder.edm2_utils import (MPConv1D, mp_silu, mp_sum, normalize)
7
+
8
+
9
+ def nonlinearity(x):
10
+ # swish
11
+ return mp_silu(x)
12
+
13
+
14
+ class ResnetBlock1D(nn.Module):
15
+
16
+ def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
17
+ super().__init__()
18
+ self.in_dim = in_dim
19
+ out_dim = in_dim if out_dim is None else out_dim
20
+ self.out_dim = out_dim
21
+ self.use_conv_shortcut = conv_shortcut
22
+ self.use_norm = use_norm
23
+
24
+ self.conv1 = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
25
+ self.conv2 = MPConv1D(out_dim, out_dim, kernel_size=kernel_size)
26
+ if self.in_dim != self.out_dim:
27
+ if self.use_conv_shortcut:
28
+ self.conv_shortcut = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
29
+ else:
30
+ self.nin_shortcut = MPConv1D(in_dim, out_dim, kernel_size=1)
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+
34
+ # pixel norm
35
+ if self.use_norm:
36
+ x = normalize(x, dim=1)
37
+
38
+ h = x
39
+ h = nonlinearity(h)
40
+ h = self.conv1(h)
41
+
42
+ h = nonlinearity(h)
43
+ h = self.conv2(h)
44
+
45
+ if self.in_dim != self.out_dim:
46
+ if self.use_conv_shortcut:
47
+ x = self.conv_shortcut(x)
48
+ else:
49
+ x = self.nin_shortcut(x)
50
+
51
+ return mp_sum(x, h, t=0.3)
52
+
53
+
54
+ class AttnBlock1D(nn.Module):
55
+
56
+ def __init__(self, in_channels, num_heads=1):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+
60
+ self.num_heads = num_heads
61
+ self.qkv = MPConv1D(in_channels, in_channels * 3, kernel_size=1)
62
+ self.proj_out = MPConv1D(in_channels, in_channels, kernel_size=1)
63
+
64
+ def forward(self, x):
65
+ h = x
66
+ y = self.qkv(h)
67
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[-1])
68
+ q, k, v = normalize(y, dim=2).unbind(3)
69
+
70
+ q = rearrange(q, 'b h c l -> b h l c')
71
+ k = rearrange(k, 'b h c l -> b h l c')
72
+ v = rearrange(v, 'b h c l -> b h l c')
73
+
74
+ h = F.scaled_dot_product_attention(q, k, v)
75
+ h = rearrange(h, 'b h l c -> b (h c) l')
76
+
77
+ h = self.proj_out(h)
78
+
79
+ return mp_sum(x, h, t=0.3)
80
+
81
+
82
+ class Upsample1D(nn.Module):
83
+
84
+ def __init__(self, in_channels, with_conv):
85
+ super().__init__()
86
+ self.with_conv = with_conv
87
+ if self.with_conv:
88
+ self.conv = MPConv1D(in_channels, in_channels, kernel_size=3)
89
+
90
+ def forward(self, x):
91
+ x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
92
+ if self.with_conv:
93
+ x = self.conv(x)
94
+ return x
95
+
96
+
97
+ class Downsample1D(nn.Module):
98
+
99
+ def __init__(self, in_channels, with_conv):
100
+ super().__init__()
101
+ self.with_conv = with_conv
102
+ if self.with_conv:
103
+ # no asymmetric padding in torch conv, must do it ourselves
104
+ self.conv1 = MPConv1D(in_channels, in_channels, kernel_size=1)
105
+ self.conv2 = MPConv1D(in_channels, in_channels, kernel_size=1)
106
+
107
+ def forward(self, x):
108
+
109
+ if self.with_conv:
110
+ x = self.conv1(x)
111
+
112
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
113
+
114
+ if self.with_conv:
115
+ x = self.conv2(x)
116
+
117
+ return x
hf_AC/mmaudio/ext/bigvgan/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
hf_AC/mmaudio/ext/bigvgan/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .bigvgan import BigVGAN
hf_AC/mmaudio/ext/bigvgan/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super(SnakeBeta, self).__init__()
91
+ self.in_features = in_features
92
+
93
+ # initialize alpha
94
+ self.alpha_logscale = alpha_logscale
95
+ if self.alpha_logscale: # log scale alphas initialized to zeros
96
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
98
+ else: # linear scale alphas initialized to ones
99
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
100
+ self.beta = Parameter(torch.ones(in_features) * alpha)
101
+
102
+ self.alpha.requires_grad = alpha_trainable
103
+ self.beta.requires_grad = alpha_trainable
104
+
105
+ self.no_div_by_zero = 0.000000001
106
+
107
+ def forward(self, x):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
115
+ if self.alpha_logscale:
116
+ alpha = torch.exp(alpha)
117
+ beta = torch.exp(beta)
118
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
+
120
+ return x
hf_AC/mmaudio/ext/bigvgan/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
hf_AC/mmaudio/ext/bigvgan/alias_free_torch/act.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12):
15
+ super().__init__()
16
+ self.up_ratio = up_ratio
17
+ self.down_ratio = down_ratio
18
+ self.act = activation
19
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
20
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
21
+
22
+ # x: [B,C,T]
23
+ def forward(self, x):
24
+ x = self.upsample(x)
25
+ x = self.act(x)
26
+ x = self.downsample(x)
27
+
28
+ return x
hf_AC/mmaudio/ext/bigvgan/alias_free_torch/filter.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if 'sinc' in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(x == 0,
21
+ torch.tensor(1., device=x.device, dtype=x.dtype),
22
+ torch.sin(math.pi * x) / math.pi / x)
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ # LICENSE is in incl_licenses directory.
28
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29
+ even = (kernel_size % 2 == 0)
30
+ half_size = kernel_size // 2
31
+
32
+ #For kaiser window
33
+ delta_f = 4 * half_width
34
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if A > 50.:
36
+ beta = 0.1102 * (A - 8.7)
37
+ elif A >= 21.:
38
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39
+ else:
40
+ beta = 0.
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+
43
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44
+ if even:
45
+ time = (torch.arange(-half_size, half_size) + 0.5)
46
+ else:
47
+ time = torch.arange(kernel_size) - half_size
48
+ if cutoff == 0:
49
+ filter_ = torch.zeros_like(time)
50
+ else:
51
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
53
+ # of the constant component in the input signal.
54
+ filter_ /= filter_.sum()
55
+ filter = filter_.view(1, 1, kernel_size)
56
+
57
+ return filter
58
+
59
+
60
+ class LowPassFilter1d(nn.Module):
61
+ def __init__(self,
62
+ cutoff=0.5,
63
+ half_width=0.6,
64
+ stride: int = 1,
65
+ padding: bool = True,
66
+ padding_mode: str = 'replicate',
67
+ kernel_size: int = 12):
68
+ # kernel_size should be even number for stylegan3 setup,
69
+ # in this implementation, odd number is also possible.
70
+ super().__init__()
71
+ if cutoff < -0.:
72
+ raise ValueError("Minimum cutoff must be larger than zero.")
73
+ if cutoff > 0.5:
74
+ raise ValueError("A cutoff above 0.5 does not make sense.")
75
+ self.kernel_size = kernel_size
76
+ self.even = (kernel_size % 2 == 0)
77
+ self.pad_left = kernel_size // 2 - int(self.even)
78
+ self.pad_right = kernel_size // 2
79
+ self.stride = stride
80
+ self.padding = padding
81
+ self.padding_mode = padding_mode
82
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83
+ self.register_buffer("filter", filter)
84
+
85
+ #input [B, C, T]
86
+ def forward(self, x):
87
+ _, C, _ = x.shape
88
+
89
+ if self.padding:
90
+ x = F.pad(x, (self.pad_left, self.pad_right),
91
+ mode=self.padding_mode)
92
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
93
+ stride=self.stride, groups=C)
94
+
95
+ return out
hf_AC/mmaudio/ext/bigvgan/alias_free_torch/resample.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ self.stride = ratio
16
+ self.pad = self.kernel_size // ratio - 1
17
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
+ half_width=0.6 / ratio,
21
+ kernel_size=self.kernel_size)
22
+ self.register_buffer("filter", filter)
23
+
24
+ # x: [B, C, T]
25
+ def forward(self, x):
26
+ _, C, _ = x.shape
27
+
28
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
+ x = self.ratio * F.conv_transpose1d(
30
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
+ x = x[..., self.pad_left:-self.pad_right]
32
+
33
+ return x
34
+
35
+
36
+ class DownSample1d(nn.Module):
37
+ def __init__(self, ratio=2, kernel_size=None):
38
+ super().__init__()
39
+ self.ratio = ratio
40
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
+ half_width=0.6 / ratio,
43
+ stride=ratio,
44
+ kernel_size=self.kernel_size)
45
+
46
+ def forward(self, x):
47
+ xx = self.lowpass(x)
48
+
49
+ return xx
hf_AC/mmaudio/ext/bigvgan/bigvgan.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from omegaconf import OmegaConf
6
+
7
+ from mmaudio.ext.bigvgan.models import BigVGANVocoder
8
+
9
+ _bigvgan_vocoder_path = Path(__file__).parent / 'bigvgan_vocoder.yml'
10
+
11
+
12
+ class BigVGAN(nn.Module):
13
+
14
+ def __init__(self, ckpt_path, config_path=_bigvgan_vocoder_path):
15
+ super().__init__()
16
+ vocoder_cfg = OmegaConf.load(config_path)
17
+ self.vocoder = BigVGANVocoder(vocoder_cfg).eval()
18
+ vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)['generator']
19
+ self.vocoder.load_state_dict(vocoder_ckpt)
20
+
21
+ self.weight_norm_removed = False
22
+ self.remove_weight_norm()
23
+
24
+ @torch.inference_mode()
25
+ def forward(self, x):
26
+ assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
27
+ return self.vocoder(x)
28
+
29
+ def remove_weight_norm(self):
30
+ self.vocoder.remove_weight_norm()
31
+ self.weight_norm_removed = True
32
+ return self