braindeck commited on
Commit
bcdf9fa
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. README.md +20 -0
  3. _infer.py +309 -0
  4. app.py +43 -0
  5. config/infer.yaml +56 -0
  6. prompts/__pycache__/base_instruction.cpython-311.pyc +0 -0
  7. prompts/__pycache__/infer_prompt.cpython-311.pyc +0 -0
  8. prompts/__pycache__/sft_prompt.cpython-311.pyc +0 -0
  9. prompts/base_instruction.py +24 -0
  10. prompts/infer_prompt.py +101 -0
  11. requirements.txt +4 -0
  12. requirements.txt.txt +200 -0
  13. scripts/sft_infer_pass1.sh +33 -0
  14. verl/__init__.py +40 -0
  15. verl/__pycache__/__init__.cpython-311.pyc +0 -0
  16. verl/__pycache__/protocol.cpython-311.pyc +0 -0
  17. verl/models/README.md +35 -0
  18. verl/models/__init__.py +13 -0
  19. verl/models/__pycache__/__init__.cpython-311.pyc +0 -0
  20. verl/models/__pycache__/registry.cpython-311.pyc +0 -0
  21. verl/models/llama/__init__.py +13 -0
  22. verl/models/llama/megatron/__init__.py +34 -0
  23. verl/models/llama/megatron/checkpoint_utils/__init__.py +13 -0
  24. verl/models/llama/megatron/checkpoint_utils/llama_loader.py +295 -0
  25. verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py +425 -0
  26. verl/models/llama/megatron/checkpoint_utils/llama_saver.py +430 -0
  27. verl/models/llama/megatron/layers/__init__.py +25 -0
  28. verl/models/llama/megatron/layers/parallel_attention.py +425 -0
  29. verl/models/llama/megatron/layers/parallel_decoder.py +150 -0
  30. verl/models/llama/megatron/layers/parallel_linear.py +106 -0
  31. verl/models/llama/megatron/layers/parallel_mlp.py +74 -0
  32. verl/models/llama/megatron/layers/parallel_rmsnorm.py +48 -0
  33. verl/models/llama/megatron/modeling_llama_megatron.py +662 -0
  34. verl/models/mcore/__init__.py +18 -0
  35. verl/models/mcore/config_converter.py +197 -0
  36. verl/models/mcore/loader.py +468 -0
  37. verl/models/mcore/model_forward.py +50 -0
  38. verl/models/mcore/model_initializer.py +160 -0
  39. verl/models/mcore/readme.md +99 -0
  40. verl/models/mcore/registry.py +179 -0
  41. verl/models/mcore/saver.py +459 -0
  42. verl/models/mcore/util.py +190 -0
  43. verl/models/mcore/weight_converter.py +207 -0
  44. verl/models/qwen2/__init__.py +13 -0
  45. verl/models/qwen2/megatron/__init__.py +34 -0
  46. verl/models/qwen2/megatron/checkpoint_utils/__init__.py +13 -0
  47. verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py +312 -0
  48. verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py +442 -0
  49. verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py +436 -0
  50. verl/models/qwen2/megatron/layers/__init__.py +20 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ checkpoints/
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Text-to-Text Generation with DeepSeek-R1-Distill-Qwen-7B
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.19.2
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Text-to-Text Generation with DeepSeek-R1-Distill-Qwen-7B
13
+
14
+ This is a simple Gradio interface for text-to-text generation using the `deepseek-ai/DeepSeek-R1-Distill-Qwen-7B` model.
15
+
16
+ ## How to use
17
+
18
+ 1. Enter a prompt in the text box.
19
+ 2. Click the "Generate" button.
20
+ 3. The model will generate a response in the "Response" text box.
_infer.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2024 Bytedance
3
+ # Apache-2.0
4
+ #
5
+ # VERL + vLLM inference with runtime LoRA (no merge).
6
+ # - Wraps a LoRA .pt into a PEFT adapter and attaches via rollout.lora_modules
7
+ # - Mixed precision defaults for H100: dtype=bf16, kv_cache_dtype=fp8_e5m2
8
+ # - Pins max_model_len, max_num_batched_tokens, sets swap_space
9
+ # - Uses OmegaConf.open_dict to add keys safely (no "not in struct" errors)
10
+ # - Prevents FSDP from trying to load LoRA .pt as a full model
11
+
12
+ import os
13
+ import ast
14
+ import json
15
+ import hydra
16
+ import numpy as np
17
+ import ray
18
+ import torch
19
+ from pathlib import Path
20
+ from pprint import pprint
21
+
22
+ # Quiet logs
23
+ os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "WARN")
24
+ os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get("TOKENIZERS_PARALLELISM", "true")
25
+
26
+ # vLLM CuMem allocator is incompatible with expandable_segments
27
+ _bad = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
28
+ if "expandable_segments:True" in _bad:
29
+ print(f"[fix] Removing incompatible PYTORCH_CUDA_ALLOC_CONF={_bad}")
30
+ os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None)
31
+
32
+ import pandas as pd
33
+ from omegaconf import OmegaConf, open_dict
34
+
35
+ from verl import DataProto
36
+ from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
37
+ from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
38
+ from verl.utils import hf_tokenizer
39
+ from verl.utils.fs import copy_to_local
40
+ from verl.utils.hdfs_io import makedirs
41
+ from verl.utils.model import compute_position_id_with_mask
42
+ from verl.workers.fsdp_workers import ActorRolloutRefWorker
43
+
44
+ # ---------------- LoRA helpers ----------------
45
+
46
+ DEFAULT_TARGET_MODULES = [
47
+ "q_proj","k_proj","v_proj","o_proj",
48
+ "up_proj","gate_proj","down_proj",
49
+ ]
50
+
51
+ def _infer_lengths_and_defaults(config):
52
+ """Ensure rollout/data keys exist and set reasonable H100 defaults."""
53
+ # Ensure nested structs exist
54
+ with open_dict(config):
55
+ if "rollout" not in config:
56
+ config["rollout"] = OmegaConf.create()
57
+ if "data" not in config:
58
+ config["data"] = OmegaConf.create()
59
+ if "trainer" not in config:
60
+ config["trainer"] = OmegaConf.create()
61
+ if "ray_init" not in config:
62
+ config["ray_init"] = OmegaConf.create()
63
+
64
+ # Defaults that work on a single H100
65
+ with open_dict(config.rollout):
66
+ # If user didn't set these, choose H100-friendly defaults
67
+ config.rollout.setdefault("dtype", "bfloat16") # weights/activations
68
+ config.rollout.setdefault("kv_cache_dtype", "fp8_e5m2") # KV cache precision
69
+ config.rollout.setdefault("tensor_model_parallel_size", 1)
70
+ config.rollout.setdefault("enable_chunked_prefill", True)
71
+ config.rollout.setdefault("swap_space", 8) # GB of host swap for KV
72
+ config.rollout.setdefault("gpu_memory_utilization", 0.62) # adjust 0.60~0.75 if needed
73
+
74
+ # Pin lengths to avoid vLLM over-reserving KV cache
75
+ pl = int(config.rollout.get("prompt_length", 1024))
76
+ rl = int(config.rollout.get("response_length", 128))
77
+ need = int(pl + rl)
78
+ config.rollout.setdefault("max_model_len", need)
79
+ config.rollout.setdefault("max_num_batched_tokens", need)
80
+
81
+ # Users may pass +rollout.quantization={fp8|awq|gptq} to shrink weights further
82
+ # We don't force it here.
83
+
84
+ with open_dict(config.data):
85
+ config.data.setdefault("batch_size", 1)
86
+ config.data.setdefault("n_samples", 1)
87
+ config.data.setdefault("prompt_key", "prompt")
88
+
89
+ with open_dict(config.trainer):
90
+ config.trainer.setdefault("n_gpus_per_node", 1)
91
+ config.trainer.setdefault("nnodes", 1)
92
+
93
+ with open_dict(config.ray_init):
94
+ config.ray_init.setdefault("num_cpus", 4)
95
+
96
+ def _infer_lora_rank_from_state(sd):
97
+ for k, v in sd.items():
98
+ if k.endswith("lora_A.weight") and hasattr(v, "dim") and v.dim() == 2:
99
+ return int(v.shape[0])
100
+ return None
101
+
102
+ def _list_target_modules_from_state(sd):
103
+ found = set()
104
+ for k in sd.keys():
105
+ if "lora_A.weight" in k or "lora_B.weight" in k:
106
+ if ".q_proj." in k: found.add("q_proj")
107
+ if ".k_proj." in k: found.add("k_proj")
108
+ if ".v_proj." in k: found.add("v_proj")
109
+ if ".o_proj." in k: found.add("o_proj")
110
+ if ".up_proj." in k: found.add("up_proj")
111
+ if ".gate_proj." in k: found.add("gate_proj")
112
+ if ".down_proj." in k: found.add("down_proj")
113
+ return sorted(found)
114
+
115
+ def _write_adapter_config(adapter_dir, r, alpha, target_modules, dropout=0.0):
116
+ cfg = {
117
+ "peft_type": "LORA",
118
+ "auto_mapping": None,
119
+ "base_model_name_or_path": "",
120
+ "bias": "none",
121
+ "inference_mode": True,
122
+ "lora_alpha": int(alpha),
123
+ "lora_dropout": float(dropout),
124
+ "r": int(r),
125
+ "target_modules": target_modules,
126
+ "task_type": "CAUSAL_LM",
127
+ }
128
+ with open(os.path.join(adapter_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
129
+ json.dump(cfg, f, ensure_ascii=False, indent=2)
130
+
131
+ def _wrap_lora_pt_as_peft(adapter_pt_path: str, out_dir: str,
132
+ fallback_rank=32, fallback_alpha=16):
133
+ os.makedirs(out_dir, exist_ok=True)
134
+ print(f"[lora] Loading LoRA state from: {adapter_pt_path}")
135
+ sd = torch.load(adapter_pt_path, map_location="cpu")
136
+ if isinstance(sd, dict) and "state_dict" in sd:
137
+ sd = sd["state_dict"]
138
+
139
+ r = _infer_lora_rank_from_state(sd) or int(fallback_rank)
140
+ tmods = _list_target_modules_from_state(sd) or DEFAULT_TARGET_MODULES
141
+ print(f"[lora] inferred rank={r}, target_modules={tmods}")
142
+
143
+ _write_adapter_config(out_dir, r=r, alpha=fallback_alpha, target_modules=tmods)
144
+ torch.save(sd, os.path.join(out_dir, "adapter_model.bin"))
145
+ return r, tmods
146
+
147
+ def _maybe_attach_lora_adapter(config):
148
+ """Attach LoRA adapter directory to vLLM rollout (runtime LoRA)."""
149
+ # Accept either +lora.pt_path or model.load_param_path as a hint
150
+ lora_pt = None
151
+ if "lora" in config and getattr(config.lora, "pt_path", ""):
152
+ lora_pt = config.lora.pt_path
153
+ elif getattr(config.model, "load_param_path", ""):
154
+ lora_pt = config.model.load_param_path
155
+
156
+ if not lora_pt or not Path(lora_pt).is_file():
157
+ print("[lora] No LoRA .pt provided; running base model only.")
158
+ return
159
+
160
+ adapter_dir = os.path.join("/tmp", "lora_adapter_vllm")
161
+ r, _ = _wrap_lora_pt_as_peft(lora_pt, adapter_dir, fallback_rank=32, fallback_alpha=16)
162
+
163
+ # Ensure rollout keys exist and add LoRA knobs required by vLLM
164
+ with open_dict(config):
165
+ if "rollout" not in config:
166
+ config["rollout"] = OmegaConf.create()
167
+ with open_dict(config.rollout):
168
+ config.rollout.setdefault("max_loras", 1)
169
+ config.rollout.setdefault("max_lora_rank", int(r))
170
+ config.rollout["lora_modules"] = [{"path": adapter_dir, "scale": 1.0}]
171
+ print(f"[lora] Attached PEFT adapter: {adapter_dir} (rank={r})")
172
+
173
+ # CRITICAL: don't let FSDP try to load the LoRA .pt as a full state dict
174
+ with open_dict(config.model):
175
+ if getattr(config.model, "load_param", False):
176
+ print("[lora] Disabling model.load_param to avoid FSDP load_state_dict mismatch.")
177
+ config.model["load_param"] = False
178
+
179
+ # ---------------- Hydra entry ----------------
180
+
181
+ @hydra.main(config_path="config", config_name="infer", version_base=None)
182
+ def main(config):
183
+ _infer_lengths_and_defaults(config)
184
+
185
+ # Ray env for workers
186
+ if not ray.is_initialized():
187
+ ray.init(
188
+ runtime_env={"env_vars": {
189
+ "TOKENIZERS_PARALLELISM": "true",
190
+ "NCCL_DEBUG": "WARN",
191
+ "PYTORCH_CUDA_ALLOC_CONF": "", # keep allocator happy for vLLM
192
+ }},
193
+ num_cpus=config.ray_init.num_cpus,
194
+ )
195
+
196
+ ray.get(main_task.remote(config))
197
+
198
+ @ray.remote(num_cpus=1)
199
+ def main_task(config):
200
+ print("[worker] PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))
201
+ pprint(OmegaConf.to_container(config, resolve=True))
202
+ OmegaConf.resolve(config)
203
+
204
+ # Build LoRA adapter if provided
205
+ _maybe_attach_lora_adapter(config)
206
+
207
+ # Optionally pre-gen dataset schema if your repo provides it
208
+ try:
209
+ from prompts.infer_prompt import infer_dataset
210
+ infer_dataset(
211
+ model_name=config.model.path,
212
+ data_path=os.path.dirname(os.path.dirname(config.data.path)),
213
+ )
214
+ except Exception as e:
215
+ print(f"[info] infer_dataset() skipped: {e}")
216
+
217
+ # ---- Tokenizer from base model
218
+ local_path = copy_to_local(config.model.path)
219
+ trust_remote_code = getattr(config.model, "trust_remote_code", False)
220
+ tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
221
+ tokenizer.padding_side = "left"
222
+ if tokenizer.pad_token is None:
223
+ tokenizer.pad_token = tokenizer.eos_token
224
+
225
+ # ---- Sampling checks
226
+ if float(config.rollout.temperature) == 0.0:
227
+ assert int(config.data.n_samples) == 1, "When temperature=0, n_samples must be 1."
228
+ assert int(config.data.n_samples) >= 1, "n_samples should always >= 1"
229
+
230
+ # ---- Load dataset
231
+ dataset = pd.read_parquet(config.data.path)
232
+ prompt_key = getattr(config.data, "prompt_key", "prompt")
233
+ if prompt_key not in dataset.columns:
234
+ raise KeyError(f"Dataset missing column '{prompt_key}'")
235
+ chat_lst = dataset[prompt_key].tolist()
236
+ chat_lst = [chat.tolist() if hasattr(chat, "tolist") else chat for chat in chat_lst]
237
+
238
+ # ---- Worker group (vLLM inside Rollout)
239
+ ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
240
+ resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
241
+ print("[debug] rollout.lora_modules =", config.rollout.get("lora_modules", None))
242
+ wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
243
+ wg.init_model() # vLLM spins up; adapter used if set in rollout.lora_modules
244
+
245
+ total = len(dataset)
246
+ bs = int(config.data.batch_size)
247
+ num_batch = -(-total // bs)
248
+ slots = [[] for _ in range(int(config.data.n_samples))]
249
+
250
+ for b in range(num_batch):
251
+ print(f"[{b+1}/{num_batch}] Start to process.")
252
+ batch_chat = chat_lst[b * bs : (b + 1) * bs]
253
+
254
+ inputs = tokenizer.apply_chat_template(
255
+ batch_chat,
256
+ add_generation_prompt=True,
257
+ padding=True,
258
+ truncation=True,
259
+ max_length=int(config.rollout.prompt_length),
260
+ return_tensors="pt",
261
+ return_dict=True,
262
+ tokenize=True,
263
+ )
264
+ input_ids = inputs["input_ids"]
265
+ attention_mask = inputs["attention_mask"]
266
+ position_ids = compute_position_id_with_mask(attention_mask)
267
+ batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}
268
+
269
+ data = DataProto.from_dict(batch_dict)
270
+ data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)
271
+
272
+ print(f"[{b+1}/{num_batch}] Start to generate.")
273
+ for n in range(int(config.data.n_samples)):
274
+ output_padded = wg.generate_sequences(data_padded)
275
+ output = unpad_dataproto(output_padded, pad_size=pad_size)
276
+ texts = []
277
+ for i in range(len(output)):
278
+ item = output[i]
279
+ pl = item.batch["prompts"].shape[-1]
280
+ valid_len = item.batch["attention_mask"][pl:].sum()
281
+ resp_ids = item.batch["responses"][:valid_len]
282
+ s = tokenizer.decode(resp_ids, skip_special_tokens=True)
283
+ print(f"[raw] Response {i}: {s!r}")
284
+ ix = s.find("</think>")
285
+ if ix != -1:
286
+ s = s[ix + len("</think>") :].lstrip()
287
+ print(f"Response {i}: {s!r}")
288
+ try:
289
+ texts.append(ast.literal_eval(s))
290
+ except Exception:
291
+ texts.append(s)
292
+ slots[n].extend(texts)
293
+
294
+ outputs = np.array(slots, dtype=object)
295
+ outputs = np.transpose(outputs, (1, 0)).tolist()
296
+ dataset["response"] = outputs
297
+
298
+ keep = ["file_id", "vt", "gt", "response"]
299
+ cols = [c for c in keep if c in dataset.columns]
300
+ if cols:
301
+ dataset = dataset[cols]
302
+
303
+ out_path = config.data.output_path
304
+ makedirs(os.path.dirname(out_path), exist_ok=True)
305
+ dataset.to_json(out_path, orient="records", lines=True, force_ascii=False)
306
+ print(f"[done] Wrote: {out_path}")
307
+
308
+ if __name__ == "__main__":
309
+ main()
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+
6
+ # Load the model and tokenizer
7
+ tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", trust_remote_code=True)
8
+ model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto")
9
+
10
+ def generate_response(prompt):
11
+ """
12
+ Generates a response from the model.
13
+ """
14
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
15
+ outputs = model.generate(**inputs, max_new_tokens=512)
16
+
17
+ # Decode the generated text
18
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
+
20
+ return generated_text
21
+
22
+ # Create the Gradio interface
23
+ with gr.Blocks() as demo:
24
+ gr.Markdown("# Text-to-Text Generation with DeepSeek-R1-Distill-Qwen-7B")
25
+ gr.Markdown("Enter a prompt and the model will generate a response.")
26
+
27
+ with gr.Row():
28
+ prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="Enter your prompt here...")
29
+
30
+ with gr.Row():
31
+ generate_button = gr.Button("Generate")
32
+
33
+ with gr.Row():
34
+ response_output = gr.Textbox(label="Response", lines=8, interactive=False)
35
+
36
+ generate_button.click(
37
+ fn=generate_response,
38
+ inputs=prompt_input,
39
+ outputs=response_output
40
+ )
41
+
42
+ if __name__ == "__main__":
43
+ demo.launch()
config/infer.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer:
2
+ nnodes: 1
3
+ n_gpus_per_node: 1
4
+
5
+ data:
6
+ path: ./data/parquet/test.parquet
7
+ prompt_key: prompt
8
+ n_samples: 1
9
+ output_path: ./checkpoints/grammar_generation.parquet
10
+ batch_size: 1
11
+
12
+ model:
13
+ path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
14
+ external_lib: null
15
+ load_param: False
16
+ load_param_path: null
17
+
18
+ rollout:
19
+ name: vllm
20
+ mode: sync # sync: LLM, async: AsyncLLM
21
+ temperature: 0.0
22
+ top_k: -1 # 0 for hf rollout, -1 for vllm rollout
23
+ top_p: 1.0
24
+ max_loras: 1
25
+ prompt_length: 1800
26
+ response_length: 512
27
+ # for vllm rollout
28
+ dtype: bfloat16 # should align with FSDP
29
+ gpu_memory_utilization: 0.9 # ↑ allow cache to allocate
30
+ ignore_eos: False
31
+ enforce_eager: True
32
+ free_cache_engine: True
33
+ load_format: dummy_dtensor
34
+ tensor_model_parallel_size: 1
35
+ max_num_batched_tokens: 8192
36
+ max_model_len: 1800 # ≥ 1200 + 512
37
+ max_num_seqs: 1024
38
+ log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
39
+ log_prob_micro_batch_size_per_gpu: 1
40
+ # for fire vllm rollout
41
+ use_fire_sampling: False # enable FIRE https://arxiv.org/abs/2410.21236
42
+ # for hf rollout
43
+ do_sample: True
44
+ disable_log_stats: False
45
+ enable_chunked_prefill: True # OK because 8192 ≥ 3072
46
+ n: 1
47
+ # if beam search activated, top_k, temperature and top_p will be ignored
48
+
49
+ actor:
50
+ strategy: fsdp # This is for backward-compatibility
51
+ ulysses_sequence_parallel_size: 1 # sp size
52
+ fsdp_config:
53
+ fsdp_size: -1
54
+
55
+ ray_init:
56
+ num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
prompts/__pycache__/base_instruction.cpython-311.pyc ADDED
Binary file (1.44 kB). View file
 
prompts/__pycache__/infer_prompt.cpython-311.pyc ADDED
Binary file (6.33 kB). View file
 
prompts/__pycache__/sft_prompt.cpython-311.pyc ADDED
Binary file (7.64 kB). View file
 
prompts/base_instruction.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def basic_instruction(content, modelname):
2
+ system_instruction = (
3
+ "당신은 한국어 문장 교정 전문가입니다. "
4
+ "입력 문장은 다양한 오류(자모 분리, 철자 오류, 단어 누락 등)를 포함할 수 있습니다. "
5
+ "당신의 임무는 이러한 잘못된 문장을 완전하고 올바른 한국어 문장으로 복원하는 것입니다.\n"
6
+ "규칙:\n"
7
+ "•출력은 반드시 교정된 한국어 문장만 작성합니다.\n"
8
+ "•불필요한 설명, 이유, 따옴표는 포함하지 않습니다.\n"
9
+ )
10
+
11
+ user_instruction = (
12
+ f"잘못된 문장(노이즈): {content}\n\n"
13
+ "위 문장을 올바른 한국어 문장으로 교정하세요.\n"
14
+ "출력은 반드시 교정된 문장 하나만 작성하세요."
15
+ )
16
+
17
+ return [
18
+ {"role": "system", "content": system_instruction},
19
+ {"role": "user", "content": user_instruction},
20
+ ]
21
+
22
+
23
+ def get_instruction_func(modelname):
24
+ return lambda desc, _: basic_instruction(desc, modelname)
prompts/infer_prompt.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prompts.base_instruction import get_instruction_func
2
+
3
+ def infer_dataset(
4
+ model_name: str,
5
+ data_path: str,
6
+ ):
7
+ import os, json
8
+ from typing import Any, Dict, List
9
+ from datasets import Dataset
10
+ from transformers import AutoTokenizer
11
+
12
+ MAX_TOKENS = 1200 # same as SFT
13
+
14
+ jsonl_path = os.path.join(data_path, "jsonl")
15
+ parquet_path = os.path.join(data_path, "parquet")
16
+ os.makedirs(parquet_path, exist_ok=True)
17
+
18
+ test_jsonl = os.path.join(jsonl_path, "test.jsonl")
19
+
20
+ # --- robust load: tolerant JSONL/array/concatenated JSON
21
+ rows = []
22
+ with open(test_jsonl, "r", encoding="utf-8") as f:
23
+ raw = f.read().strip()
24
+
25
+ try:
26
+ obj = json.loads(raw)
27
+ if isinstance(obj, list):
28
+ rows = [x for x in obj if isinstance(x, dict)]
29
+ except Exception:
30
+ pass
31
+
32
+ if not rows:
33
+ for ln in raw.replace("}{", "}\n{").splitlines():
34
+ ln = ln.strip()
35
+ if not ln:
36
+ continue
37
+ try:
38
+ x = json.loads(ln)
39
+ if isinstance(x, dict):
40
+ rows.append(x)
41
+ except Exception:
42
+ continue
43
+
44
+ test_dataset = Dataset.from_list(rows)
45
+
46
+ instruction_func = get_instruction_func(model_name)
47
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
48
+
49
+ # ─── helpers ───
50
+ def _coerce(rec: Dict[str, Any]) -> Dict[str, Any]:
51
+ r = dict(rec)
52
+ r["vt"] = str(r.get("vt", "") or "")
53
+ return r
54
+
55
+ def _prompt_tokens(prompt_messages) -> int:
56
+ prompt_str = tokenizer.apply_chat_template(
57
+ prompt_messages, add_generation_prompt=True, tokenize=False
58
+ )
59
+ return len(tokenizer(prompt_str, add_special_tokens=False).input_ids)
60
+
61
+ def make_map_fn(split: str):
62
+ def process_fn(example: Dict[str, Any], idx: int) -> Dict[str, Any]:
63
+ ex = _coerce(example)
64
+ vt = ex.get("vt", "").strip()
65
+ if not vt:
66
+ return {}
67
+
68
+ chat_prompt = instruction_func(vt, model_name)
69
+ total_tokens = _prompt_tokens(chat_prompt)
70
+
71
+ extra = {
72
+ "split": split,
73
+ "index": idx,
74
+ "total_tokens": int(total_tokens),
75
+ "file_id": ex.get("file_id")
76
+ }
77
+
78
+ return {
79
+ "prompt": chat_prompt,
80
+ "extra_info": extra,
81
+ "total_tokens": int(total_tokens)
82
+ }
83
+ return process_fn
84
+
85
+ # build prompts + token counts
86
+ test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
87
+
88
+ # drop rows where prompt is empty
89
+ test_dataset = test_dataset.filter(lambda ex: bool(ex.get("prompt")))
90
+
91
+ # drop long prompts (> MAX_TOKENS)
92
+ n_before_len = len(test_dataset)
93
+ test_dataset = test_dataset.filter(lambda ex: ex["total_tokens"] <= MAX_TOKENS)
94
+ kept = len(test_dataset)
95
+ dropped_long = n_before_len - kept
96
+
97
+ out_path = os.path.join(parquet_path, "test.parquet")
98
+ test_dataset.to_parquet(out_path)
99
+
100
+ print(f"[test] kept {kept} rows, dropped_long(>{MAX_TOKENS}) {dropped_long}")
101
+ print(f"Wrote {kept} rows → {out_path}")
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ accelerate
requirements.txt.txt ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ accelerate==1.6.0
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.11.16
5
+ aiohttp-cors==0.8.1
6
+ aiosignal==1.3.2
7
+ airportsdata==20250224
8
+ alabaster==1.0.0
9
+ annotated-types==0.7.0
10
+ antlr4-python3-runtime==4.9.3
11
+ anyio==4.9.0
12
+ astor==0.8.1
13
+ attrs==25.3.0
14
+ babel==2.17.0
15
+ blake3==1.0.4
16
+ cachetools==5.5.2
17
+ certifi==2025.1.31
18
+ charset-normalizer==3.4.1
19
+ click==8.1.8
20
+ cloudpickle==3.1.1
21
+ codetiming==1.4.0
22
+ colorful==0.5.6
23
+ compressed-tensors==0.9.2
24
+ cupy-cuda12x==13.4.1
25
+ datasets==3.5.0
26
+ depyf==0.18.0
27
+ dill==0.3.8
28
+ diskcache==5.6.3
29
+ distlib==0.3.9
30
+ distro==1.9.0
31
+ dnspython==2.7.0
32
+ docker-pycreds==0.4.0
33
+ docutils==0.21.2
34
+ einops==0.8.1
35
+ email_validator==2.2.0
36
+ fastapi==0.115.12
37
+ fastapi-cli==0.0.7
38
+ fastrlock==0.8.3
39
+ filelock==3.18.0
40
+ flash_attn==2.7.4.post1
41
+ flashinfer-python==0.2.5
42
+ frozenlist==1.5.0
43
+ fsspec==2024.6.1
44
+ gguf==0.10.0
45
+ gitdb==4.0.12
46
+ GitPython==3.1.44
47
+ google-api-core==2.24.2
48
+ google-auth==2.38.0
49
+ googleapis-common-protos==1.69.2
50
+ grpcio==1.71.0
51
+ h11==0.14.0
52
+ httpcore==1.0.8
53
+ httptools==0.6.4
54
+ httpx==0.28.1
55
+ huggingface-hub==0.30.2
56
+ hydra-core==1.3.2
57
+ idna==3.10
58
+ imagesize==1.4.1
59
+ importlib_metadata==8.6.1
60
+ interegular==0.3.3
61
+ Jinja2==3.1.6
62
+ jiter==0.9.0
63
+ jiwer==4.0.0
64
+ joblib==1.5.2
65
+ jsonschema==4.23.0
66
+ jsonschema-specifications==2024.10.1
67
+ lark==1.2.2
68
+ Levenshtein==0.27.1
69
+ liger_kernel==0.5.9
70
+ llguidance==0.7.14
71
+ llvmlite==0.43.0
72
+ lm-format-enforcer==0.10.11
73
+ markdown-it-py==3.0.0
74
+ MarkupSafe==2.1.5
75
+ mdurl==0.1.2
76
+ mistral_common==1.5.4
77
+ mpmath==1.3.0
78
+ msgpack==1.1.0
79
+ msgspec==0.19.0
80
+ multidict==6.4.3
81
+ multiprocess==0.70.16
82
+ nest-asyncio==1.6.0
83
+ networkx==3.3
84
+ ninja==1.11.1.4
85
+ nltk==3.9.1
86
+ numba==0.60.0
87
+ numpy==1.26.4
88
+ nvidia-cublas-cu12==12.4.5.8
89
+ nvidia-cuda-cupti-cu12==12.4.127
90
+ nvidia-cuda-nvrtc-cu12==12.4.127
91
+ nvidia-cuda-runtime-cu12==12.4.127
92
+ nvidia-cudnn-cu12==9.1.0.70
93
+ nvidia-cufft-cu12==11.2.1.3
94
+ nvidia-curand-cu12==10.3.5.147
95
+ nvidia-cusolver-cu12==11.6.1.9
96
+ nvidia-cusparse-cu12==12.3.1.170
97
+ nvidia-cusparselt-cu12==0.6.2
98
+ nvidia-ml-py==12.570.86
99
+ nvidia-nccl-cu12==2.21.5
100
+ nvidia-nvjitlink-cu12==12.4.127
101
+ nvidia-nvtx-cu12==12.4.127
102
+ omegaconf==2.3.0
103
+ openai==1.73.0
104
+ opencensus==0.11.4
105
+ opencensus-context==0.1.3
106
+ opencv-python-headless==4.11.0.86
107
+ orjson==3.10.16
108
+ outlines==0.1.11
109
+ outlines_core==0.1.26
110
+ packaging==24.2
111
+ pandas==2.2.3
112
+ partial-json-parser==0.2.1.1.post5
113
+ peft==0.15.1
114
+ pillow==11.2.1
115
+ platformdirs==4.3.7
116
+ prometheus-fastapi-instrumentator==7.1.0
117
+ prometheus_client==0.21.1
118
+ propcache==0.3.1
119
+ proto-plus==1.26.1
120
+ protobuf==5.29.4
121
+ psutil==7.0.0
122
+ py-cpuinfo==9.0.0
123
+ py-spy==0.4.0
124
+ pyarrow==19.0.1
125
+ pyasn1==0.6.1
126
+ pyasn1_modules==0.4.2
127
+ pybind11==2.13.6
128
+ pycountry==24.6.1
129
+ pydantic==2.11.3
130
+ pydantic_core==2.33.1
131
+ Pygments==2.19.1
132
+ pylatexenc==2.10
133
+ python-dateutil==2.9.0.post0
134
+ python-dotenv==1.1.0
135
+ python-json-logger==3.3.0
136
+ python-Levenshtein==0.27.1
137
+ python-multipart==0.0.20
138
+ pytz==2025.2
139
+ PyYAML==6.0.2
140
+ pyzmq==26.4.0
141
+ RapidFuzz==3.14.1
142
+ ray==2.44.1
143
+ referencing==0.36.2
144
+ regex==2024.11.6
145
+ requests==2.32.3
146
+ rich==14.0.0
147
+ rich-toolkit==0.14.1
148
+ roman-numerals-py==3.1.0
149
+ rouge_score==0.1.2
150
+ rpds-py==0.24.0
151
+ rsa==4.9
152
+ safetensors==0.5.3
153
+ scipy==1.15.2
154
+ sentencepiece==0.2.0
155
+ sentry-sdk==2.25.1
156
+ setproctitle==1.3.5
157
+ shellingham==1.5.4
158
+ six==1.17.0
159
+ smart-open==7.1.0
160
+ smmap==5.0.2
161
+ sniffio==1.3.1
162
+ snowballstemmer==2.2.0
163
+ Sphinx==8.2.3
164
+ sphinxcontrib-applehelp==2.0.0
165
+ sphinxcontrib-devhelp==2.0.0
166
+ sphinxcontrib-htmlhelp==2.1.0
167
+ sphinxcontrib-jsmath==1.0.1
168
+ sphinxcontrib-qthelp==2.0.0
169
+ sphinxcontrib-serializinghtml==2.0.0
170
+ starlette==0.46.1
171
+ sympy==1.13.1
172
+ tensordict==0.6.2
173
+ tiktoken==0.9.0
174
+ timeout-decorator==0.5.0
175
+ tokenizers==0.21.1
176
+ torch==2.6.0
177
+ torchaudio==2.6.0
178
+ torchdata==0.11.0
179
+ torchvision==0.21.0
180
+ tqdm==4.67.1
181
+ transformers==4.51.2
182
+ triton==3.2.0
183
+ typer==0.15.2
184
+ typing-inspection==0.4.0
185
+ typing_extensions==4.12.2
186
+ tzdata==2025.2
187
+ urllib3==2.4.0
188
+ uvicorn==0.34.0
189
+ uvloop==0.21.0
190
+ virtualenv==20.30.0
191
+ vllm==0.8.2
192
+ wandb==0.19.9
193
+ watchfiles==1.0.5
194
+ websockets==15.0.1
195
+ wrapt==1.17.2
196
+ xformers==0.0.29.post2
197
+ xgrammar==0.1.16
198
+ xxhash==3.5.0
199
+ yarl==1.19.0
200
+ zipp==3.21.0
scripts/sft_infer_pass1.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #!/bin/bash
4
+ set -x
5
+
6
+ python ./_infer.py \
7
+ model.path=./checkpoints/model \
8
+ model.load_param=False \
9
+ data.path=./data/parquet/test.parquet \
10
+ data.output_path=./model_output/[email protected] \
11
+ data.batch_size=32 data.n_samples=1 \
12
+ rollout.tensor_model_parallel_size=1 \
13
+ rollout.temperature=0.7 rollout.top_p=0.9 rollout.n=1 rollout.do_sample=True \
14
+ rollout.prompt_length=1200 rollout.response_length=512 \
15
+ rollout.enable_chunked_prefill=True \
16
+ +rollout.kv_cache_dtype=fp8_e5m2 \
17
+ rollout.max_model_len=1800 \
18
+ rollout.max_num_batched_tokens=1800 \
19
+ rollout.max_num_seqs=1 \
20
+ +model.trust_remote_code=True \
21
+ +rollout.kv_cache_block_size=16 \
22
+ +rollout.swap_space=16 \
23
+ rollout.gpu_memory_utilization=0.7
24
+
25
+ # python ./_infer.py \
26
+ # model.load_param=True \
27
+ # model.load_param_path="./checkpoints/merged_r1qwen14b/model.pt" \
28
+ # data.output_path="./model_output/[email protected]" \
29
+ # data.n_samples=10\
30
+ # data.path="./data/parquet/test.parquet" \
31
+ # rollout.temperature=0.9\
32
+ # rollout.top_p=0.9 \
33
+ # rollout.n=1\
verl/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+
18
+ from .protocol import DataProto
19
+ from .utils.logging_utils import set_basic_config
20
+
21
+ version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
22
+
23
+ with open(os.path.join(version_folder, "version/version")) as f:
24
+ __version__ = f.read().strip()
25
+
26
+
27
+ set_basic_config(level=logging.WARNING)
28
+
29
+
30
+ __all__ = ["DataProto", "__version__"]
31
+
32
+ if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true":
33
+ import importlib
34
+
35
+ if importlib.util.find_spec("modelscope") is None:
36
+ raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`")
37
+ # Patch hub to download models from modelscope to speed up.
38
+ from modelscope.utils.hf_util import patch_hub
39
+
40
+ patch_hub()
verl/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.69 kB). View file
 
verl/__pycache__/protocol.cpython-311.pyc ADDED
Binary file (47 kB). View file
 
verl/models/README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Models
2
+ Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl.
3
+ ## Adding a New Huggingface Model
4
+ ### Step 1: Copy the model file from HF to verl
5
+ - Add a new file under verl/models/hf
6
+ - Copy ONLY the model file from huggingface/transformers/models to verl/models/hf
7
+
8
+ ### Step 2: Modify the model file to use packed inputs
9
+ - Remove all the code related to inference (kv cache)
10
+ - Modify the inputs to include only
11
+ - input_ids (total_nnz,)
12
+ - cu_seqlens (total_nnz + 1,)
13
+ - max_seqlen_in_batch: int
14
+ - Note that this requires using flash attention with causal mask.
15
+
16
+ ### Step 2.5: Add tests
17
+ - Add a test to compare this version and the huggingface version
18
+ - Following the infrastructure and add tests to tests/models/hf
19
+
20
+ ### Step 3: Add a function to apply tensor parallelism
21
+ - Please follow
22
+ - https://pytorch.org/docs/stable/distributed.tensor.parallel.html
23
+ - https://pytorch.org/tutorials/intermediate/TP_tutorial.html
24
+ - General comments
25
+ - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward.
26
+
27
+ ### Step 4: Add a function to apply data parallelism
28
+ - Please use FSDP2 APIs
29
+ - See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413
30
+
31
+ ### Step 5: Add a function to apply pipeline parallelism
32
+ - Comes in Pytorch 2.4
33
+ - Currently only in alpha in nightly version
34
+ - Check torchtitan for more details
35
+
verl/models/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
verl/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (176 Bytes). View file
 
verl/models/__pycache__/registry.cpython-311.pyc ADDED
Binary file (2.13 kB). View file
 
verl/models/llama/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
verl/models/llama/megatron/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .modeling_llama_megatron import (
16
+ ParallelLlamaForCausalLM,
17
+ # rmpad with megatron
18
+ ParallelLlamaForCausalLMRmPad,
19
+ # rmpad with megatron and pipeline parallelism
20
+ ParallelLlamaForCausalLMRmPadPP,
21
+ ParallelLlamaForValueRmPad,
22
+ ParallelLlamaForValueRmPadPP,
23
+ # original model with megatron
24
+ ParallelLlamaModel,
25
+ )
26
+
27
+ __all__ = [
28
+ "ParallelLlamaForCausalLM",
29
+ "ParallelLlamaForCausalLMRmPad",
30
+ "ParallelLlamaForCausalLMRmPadPP",
31
+ "ParallelLlamaForValueRmPad",
32
+ "ParallelLlamaForValueRmPadPP",
33
+ "ParallelLlamaModel",
34
+ ]
verl/models/llama/megatron/checkpoint_utils/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
verl/models/llama/megatron/checkpoint_utils/llama_loader.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+
20
+
21
+ def _megatron_calc_layer_map(config):
22
+ """Calculate the mapping of global layer_idx to local layer_idx
23
+ Returns:
24
+ layer_map (Dict: int -> tuple(int, int, int)):
25
+ mapping from the global layer index to
26
+ a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
27
+ """
28
+ from megatron.core import mpu
29
+
30
+ print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}")
31
+
32
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
33
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
34
+
35
+ layer_map = dict()
36
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
37
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
38
+
39
+ for pp_rank_idx in range(pp_size):
40
+ for virtual_pp_rank_idx in range(virtual_pp_size):
41
+ layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
42
+ for layer_idx in range(num_layers_per_model):
43
+ layer_map[layer_offset + layer_idx] = (
44
+ pp_rank_idx,
45
+ virtual_pp_rank_idx,
46
+ layer_idx,
47
+ )
48
+ return layer_map
49
+
50
+
51
+ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):
52
+ """Load merged state_dict to sharded Megatron module in training."""
53
+ from megatron.core import DistributedDataParallel as LocalDDP
54
+ from megatron.core import mpu
55
+ from megatron.core.transformer.module import Float16Module
56
+ from torch.nn.parallel import DistributedDataParallel as torchDDP
57
+
58
+ from verl.utils.megatron_utils import print_rank_0, unwrap_model
59
+
60
+ start_time = time.time()
61
+
62
+ def _get_gpt_model(model):
63
+ return model
64
+
65
+ def fetch_params(module):
66
+ for param in module.parameters():
67
+ torch.distributed.fetch(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())
68
+
69
+ dp_rank = mpu.get_data_parallel_rank()
70
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
71
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
72
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
73
+ mp_group = mpu.get_model_parallel_group()
74
+
75
+ if torch.distributed.get_rank() == 0:
76
+ assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
77
+ assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
78
+ assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
79
+
80
+ if not isinstance(wrapped_models, (list, tuple)):
81
+ wrapped_models = list(wrapped_models)
82
+
83
+ assert len(wrapped_models) == virtual_pp_size
84
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
85
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}"
86
+
87
+ models = [None] * len(wrapped_models)
88
+
89
+ for i, wrapped_model in enumerate(wrapped_models):
90
+ models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
91
+ gpt_model_module = _get_gpt_model(models[i])
92
+ assert len(gpt_model_module.model.layers) == num_layers_per_model
93
+
94
+ def _fetch_tensor(tensor, name) -> torch.Tensor:
95
+ """fetch tensor"""
96
+ nonlocal state_dict
97
+ if tensor is not None:
98
+ tensor.data.copy_(state_dict[name])
99
+
100
+ def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
101
+ """fetch tensor in tp shards"""
102
+ nonlocal state_dict
103
+ tp_rank = mpu.get_tensor_model_parallel_rank()
104
+ tp_size = mpu.get_tensor_model_parallel_world_size()
105
+ if name in state_dict:
106
+ full_weight = state_dict[name]
107
+
108
+ if mutate_func is not None:
109
+ full_weight = mutate_func(full_weight)
110
+ tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
111
+ if tensor is not None:
112
+ tensor.data.copy_(tensor_chunk[tp_rank])
113
+ else:
114
+ print(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
115
+
116
+ def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
117
+ """fetch tensor in tp shards"""
118
+ nonlocal state_dict
119
+ tp_rank = mpu.get_tensor_model_parallel_rank()
120
+ tp_size = mpu.get_tensor_model_parallel_world_size()
121
+ if name in state_dict:
122
+ full_weight = state_dict[name]
123
+
124
+ if mutate_func is not None:
125
+ full_weight = mutate_func(full_weight)
126
+ tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
127
+ if tensor is not None:
128
+ tensor.data.copy_(tensor_chunk[tp_rank])
129
+ else:
130
+ print(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
131
+
132
+ def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
133
+ """fetch gate_up tensor in tp shards"""
134
+ nonlocal state_dict
135
+ nonlocal mp_group
136
+ tp_rank = mpu.get_tensor_model_parallel_rank()
137
+ tp_size = mpu.get_tensor_model_parallel_world_size()
138
+ if gate_name in state_dict and up_name in state_dict:
139
+ gate_weight = state_dict[gate_name]
140
+ up_weight = state_dict[up_name]
141
+ new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
142
+ for i in range(tp_size):
143
+ intermediate_size_tp = config.intermediate_size // tp_size
144
+ gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
145
+ up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
146
+ new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))
147
+
148
+ tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
149
+ if tensor is not None:
150
+ tensor.data.copy_(tensor_chunk[tp_rank])
151
+ else:
152
+ print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading")
153
+
154
+ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:
155
+ """fetch tensor in tp shards across mp_group"""
156
+ nonlocal state_dict
157
+ nonlocal mp_group
158
+ tp_rank = mpu.get_tensor_model_parallel_rank()
159
+ tp_size = mpu.get_tensor_model_parallel_world_size()
160
+ assert q_name in state_dict and k_name in state_dict and v_name in state_dict
161
+ full_weight_q = state_dict[q_name]
162
+ full_weight_k = state_dict[k_name]
163
+ full_weight_v = state_dict[v_name]
164
+
165
+ hidden_size_per_head = config.hidden_size // config.num_attention_heads
166
+
167
+ if config.num_key_value_heads >= tp_size:
168
+ q_size_tp = config.hidden_size // tp_size
169
+ kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
170
+ total_size = q_size_tp + 2 * kv_size_tp
171
+ new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
172
+ for i in range(tp_size):
173
+ q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
174
+ k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
175
+ v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
176
+ new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
177
+
178
+ else:
179
+ q_size_tp = config.hidden_size // tp_size
180
+ kv_size_tp = hidden_size_per_head
181
+ total_size = q_size_tp + 2 * kv_size_tp
182
+ new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
183
+ for i in range(tp_size):
184
+ q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
185
+ start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
186
+ end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
187
+ k_part = full_weight_k[start_idx:end_idx]
188
+ v_part = full_weight_v[start_idx:end_idx]
189
+ new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
190
+
191
+ tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
192
+ if tensor is not None:
193
+ tensor.data.copy_(tensor_chunk[tp_rank])
194
+
195
+ # Embeddings
196
+ # -------------------
197
+ print_rank_0("loading embeddings...")
198
+ gpt_model_module = _get_gpt_model(models[0])
199
+ embed_tokens_weight = None
200
+ if pp_rank == 0:
201
+ embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
202
+ _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
203
+
204
+ # Transformer layers
205
+ # -------------------
206
+ layer_map = _megatron_calc_layer_map(config)
207
+
208
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
209
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
210
+ num_layer_per_pp = config.num_hidden_layers // pp_size
211
+ vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
212
+
213
+ layer_list = []
214
+ if vpp_size is not None:
215
+ for vpp_rank in range(vpp_size):
216
+ num_layer_vpp_chunk = num_layer_per_pp // vpp_size
217
+ num_layer_this_model = num_layer_vpp_chunk
218
+ offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk)
219
+ layer_list.extend(list(range(offset, offset + num_layer_this_model)))
220
+ else:
221
+ num_layer_this_model = num_layer_per_pp
222
+ offset = pp_rank * num_layer_per_pp
223
+ layer_list.extend(list(range(offset, offset + num_layer_this_model)))
224
+
225
+ for layer in layer_list:
226
+ print_rank_0(f"loading layer #{layer}...")
227
+ layer_name = f"model.layers.{layer}"
228
+ dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
229
+
230
+ gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
231
+ sync_layer = gpt_model_module.model.layers[dst_layer_idx]
232
+
233
+ _fetch_tensor(
234
+ sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
235
+ f"{layer_name}.input_layernorm.weight",
236
+ )
237
+
238
+ _fetch_tp_shard_tensor_qkv(
239
+ sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
240
+ f"{layer_name}.self_attn.q_proj.weight",
241
+ f"{layer_name}.self_attn.k_proj.weight",
242
+ f"{layer_name}.self_attn.v_proj.weight",
243
+ )
244
+
245
+ _fetch_tp_shard_tensor(
246
+ sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
247
+ f"{layer_name}.self_attn.o_proj.weight",
248
+ chunk_dim=1,
249
+ )
250
+
251
+ _fetch_tensor(
252
+ sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
253
+ f"{layer_name}.post_attention_layernorm.weight",
254
+ )
255
+
256
+ _fetch_tp_shard_tensor_gate_up(
257
+ sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
258
+ f"{layer_name}.mlp.gate_proj.weight",
259
+ f"{layer_name}.mlp.up_proj.weight",
260
+ )
261
+
262
+ _fetch_tp_shard_tensor(
263
+ sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
264
+ f"{layer_name}.mlp.down_proj.weight",
265
+ chunk_dim=1,
266
+ )
267
+ # Final Layernorm
268
+ # -------------------
269
+ print_rank_0("loading final layernorm...")
270
+ gpt_model_module = _get_gpt_model(models[-1])
271
+ _fetch_tensor(
272
+ getattr(gpt_model_module.model.norm, "weight", None),
273
+ "model.norm.weight",
274
+ )
275
+
276
+ print_rank_0("loading lm_head...")
277
+ if pp_rank + 1 == pp_size:
278
+ lm_head_weight = gpt_model_module.lm_head.weight
279
+
280
+ if is_value_model:
281
+ if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
282
+ _fetch_tensor(lm_head_weight, "lm_head.weight")
283
+ print_rank_0("load lm_head weight")
284
+ elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
285
+ _fetch_tensor(lm_head_weight, "reward_head.weight")
286
+ print_rank_0("load lm_head from value_head weight")
287
+ else:
288
+ _fetch_tensor(None, "lm_head.weight")
289
+ print_rank_0("fail to match lm_head in value_model")
290
+ else:
291
+ _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight")
292
+
293
+ dist.barrier()
294
+ torch.cuda.empty_cache()
295
+ print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+
20
+
21
+ def _megatron_calc_layer_map(config):
22
+ """Calculate the mapping of global layer_idx to local layer_idx
23
+ Returns:
24
+ layer_map (Dict: int -> tuple(int, int, int)):
25
+ mapping from the global layer index to
26
+ a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
27
+ """
28
+ from megatron.core import mpu
29
+
30
+ print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}")
31
+
32
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
33
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
34
+
35
+ layer_map = dict()
36
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
37
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
38
+
39
+ for pp_rank_idx in range(pp_size):
40
+ for virtual_pp_rank_idx in range(virtual_pp_size):
41
+ layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
42
+ for layer_idx in range(num_layers_per_model):
43
+ layer_map[layer_offset + layer_idx] = (
44
+ pp_rank_idx,
45
+ virtual_pp_rank_idx,
46
+ layer_idx,
47
+ )
48
+ return layer_map
49
+
50
+
51
+ def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):
52
+ """Load merged state_dict to sharded Megatron module in training."""
53
+ from megatron.core import DistributedDataParallel as LocalDDP
54
+ from megatron.core import mpu
55
+ from megatron.core.transformer.module import Float16Module
56
+ from torch.nn.parallel import DistributedDataParallel as torchDDP
57
+
58
+ from verl.utils.megatron_utils import print_rank_0, unwrap_model
59
+
60
+ start_time = time.time()
61
+
62
+ def _get_gpt_model(model):
63
+ return model
64
+
65
+ def broadcast_params(module):
66
+ for param in module.parameters():
67
+ torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())
68
+
69
+ dp_rank = mpu.get_data_parallel_rank()
70
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
71
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
72
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
73
+ mp_group = mpu.get_model_parallel_group()
74
+
75
+ if torch.distributed.get_rank() == 0:
76
+ assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
77
+ assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
78
+ assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
79
+
80
+ if not isinstance(wrapped_models, (list, tuple)):
81
+ wrapped_models = list(wrapped_models)
82
+
83
+ assert len(wrapped_models) == virtual_pp_size
84
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
85
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}"
86
+
87
+ models = [None] * len(wrapped_models)
88
+
89
+ for i, wrapped_model in enumerate(wrapped_models):
90
+ models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
91
+ gpt_model_module = _get_gpt_model(models[i])
92
+ assert len(gpt_model_module.model.layers) == num_layers_per_model
93
+
94
+ def _broadcast_tensor(tensor, name) -> torch.Tensor:
95
+ """broadcast tensor from rank0 across mp_group"""
96
+ nonlocal state_dict
97
+ nonlocal mp_group
98
+ if torch.distributed.get_rank() == 0:
99
+ if name in state_dict:
100
+ weight = state_dict[name]
101
+ tensor_shape = weight.shape
102
+ else:
103
+ tensor_shape = None
104
+ else:
105
+ weight = None
106
+ tensor_shape = None
107
+
108
+ obj_list = [tensor_shape]
109
+ dist.broadcast_object_list(obj_list, src=0, group=mp_group)
110
+ tensor_shape = obj_list[0]
111
+
112
+ if tensor_shape is None:
113
+ # all or none ranks in the mp_group should reach here
114
+ print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
115
+ return
116
+
117
+ if tensor is None:
118
+ tensor = torch.empty(
119
+ tensor_shape,
120
+ dtype=params_dtype,
121
+ device=torch.cuda.current_device(),
122
+ requires_grad=False,
123
+ )
124
+ if torch.distributed.get_rank() == 0:
125
+ tensor.data.copy_(weight)
126
+ dist.broadcast(tensor, src=0, group=mp_group)
127
+
128
+ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
129
+ """broadcast tensor in tp shards across mp_group"""
130
+ nonlocal state_dict
131
+ nonlocal mp_group
132
+ tp_rank = mpu.get_tensor_model_parallel_rank()
133
+ tp_size = mpu.get_tensor_model_parallel_world_size()
134
+
135
+ if torch.distributed.get_rank() == 0:
136
+ if name in state_dict:
137
+ full_weight = state_dict[name]
138
+
139
+ if mutate_func is not None:
140
+ full_weight = mutate_func(full_weight)
141
+ tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
142
+ chunk_shape = tensor_chunk[0].shape
143
+ else:
144
+ chunk_shape = None
145
+ else:
146
+ chunk_shape = None
147
+
148
+ obj_list = [chunk_shape]
149
+ dist.broadcast_object_list(obj_list, src=0, group=mp_group)
150
+ chunk_shape = obj_list[0]
151
+ if chunk_shape is None:
152
+ # all or none ranks in the mp_group should reach here
153
+ print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
154
+ return
155
+
156
+ if tensor is None:
157
+ sync_tensor = torch.empty(
158
+ chunk_shape,
159
+ dtype=params_dtype,
160
+ device=torch.cuda.current_device(),
161
+ requires_grad=False,
162
+ )
163
+ else:
164
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
165
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
166
+
167
+ for i in range(tp_size):
168
+ if torch.distributed.get_rank() == 0:
169
+ sync_tensor.data.copy_(tensor_chunk[i])
170
+ dist.broadcast(sync_tensor, src=0, group=mp_group)
171
+ if (i == tp_rank) and (tensor is not None):
172
+ tensor.data.copy_(sync_tensor)
173
+
174
+ def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
175
+ """broadcast tensor in tp shards across mp_group"""
176
+ nonlocal state_dict
177
+ nonlocal mp_group
178
+ tp_rank = mpu.get_tensor_model_parallel_rank()
179
+ tp_size = mpu.get_tensor_model_parallel_world_size()
180
+
181
+ if torch.distributed.get_rank() == 0:
182
+ if name in state_dict:
183
+ full_weight = state_dict[name]
184
+ if mutate_func is not None:
185
+ full_weight = mutate_func(full_weight)
186
+ tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
187
+ chunk_shape = tensor_chunk[0].shape
188
+ else:
189
+ chunk_shape = None
190
+ else:
191
+ chunk_shape = None
192
+
193
+ obj_list = [chunk_shape]
194
+ dist.broadcast_object_list(obj_list, src=0, group=mp_group)
195
+ chunk_shape = obj_list[0]
196
+ if chunk_shape is None:
197
+ # all or none ranks in the mp_group should reach here
198
+ print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
199
+ return
200
+
201
+ if tensor is None:
202
+ sync_tensor = torch.empty(
203
+ chunk_shape,
204
+ dtype=params_dtype,
205
+ device=torch.cuda.current_device(),
206
+ requires_grad=False,
207
+ )
208
+ else:
209
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
210
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
211
+
212
+ for i in range(tp_size):
213
+ if torch.distributed.get_rank() == 0:
214
+ sync_tensor.data.copy_(tensor_chunk[i])
215
+ dist.broadcast(sync_tensor, src=0, group=mp_group)
216
+ if (i == tp_rank) and (tensor is not None):
217
+ tensor.data.copy_(sync_tensor)
218
+
219
+ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
220
+ """broadcast tensor in tp shards across mp_group"""
221
+ nonlocal state_dict
222
+ nonlocal mp_group
223
+ tp_rank = mpu.get_tensor_model_parallel_rank()
224
+ tp_size = mpu.get_tensor_model_parallel_world_size()
225
+
226
+ if torch.distributed.get_rank() == 0:
227
+ gate_weight = state_dict[gate_name]
228
+ up_weight = state_dict[up_name]
229
+ new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
230
+ for i in range(tp_size):
231
+ intermediate_size_tp = config.intermediate_size // tp_size
232
+ gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
233
+ up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
234
+ new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))
235
+
236
+ tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
237
+ chunk_shape = tensor_chunk[0].shape
238
+ else:
239
+ chunk_shape = None
240
+
241
+ obj_list = [chunk_shape]
242
+ dist.broadcast_object_list(obj_list, src=0, group=mp_group)
243
+ chunk_shape = obj_list[0]
244
+ if chunk_shape is None:
245
+ # all or none ranks in the mp_group should reach here
246
+ print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
247
+ return
248
+
249
+ if tensor is None:
250
+ sync_tensor = torch.empty(
251
+ chunk_shape,
252
+ dtype=params_dtype,
253
+ device=torch.cuda.current_device(),
254
+ requires_grad=False,
255
+ )
256
+ else:
257
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
258
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
259
+
260
+ for i in range(tp_size):
261
+ if torch.distributed.get_rank() == 0:
262
+ sync_tensor.data.copy_(tensor_chunk[i])
263
+ dist.broadcast(sync_tensor, src=0, group=mp_group)
264
+ if (i == tp_rank) and (tensor is not None):
265
+ tensor.data.copy_(sync_tensor)
266
+
267
+ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:
268
+ """broadcast tensor in tp shards across mp_group"""
269
+ nonlocal state_dict
270
+ nonlocal mp_group
271
+ tp_rank = mpu.get_tensor_model_parallel_rank()
272
+ tp_size = mpu.get_tensor_model_parallel_world_size()
273
+
274
+ if torch.distributed.get_rank() == 0:
275
+ assert q_name in state_dict and k_name in state_dict and v_name in state_dict
276
+ full_weight_q = state_dict[q_name]
277
+ full_weight_k = state_dict[k_name]
278
+ full_weight_v = state_dict[v_name]
279
+
280
+ hidden_size_per_head = config.hidden_size // config.num_attention_heads
281
+
282
+ if config.num_key_value_heads >= tp_size:
283
+ q_size_tp = config.hidden_size // tp_size
284
+ kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
285
+ total_size = q_size_tp + 2 * kv_size_tp
286
+ new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
287
+ for i in range(tp_size):
288
+ q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
289
+ k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
290
+ v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
291
+ new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
292
+
293
+ else:
294
+ q_size_tp = config.hidden_size // tp_size
295
+ kv_size_tp = hidden_size_per_head
296
+ total_size = q_size_tp + 2 * kv_size_tp
297
+ new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
298
+ for i in range(tp_size):
299
+ q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
300
+ start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
301
+ end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
302
+ k_part = full_weight_k[start_idx:end_idx]
303
+ v_part = full_weight_v[start_idx:end_idx]
304
+ new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
305
+
306
+ tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
307
+ chunk_shape = tensor_chunk[0].shape
308
+ else:
309
+ chunk_shape = None
310
+
311
+ obj_list = [chunk_shape]
312
+ dist.broadcast_object_list(obj_list, src=0, group=mp_group)
313
+ chunk_shape = obj_list[0]
314
+ if chunk_shape is None:
315
+ # all or none ranks in the mp_group should reach here
316
+ print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading")
317
+ return
318
+
319
+ if tensor is None:
320
+ sync_tensor = torch.empty(
321
+ chunk_shape,
322
+ dtype=params_dtype,
323
+ device=torch.cuda.current_device(),
324
+ requires_grad=False,
325
+ )
326
+ else:
327
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
328
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
329
+
330
+ for i in range(tp_size):
331
+ if torch.distributed.get_rank() == 0:
332
+ sync_tensor.data.copy_(tensor_chunk[i])
333
+ dist.broadcast(sync_tensor, src=0, group=mp_group)
334
+ if (i == tp_rank) and (tensor is not None):
335
+ tensor.data.copy_(sync_tensor)
336
+
337
+ if dp_rank == 0:
338
+ # Embeddings
339
+ # -------------------
340
+ print_rank_0("loading embeddings...")
341
+ gpt_model_module = _get_gpt_model(models[0])
342
+ embed_tokens_weight = None
343
+ if pp_rank == 0:
344
+ embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
345
+ _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
346
+
347
+ # Transformer layers
348
+ # -------------------
349
+ layer_map = _megatron_calc_layer_map(config)
350
+
351
+ for layer in range(config.num_hidden_layers):
352
+ print_rank_0(f"loading layer #{layer}...")
353
+ layer_name = f"model.layers.{layer}"
354
+ dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
355
+
356
+ gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
357
+ sync_layer = gpt_model_module.model.layers[dst_layer_idx]
358
+
359
+ _broadcast_tensor(
360
+ sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
361
+ f"{layer_name}.input_layernorm.weight",
362
+ )
363
+
364
+ _broadcast_tp_shard_tensor_qkv(
365
+ sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
366
+ f"{layer_name}.self_attn.q_proj.weight",
367
+ f"{layer_name}.self_attn.k_proj.weight",
368
+ f"{layer_name}.self_attn.v_proj.weight",
369
+ )
370
+
371
+ _broadcast_tp_shard_tensor(
372
+ sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
373
+ f"{layer_name}.self_attn.o_proj.weight",
374
+ chunk_dim=1,
375
+ )
376
+
377
+ _broadcast_tensor(
378
+ sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
379
+ f"{layer_name}.post_attention_layernorm.weight",
380
+ )
381
+
382
+ _broadcast_tp_shard_tensor_gate_up(
383
+ sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
384
+ f"{layer_name}.mlp.gate_proj.weight",
385
+ f"{layer_name}.mlp.up_proj.weight",
386
+ )
387
+
388
+ _broadcast_tp_shard_tensor(
389
+ sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
390
+ f"{layer_name}.mlp.down_proj.weight",
391
+ chunk_dim=1,
392
+ )
393
+ # Final Layernorm
394
+ # -------------------
395
+ print_rank_0("loading final layernorm...")
396
+ gpt_model_module = _get_gpt_model(models[-1])
397
+ _broadcast_tensor(
398
+ getattr(gpt_model_module.model.norm, "weight", None),
399
+ "model.norm.weight",
400
+ )
401
+
402
+ print_rank_0("loading lm_head...")
403
+ lm_head_weight = None
404
+ if pp_rank + 1 == pp_size:
405
+ lm_head_weight = gpt_model_module.lm_head.weight
406
+
407
+ if is_value_model:
408
+ if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
409
+ _broadcast_tensor(lm_head_weight, "lm_head.weight")
410
+ print_rank_0("load lm_head weight")
411
+ elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
412
+ _broadcast_tensor(lm_head_weight, "reward_head.weight")
413
+ print_rank_0("load lm_head from value_head weight")
414
+ else:
415
+ _broadcast_tensor(None, "lm_head.weight")
416
+ print_rank_0("fail to match lm_head in value_model")
417
+ else:
418
+ _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
419
+ dist.barrier()
420
+ # Broadcast weights inside data parallel groups
421
+ for wrapped_model in wrapped_models:
422
+ broadcast_params(wrapped_model)
423
+
424
+ torch.cuda.empty_cache()
425
+ print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
verl/models/llama/megatron/checkpoint_utils/llama_saver.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ from megatron.core import mpu
20
+ from megatron.core.distributed import DistributedDataParallel as LocalDDP
21
+ from megatron.core.transformer.module import Float16Module
22
+ from torch.nn.parallel import DistributedDataParallel as torchDDP
23
+
24
+ from verl.utils.megatron_utils import print_rank_0, unwrap_model
25
+
26
+
27
+ def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
28
+ """given TP,DP,PP rank to get the global rank."""
29
+
30
+ tp_size = mpu.get_tensor_model_parallel_world_size()
31
+ dp_size = mpu.get_data_parallel_world_size()
32
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
33
+ assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
34
+ # We only support TP-DP-PP grouping, for correctness when resharding
35
+ return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank
36
+
37
+
38
+ def _megatron_calc_layer_map(config):
39
+ """Calculate the mapping of global layer_idx to local layer_idx
40
+ Returns:
41
+ layer_map (Dict: int -> tuple(int, int, int)):
42
+ mapping from the global layer index to
43
+ a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
44
+ """
45
+ from megatron.core import mpu
46
+
47
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
48
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
49
+
50
+ layer_map = dict()
51
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
52
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
53
+
54
+ for pp_rank_idx in range(pp_size):
55
+ for virtual_pp_rank_idx in range(virtual_pp_size):
56
+ layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
57
+ for layer_idx in range(num_layers_per_model):
58
+ layer_map[layer_offset + layer_idx] = (
59
+ pp_rank_idx,
60
+ virtual_pp_rank_idx,
61
+ layer_idx,
62
+ )
63
+ return layer_map
64
+
65
+
66
+ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
67
+ """Merge sharded parameters of a Megatron module into a merged checkpoint.
68
+
69
+ Args:
70
+ wrapped_models (list of megatron.core.distributed.DistributedDataParallel):
71
+ The local DDP wrapped megatron modules.
72
+ config (str or None):
73
+ HF config for model
74
+ dtype: model params type
75
+ is_value_model: if model is value model
76
+ tie_word_embeddings: tie_word_embeddings, not used in llama, only to keep same interface with qwen2
77
+ Returns:
78
+ state_dict (dict):
79
+ The merged state_dict in rank 0, and an empty dictionary in other ranks.
80
+ """
81
+ start_time = time.time()
82
+
83
+ def _get_gpt_model(model):
84
+ return model
85
+
86
+ dp_rank = mpu.get_data_parallel_rank()
87
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
88
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
89
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
90
+ mp_group = mpu.get_model_parallel_group()
91
+
92
+ if dist.get_rank() == 0:
93
+ assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
94
+ assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
95
+ assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
96
+
97
+ if not isinstance(wrapped_models, (list, tuple)):
98
+ wrapped_models = list(wrapped_models)
99
+
100
+ assert len(wrapped_models) == virtual_pp_size
101
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
102
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
103
+
104
+ models = [None] * len(wrapped_models)
105
+
106
+ for i, wrapped_model in enumerate(wrapped_models):
107
+ models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
108
+ assert len(models[i].model.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].model.layers), num_layers_per_model)
109
+
110
+ state_dict = dict()
111
+
112
+ def _get_cpu_tensor(tensor: torch.Tensor):
113
+ if tensor is None:
114
+ return None
115
+ if tensor.device == torch.device("cpu"):
116
+ return tensor.detach().clone()
117
+ return tensor.detach().cpu()
118
+
119
+ def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
120
+ """broadcast tensor across mp_group"""
121
+ nonlocal state_dict
122
+ nonlocal mp_group
123
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
124
+
125
+ if torch.distributed.get_rank() == src_rank:
126
+ if tensor is None:
127
+ weight = None
128
+ tensor_shape = None
129
+ else:
130
+ weight = tensor
131
+ tensor_shape = weight.shape
132
+ else:
133
+ weight = None
134
+ tensor_shape = None
135
+
136
+ obj_list = [tensor_shape]
137
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
138
+ tensor_shape = obj_list[0]
139
+
140
+ if tensor_shape is None:
141
+ # all or none ranks in the mp_group should reach here
142
+ print_rank_0(f"tensor:[{name}] not exist, skip collect")
143
+ return
144
+
145
+ if weight is None:
146
+ weight = torch.empty(
147
+ tensor_shape,
148
+ dtype=dtype,
149
+ device=torch.cuda.current_device(),
150
+ requires_grad=False,
151
+ )
152
+
153
+ dist.broadcast(weight, src=src_rank, group=mp_group)
154
+
155
+ if torch.distributed.get_rank() == 0:
156
+ state_dict[name] = _get_cpu_tensor(weight)
157
+
158
+ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
159
+ """broadcast tensor in tp shards across mp_group"""
160
+ nonlocal state_dict
161
+ nonlocal mp_group
162
+ tp_size = mpu.get_tensor_model_parallel_world_size()
163
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
164
+
165
+ chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
166
+
167
+ obj_list = [chunk_shape]
168
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
169
+ chunk_shape = obj_list[0]
170
+ if chunk_shape is None:
171
+ # all or none ranks in the mp_group should reach here
172
+ print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
173
+ return
174
+
175
+ buffer_tensor = torch.empty(
176
+ chunk_shape,
177
+ dtype=dtype,
178
+ device=torch.cuda.current_device(),
179
+ requires_grad=False,
180
+ )
181
+
182
+ chunk_tensors = [None] * tp_size
183
+
184
+ for i in range(tp_size):
185
+ cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
186
+ sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
187
+ dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
188
+
189
+ if torch.distributed.get_rank() == 0:
190
+ chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
191
+
192
+ if torch.distributed.get_rank() == 0:
193
+ full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
194
+ if mutate_func is not None:
195
+ full_tensor = mutate_func(full_tensor)
196
+ state_dict[name] = full_tensor
197
+
198
+ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
199
+ """broadcast tensor in tp shards across mp_group"""
200
+ nonlocal state_dict
201
+ nonlocal mp_group
202
+ tp_size = mpu.get_tensor_model_parallel_world_size()
203
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
204
+
205
+ chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
206
+
207
+ obj_list = [chunk_shape]
208
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
209
+ chunk_shape = obj_list[0]
210
+ if chunk_shape is None:
211
+ # all or none ranks in the mp_group should reach here
212
+ print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
213
+ return
214
+
215
+ buffer_tensor = torch.empty(
216
+ chunk_shape,
217
+ dtype=dtype,
218
+ device=torch.cuda.current_device(),
219
+ requires_grad=False,
220
+ )
221
+
222
+ chunk_tensors = [None] * tp_size
223
+
224
+ for i in range(tp_size):
225
+ cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
226
+ sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
227
+ dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
228
+
229
+ if torch.distributed.get_rank() == 0:
230
+ chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
231
+
232
+ if torch.distributed.get_rank() == 0:
233
+ full_tensor = torch.concat(chunk_tensors, dim=0)
234
+ intermediate_size_tp = config.intermediate_size // tp_size
235
+ gate_weight_list = []
236
+ up_weight_list = []
237
+ for i in range(tp_size):
238
+ gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]
239
+ gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
240
+ up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
241
+ gate_weight_list.append(gate_weight_tp)
242
+ up_weight_list.append(up_weight_tp)
243
+
244
+ state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
245
+ state_dict[up_name] = torch.cat(up_weight_list, dim=0)
246
+
247
+ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
248
+ """broadcast tensor in tp shards across mp_group"""
249
+ nonlocal state_dict
250
+ nonlocal mp_group
251
+ tp_size = mpu.get_tensor_model_parallel_world_size()
252
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
253
+
254
+ chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
255
+
256
+ obj_list = [chunk_shape]
257
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
258
+ chunk_shape = obj_list[0]
259
+ if chunk_shape is None:
260
+ # all or none ranks in the mp_group should reach here
261
+ print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
262
+ return
263
+
264
+ buffer_tensor = torch.empty(
265
+ chunk_shape,
266
+ dtype=dtype,
267
+ device=torch.cuda.current_device(),
268
+ requires_grad=False,
269
+ )
270
+
271
+ chunk_tensors = [None] * tp_size
272
+
273
+ for i in range(tp_size):
274
+ cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
275
+ sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
276
+ dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
277
+
278
+ if torch.distributed.get_rank() == 0:
279
+ chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
280
+
281
+ if torch.distributed.get_rank() == 0:
282
+ full_tensor = torch.concat(chunk_tensors, dim=0)
283
+ q_weight_list = []
284
+ k_weight_list = []
285
+ v_weight_list = []
286
+ hidden_size_per_head = config.hidden_size // config.num_attention_heads
287
+
288
+ if config.num_key_value_heads >= tp_size:
289
+ q_size_tp = config.hidden_size // tp_size
290
+ kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
291
+ total_size = q_size_tp + 2 * kv_size_tp
292
+ for i in range(tp_size):
293
+ qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
294
+ q_part = qkv_part[:q_size_tp]
295
+ k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]
296
+ v_part = qkv_part[q_size_tp + kv_size_tp : total_size]
297
+ q_weight_list.append(q_part)
298
+ k_weight_list.append(k_part)
299
+ v_weight_list.append(v_part)
300
+ else:
301
+ q_size_tp = config.hidden_size // tp_size
302
+ kv_size_tp = hidden_size_per_head
303
+ total_size = q_size_tp + 2 * kv_size_tp
304
+ for i in range(tp_size):
305
+ qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
306
+ q_part = qkv_part[:q_size_tp]
307
+ k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]
308
+ v_part = qkv_part[q_size_tp + kv_size_tp : total_size]
309
+ q_weight_list.append(q_part)
310
+ if i * config.num_key_value_heads % tp_size == 0:
311
+ k_weight_list.append(k_part)
312
+ v_weight_list.append(v_part)
313
+
314
+ state_dict[q_name] = torch.cat(q_weight_list, dim=0)
315
+ state_dict[k_name] = torch.cat(k_weight_list, dim=0)
316
+ state_dict[v_name] = torch.cat(v_weight_list, dim=0)
317
+
318
+ # empty cache before collecting weights
319
+ torch.cuda.empty_cache()
320
+ # Embeddings
321
+ # -------------------
322
+ if dp_rank == 0:
323
+ # Embeddings
324
+ # -------------------
325
+ print_rank_0("collecting embeddings...")
326
+ gpt_model_module = _get_gpt_model(models[0])
327
+ _broadcast_tp_shard_tensor(
328
+ gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,
329
+ "model.embed_tokens.weight",
330
+ src_pp_rank=0,
331
+ )
332
+
333
+ # Transformer layers
334
+ # -------------------
335
+ layer_map = _megatron_calc_layer_map(config)
336
+ for layer in range(config.num_hidden_layers):
337
+ print_rank_0(f"collecting layer #{layer}...")
338
+ layer_name = f"model.layers.{layer}"
339
+ src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]
340
+
341
+ gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
342
+ sync_layer = gpt_model_module.model.layers[src_layer_idx]
343
+
344
+ _broadcast_tensor(
345
+ sync_layer.input_layernorm.weight,
346
+ f"{layer_name}.input_layernorm.weight",
347
+ src_pp_rank=src_pp_rank,
348
+ )
349
+
350
+ _broadcast_tp_shard_tensor_qkv(
351
+ sync_layer.self_attn.qkv_proj.weight,
352
+ f"{layer_name}.self_attn.q_proj.weight",
353
+ f"{layer_name}.self_attn.k_proj.weight",
354
+ f"{layer_name}.self_attn.v_proj.weight",
355
+ src_pp_rank=src_pp_rank,
356
+ )
357
+
358
+ _broadcast_tp_shard_tensor(
359
+ sync_layer.self_attn.o_proj.weight,
360
+ f"{layer_name}.self_attn.o_proj.weight",
361
+ concat_dim=1,
362
+ src_pp_rank=src_pp_rank,
363
+ )
364
+
365
+ _broadcast_tensor(
366
+ sync_layer.post_attention_layernorm.weight,
367
+ f"{layer_name}.post_attention_layernorm.weight",
368
+ src_pp_rank=src_pp_rank,
369
+ )
370
+
371
+ _broadcast_tp_shard_tensor_gate_up(
372
+ sync_layer.mlp.gate_up_proj.weight,
373
+ f"{layer_name}.mlp.gate_proj.weight",
374
+ f"{layer_name}.mlp.up_proj.weight",
375
+ src_pp_rank=src_pp_rank,
376
+ )
377
+
378
+ _broadcast_tp_shard_tensor(
379
+ sync_layer.mlp.down_proj.weight,
380
+ f"{layer_name}.mlp.down_proj.weight",
381
+ concat_dim=1,
382
+ src_pp_rank=src_pp_rank,
383
+ )
384
+
385
+ # Final Layernorm
386
+ # -------------------
387
+ print_rank_0("collecting final layernorm...")
388
+ gpt_model_module = _get_gpt_model(models[-1])
389
+ _broadcast_tensor(
390
+ getattr(gpt_model_module.model.norm, "weight", None),
391
+ "model.norm.weight",
392
+ src_pp_rank=pp_size - 1,
393
+ )
394
+
395
+ print_rank_0("collecting lm_head...")
396
+
397
+ if is_value_model:
398
+ if pp_rank == pp_size - 1:
399
+ print(f"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}")
400
+ _broadcast_tensor(
401
+ gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,
402
+ "lm_head.weight",
403
+ src_pp_rank=pp_size - 1,
404
+ )
405
+ _broadcast_tensor(
406
+ gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None else None,
407
+ "reward_head.weight",
408
+ src_pp_rank=pp_size - 1,
409
+ )
410
+
411
+ else:
412
+ _broadcast_tp_shard_tensor(
413
+ getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
414
+ "lm_head.weight",
415
+ src_pp_rank=pp_size - 1,
416
+ )
417
+
418
+ dist.barrier()
419
+
420
+ torch.cuda.empty_cache()
421
+ if torch.distributed.get_rank() == 0:
422
+ if dtype not in [torch.float16, torch.bfloat16, torch.float32]:
423
+ print(f'Unknown/unsupported dtype to save: {dtype}"')
424
+ exit(1)
425
+ for k, v in state_dict.items():
426
+ if dtype != v.dtype:
427
+ state_dict[k] = v.to(dtype)
428
+
429
+ print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
430
+ return state_dict
verl/models/llama/megatron/layers/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .parallel_attention import ParallelLlamaAttention
16
+ from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad
17
+ from .parallel_linear import (
18
+ LinearForLastLayer,
19
+ MergedColumnParallelLinear,
20
+ QKVParallelLinear,
21
+ )
22
+ from .parallel_mlp import ParallelLlamaMLP
23
+ from .parallel_rmsnorm import ParallelLlamaRMSNorm
24
+
25
+ __all__ = ["LinearForLastLayer", "MergedColumnParallelLinear", "QKVParallelLinear", "ParallelLlamaAttention", "ParallelLlamaDecoderLayer", "ParallelLlamaDecoderLayerRmPad", "ParallelLlamaMLP", "ParallelLlamaRMSNorm"]
verl/models/llama/megatron/layers/parallel_attention.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import math
22
+ from typing import Optional, Tuple
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from einops import rearrange
27
+ from flash_attn.layers.rotary import apply_rotary_emb
28
+ from megatron.core import ModelParallelConfig, tensor_parallel
29
+ from megatron.core import parallel_state as mpu
30
+ from torch import nn
31
+ from transformers import LlamaConfig
32
+ from transformers.utils import is_flash_attn_2_available
33
+
34
+ from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear
35
+ from verl.utils.megatron import tensor_parallel as tp_utils
36
+
37
+
38
+ class LlamaRotaryEmbedding(nn.Module):
39
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
40
+ super().__init__()
41
+
42
+ self.dim = dim
43
+ self.max_position_embeddings = max_position_embeddings
44
+ self.base = base
45
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
46
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
47
+
48
+ # Build here to make `torch.jit.trace` work.
49
+ self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())
50
+
51
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
52
+ self.max_seq_len_cached = seq_len
53
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
54
+
55
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
56
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
57
+ emb = torch.cat((freqs, freqs), dim=-1)
58
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
59
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
60
+
61
+ def forward(self, x, seq_len=None):
62
+ # x: [bs, num_attention_heads, seq_len, head_size]
63
+ if seq_len > self.max_seq_len_cached:
64
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
65
+
66
+ return (
67
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
68
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
69
+ )
70
+
71
+
72
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
73
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
74
+
75
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
76
+ self.scaling_factor = scaling_factor
77
+ super().__init__(dim, max_position_embeddings, base, device)
78
+
79
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
80
+ self.max_seq_len_cached = seq_len
81
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
82
+ t = t / self.scaling_factor
83
+
84
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
85
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
86
+ emb = torch.cat((freqs, freqs), dim=-1)
87
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
88
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
89
+
90
+
91
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
92
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
93
+
94
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
95
+ self.scaling_factor = scaling_factor
96
+ super().__init__(dim, max_position_embeddings, base, device)
97
+
98
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
99
+ self.max_seq_len_cached = seq_len
100
+
101
+ if seq_len > self.max_position_embeddings:
102
+ base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
103
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
104
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
105
+
106
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
107
+
108
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
109
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
110
+ emb = torch.cat((freqs, freqs), dim=-1)
111
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
112
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
113
+
114
+
115
+ class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding):
116
+ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None):
117
+ super().__init__(dim, max_position_embeddings, base, device)
118
+
119
+ self.factor = config.rope_scaling["factor"] # `8` in the original implementation
120
+ self.high_freq_factor = config.rope_scaling["high_freq_factor"] # `1` in the original implementation
121
+ self.low_freq_factor = config.rope_scaling["low_freq_factor"] # `4` in the original implementation
122
+ self.old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
123
+
124
+ low_freq_wavelen = self.old_context_len / self.low_freq_factor
125
+ high_freq_wavelen = self.old_context_len / self.high_freq_factor
126
+
127
+ wavelen = 2 * math.pi / self.inv_freq
128
+ # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor
129
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq)
130
+ # otherwise: interpolate between the two, using a smooth factor
131
+ smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / (self.high_freq_factor - self.low_freq_factor)
132
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama
133
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
134
+ inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
135
+
136
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
137
+
138
+ # Build here to make `torch.jit.trace` work.
139
+ self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())
140
+
141
+
142
+ def rotate_half(x):
143
+ """Rotates half the hidden dims of the input."""
144
+ x1 = x[..., : x.shape[-1] // 2]
145
+ x2 = x[..., x.shape[-1] // 2 :]
146
+ return torch.cat((-x2, x1), dim=-1)
147
+
148
+
149
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
150
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
151
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
152
+ q_embed = (q * cos) + (rotate_half(q) * sin)
153
+ k_embed = (k * cos) + (rotate_half(k) * sin)
154
+ return q_embed, k_embed
155
+
156
+
157
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
158
+ """
159
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
160
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
161
+ """
162
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
163
+ if n_rep == 1:
164
+ return hidden_states
165
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
166
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
167
+
168
+
169
+ class ParallelLlamaAttention(nn.Module):
170
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
171
+
172
+ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
173
+ super().__init__()
174
+ self.config = config
175
+ self.megatron_config = megatron_config
176
+ self.hidden_size = config.hidden_size
177
+ self.num_heads = config.num_attention_heads
178
+ self.head_dim = self.hidden_size // self.num_heads
179
+ self.num_key_value_heads = config.num_key_value_heads
180
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
181
+ self.max_position_embeddings = config.max_position_embeddings
182
+ self.rope_theta = config.rope_theta
183
+
184
+ # assign values after tp
185
+ tp_size = mpu.get_tensor_model_parallel_world_size()
186
+ assert self.num_heads % tp_size == 0, f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}"
187
+ assert self.num_key_value_heads % tp_size == 0, f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}"
188
+
189
+ self.num_heads_per_tp = self.num_heads // tp_size
190
+ self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size
191
+ self.hidden_size_per_tp = self.hidden_size // tp_size
192
+
193
+ if (self.head_dim * self.num_heads) != self.hidden_size:
194
+ raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).")
195
+
196
+ column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
197
+ row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
198
+
199
+ if megatron_config is not None:
200
+ assert column_kwargs.get("config", False), "must have ModelParallelConfig"
201
+ assert row_kwargs.get("config", False), "must have ModelParallelConfig"
202
+ tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
203
+ tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
204
+
205
+ # [self.q_size, self.k_size, self.v_size]
206
+ self.qkv_proj = QKVParallelLinear(
207
+ input_size=self.hidden_size,
208
+ num_heads=self.num_heads,
209
+ num_key_value_heads=self.num_key_value_heads,
210
+ head_dim=self.head_dim,
211
+ bias=config.attention_bias,
212
+ gather_output=False,
213
+ skip_bias_add=False,
214
+ **column_kwargs,
215
+ )
216
+
217
+ self.q_size = self.num_heads_per_tp * self.head_dim
218
+ self.k_size = self.num_key_value_heads_per_tp * self.head_dim
219
+ self.v_size = self.num_key_value_heads_per_tp * self.head_dim
220
+
221
+ self.o_proj = tensor_parallel.RowParallelLinear(
222
+ input_size=self.num_heads * self.head_dim,
223
+ output_size=self.hidden_size,
224
+ bias=config.attention_bias,
225
+ input_is_parallel=True,
226
+ skip_bias_add=False,
227
+ **row_kwargs,
228
+ )
229
+
230
+ self._init_rope()
231
+
232
+ def _init_rope(self):
233
+ if self.config.rope_scaling is None:
234
+ self.rotary_emb = LlamaRotaryEmbedding(
235
+ self.head_dim,
236
+ max_position_embeddings=self.max_position_embeddings,
237
+ base=self.rope_theta,
238
+ )
239
+ else:
240
+ rope_type_key = "type" if "type" in self.config.rope_scaling else "rope_type"
241
+ scaling_type = self.config.rope_scaling[rope_type_key]
242
+ scaling_factor = self.config.rope_scaling["factor"]
243
+ if scaling_type == "linear":
244
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
245
+ self.head_dim,
246
+ max_position_embeddings=self.max_position_embeddings,
247
+ scaling_factor=scaling_factor,
248
+ base=self.rope_theta,
249
+ )
250
+ elif scaling_type == "dynamic":
251
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
252
+ self.head_dim,
253
+ max_position_embeddings=self.max_position_embeddings,
254
+ scaling_factor=scaling_factor,
255
+ base=self.rope_theta,
256
+ )
257
+ elif scaling_type == "llama3":
258
+ self.rotary_emb = LlamaLlama3ScalingRotaryEmbedding(
259
+ self.head_dim,
260
+ self.config,
261
+ max_position_embeddings=self.max_position_embeddings,
262
+ base=self.rope_theta,
263
+ )
264
+ else:
265
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
266
+
267
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
268
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
269
+
270
+ def forward(
271
+ self,
272
+ hidden_states: torch.Tensor,
273
+ attention_mask: Optional[torch.Tensor] = None,
274
+ position_ids: Optional[torch.LongTensor] = None,
275
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
276
+ bsz, q_len, _ = hidden_states.size()
277
+ qkv = self.qkv_proj(hidden_states)[0]
278
+ query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
279
+
280
+ query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)
281
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
282
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
283
+
284
+ kv_seq_len = key_states.shape[-2]
285
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
286
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
287
+
288
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
289
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
290
+
291
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
292
+
293
+ if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):
294
+ raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is {attn_weights.size()}")
295
+
296
+ if attention_mask is not None:
297
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
298
+ raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
299
+ attn_weights = attn_weights + attention_mask
300
+
301
+ # upcast attention to fp32
302
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
303
+ attn_output = torch.matmul(attn_weights, value_states)
304
+
305
+ if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):
306
+ raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is {attn_output.size()}")
307
+
308
+ attn_output = attn_output.transpose(1, 2).contiguous()
309
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)
310
+ attn_output = self.o_proj(attn_output)[0]
311
+ return attn_output
312
+
313
+
314
+ """
315
+ Remove padding Attention
316
+ - Using Flash-attn 2
317
+ - Compatible with sequence parallel
318
+ """
319
+
320
+
321
+ if is_flash_attn_2_available():
322
+ from flash_attn import flash_attn_varlen_func
323
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
324
+
325
+
326
+ def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):
327
+ batch_size = position_ids.shape[0]
328
+
329
+ q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim)
330
+ k = pad_input(k, indices, batch_size, sequence_length)
331
+ cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
332
+ sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
333
+ q_embed = (q * cos) + (rotate_half(q) * sin)
334
+ k_embed = (k * cos) + (rotate_half(k) * sin)
335
+
336
+ q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices)
337
+ k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices)
338
+
339
+ return q_embed, k_embed
340
+
341
+
342
+ # use flash-attn rotary embeddings with rmpad
343
+ # cos/sin shoudl be: (seq_length, rotary_dim / 2)
344
+ def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):
345
+ q_embed = apply_rotary_emb(q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
346
+ k_embed = apply_rotary_emb(k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
347
+ return q_embed, k_embed
348
+
349
+
350
+ class ParallelLlamaAttentionRmPad(ParallelLlamaAttention):
351
+ def forward(
352
+ self,
353
+ hidden_states: torch.Tensor,
354
+ position_ids: Optional[torch.LongTensor] = None,
355
+ sequence_length: int = None,
356
+ indices: torch.Tensor = None,
357
+ cu_seqlens: torch.Tensor = None,
358
+ max_seqlen_in_batch: int = None,
359
+ ):
360
+ total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel
361
+
362
+ if self.megatron_config.sequence_parallel:
363
+ total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()
364
+
365
+ qkv = self.qkv_proj(hidden_states)[0]
366
+ query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) # (total_nnz, 1, hidden_size)
367
+
368
+ if self.megatron_config.sequence_parallel:
369
+ sequence_parallel_pad = total_nnz - cu_seqlens[-1]
370
+ total_nnz = cu_seqlens[-1] # total_nnz before sp padding
371
+ query_states = query_states[:total_nnz]
372
+ key_states = key_states[:total_nnz]
373
+ value_states = value_states[:total_nnz]
374
+
375
+ # Flash attention requires the input to have the shape
376
+ # batch_size x seq_length x head_dime x hidden_dim
377
+ # therefore we just need to keep the original shape
378
+ query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)
379
+ key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
380
+ value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
381
+
382
+ cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)
383
+ cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half
384
+ query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch)
385
+ # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices,
386
+
387
+ # TODO: llama does not have dropout in the config??
388
+ # It is recommended to use dropout with FA according to the docs
389
+ # when training.
390
+ dropout_rate = 0.0 # if not self.training else self.attn_dropout
391
+
392
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
393
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
394
+ # cast them back in float16 just to be sure everything works as expected.
395
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
396
+ # in fp32. (LlamaRMSNorm handles it correctly)
397
+ input_dtype = query_states.dtype
398
+ if input_dtype == torch.float32:
399
+ query_states = query_states.to(torch.float16)
400
+ key_states = key_states.to(torch.float16)
401
+ value_states = value_states.to(torch.float16)
402
+
403
+ attn_output_unpad = flash_attn_varlen_func(
404
+ query_states,
405
+ key_states,
406
+ value_states,
407
+ cu_seqlens_q=cu_seqlens,
408
+ cu_seqlens_k=cu_seqlens,
409
+ max_seqlen_q=max_seqlen_in_batch,
410
+ max_seqlen_k=max_seqlen_in_batch,
411
+ dropout_p=dropout_rate,
412
+ softmax_scale=None,
413
+ causal=True,
414
+ )
415
+
416
+ attn_output_unpad = attn_output_unpad.to(input_dtype)
417
+ attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()
418
+
419
+ # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled
420
+ # Here we need to repad
421
+ if self.megatron_config.sequence_parallel:
422
+ attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))
423
+
424
+ attn_output_unpad = self.o_proj(attn_output_unpad)[0]
425
+ return attn_output_unpad
verl/models/llama/megatron/layers/parallel_decoder.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from typing import Optional, Tuple
22
+
23
+ import torch
24
+ from megatron.core import ModelParallelConfig
25
+ from torch import nn
26
+ from transformers import LlamaConfig
27
+
28
+ from verl.utils.megatron_utils import TransformerConfig, convert_config
29
+
30
+ from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad
31
+ from .parallel_mlp import ParallelLlamaMLP
32
+ from .parallel_rmsnorm import ParallelLlamaRMSNorm
33
+
34
+
35
+ class ParallelLlamaDecoderLayer(nn.Module):
36
+ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):
37
+ super().__init__()
38
+ self.config: TransformerConfig = convert_config(config, megatron_config)
39
+ self.layer_idx = layer_idx
40
+ self.hidden_size = config.hidden_size
41
+ self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config)
42
+
43
+ self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
44
+ self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
45
+ self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
46
+
47
+ def forward(
48
+ self,
49
+ hidden_states: torch.Tensor,
50
+ attention_mask: Optional[torch.Tensor] = None,
51
+ position_ids: Optional[torch.LongTensor] = None,
52
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
53
+ """
54
+ Args:
55
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
56
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
57
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
58
+ output_attentions (`bool`, *optional*):
59
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
60
+ returned tensors for more detail.
61
+ use_cache (`bool`, *optional*):
62
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
63
+ (see `past_key_values`).
64
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
65
+ """
66
+
67
+ residual = hidden_states
68
+
69
+ hidden_states = self.input_layernorm(hidden_states)
70
+
71
+ # Note: sequence parallel is hidden inside ColumnParallelLinear
72
+ # reduce scatter is hidden inside RowParallelLinear
73
+
74
+ # Self Attention
75
+ hidden_states = self.self_attn(
76
+ hidden_states=hidden_states,
77
+ attention_mask=attention_mask,
78
+ position_ids=position_ids,
79
+ )
80
+
81
+ # TODO: add sequence parallel operator reduce_scatter here
82
+
83
+ hidden_states = residual + hidden_states
84
+
85
+ # Fully Connected
86
+ residual = hidden_states
87
+ hidden_states = self.post_attention_layernorm(hidden_states)
88
+
89
+ # TODO: add sequence parallel operator all_gather here
90
+
91
+ hidden_states = self.mlp(hidden_states)
92
+
93
+ # TODO: add sequence parallel operator reduce_scatter here
94
+
95
+ hidden_states = residual + hidden_states
96
+
97
+ outputs = hidden_states
98
+
99
+ return outputs
100
+
101
+
102
+ class ParallelLlamaDecoderLayerRmPad(nn.Module):
103
+ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):
104
+ super().__init__()
105
+ self.config: TransformerConfig = convert_config(config, megatron_config)
106
+ self.layer_idx = layer_idx
107
+ self.hidden_size = config.hidden_size
108
+ self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config)
109
+
110
+ self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
111
+ self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
112
+ self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
113
+
114
+ def forward(
115
+ self,
116
+ hidden_states: torch.Tensor,
117
+ position_ids: Optional[torch.LongTensor] = None,
118
+ sequence_length: int = None,
119
+ indices: torch.Tensor = None,
120
+ cu_seqlens: int = None,
121
+ max_seqlen_in_batch: int = None,
122
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
123
+ residual = hidden_states # (total_nnz // sp, 1, hidden_size)
124
+
125
+ hidden_states = self.input_layernorm(hidden_states)
126
+
127
+ # Self Attention
128
+ # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)
129
+ # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)
130
+ hidden_states = self.self_attn(
131
+ hidden_states=hidden_states,
132
+ position_ids=position_ids,
133
+ sequence_length=sequence_length,
134
+ indices=indices,
135
+ cu_seqlens=cu_seqlens,
136
+ max_seqlen_in_batch=max_seqlen_in_batch,
137
+ )
138
+
139
+ hidden_states = residual + hidden_states
140
+
141
+ # Fully Connected
142
+ # shape changes same as attn
143
+ residual = hidden_states
144
+ hidden_states = self.post_attention_layernorm(hidden_states)
145
+ hidden_states = self.mlp(hidden_states)
146
+ hidden_states = residual + hidden_states
147
+
148
+ outputs = hidden_states
149
+
150
+ return outputs
verl/models/llama/megatron/layers/parallel_linear.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023 The vLLM team.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
15
+
16
+ import torch
17
+ from megatron.core import tensor_parallel
18
+
19
+
20
+ class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
21
+ def __init__(
22
+ self,
23
+ input_size,
24
+ num_heads,
25
+ num_key_value_heads,
26
+ head_dim,
27
+ *,
28
+ bias=True,
29
+ gather_output=True,
30
+ skip_bias_add=False,
31
+ **kwargs,
32
+ ):
33
+ # Keep input parameters, and already restrict the head numbers
34
+ self.input_size = input_size
35
+ self.q_output_size = num_heads * head_dim
36
+ self.kv_output_size = num_key_value_heads * head_dim
37
+ self.head_dim = head_dim
38
+ self.gather_output = gather_output
39
+ self.skip_bias_add = skip_bias_add
40
+
41
+ input_size = self.input_size
42
+ output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim
43
+
44
+ super().__init__(
45
+ input_size=input_size,
46
+ output_size=output_size,
47
+ bias=bias,
48
+ gather_output=gather_output,
49
+ skip_bias_add=skip_bias_add,
50
+ **kwargs,
51
+ )
52
+
53
+
54
+ class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
55
+ def __init__(
56
+ self,
57
+ input_size,
58
+ gate_ouput_size,
59
+ up_output_size,
60
+ *,
61
+ bias=True,
62
+ gather_output=True,
63
+ skip_bias_add=False,
64
+ **kwargs,
65
+ ):
66
+ # Keep input parameters, and already restrict the head numbers
67
+ self.input_size = input_size
68
+ self.output_size = gate_ouput_size + up_output_size
69
+ self.gather_output = gather_output
70
+ self.skip_bias_add = skip_bias_add
71
+
72
+ super().__init__(
73
+ input_size=self.input_size,
74
+ output_size=self.output_size,
75
+ bias=bias,
76
+ gather_output=gather_output,
77
+ skip_bias_add=skip_bias_add,
78
+ **kwargs,
79
+ )
80
+
81
+
82
+ class LinearForLastLayer(torch.nn.Linear):
83
+ def __init__(
84
+ self,
85
+ input_size,
86
+ output_size,
87
+ *,
88
+ config,
89
+ bias=True,
90
+ ):
91
+ super().__init__(in_features=input_size, out_features=output_size, bias=bias)
92
+ self.sequence_parallel = config.sequence_parallel
93
+ if self.sequence_parallel:
94
+ self.weight.sequence_parallel = True
95
+
96
+ def forward(
97
+ self,
98
+ input_,
99
+ weight=None,
100
+ runtime_gather_output=None,
101
+ ):
102
+ logits = super().forward(input_)
103
+ logits = logits.float()
104
+ if self.sequence_parallel:
105
+ logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
106
+ return logits, None
verl/models/llama/megatron/layers/parallel_mlp.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from megatron.core import ModelParallelConfig, tensor_parallel
22
+ from megatron.core import parallel_state as mpu
23
+ from torch import nn
24
+ from transformers.activations import ACT2FN
25
+
26
+ from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear
27
+ from verl.utils.megatron import tensor_parallel as tp_utils
28
+
29
+
30
+ class ParallelLlamaMLP(nn.Module):
31
+ def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
32
+ super().__init__()
33
+ self.config = config
34
+ self.hidden_size = config.hidden_size
35
+ self.intermediate_size = config.intermediate_size
36
+ # The weight is only [hidden_size, intermediate_size // model_parallel_world_size]
37
+
38
+ column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
39
+ row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
40
+
41
+ if megatron_config is not None:
42
+ assert column_kwargs.get("config", False), "must have ModelParallelConfig"
43
+ assert row_kwargs.get("config", False), "must have ModelParallelConfig"
44
+ tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
45
+ tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
46
+
47
+ tp_size = mpu.get_tensor_model_parallel_world_size()
48
+
49
+ self.gate_up_proj = MergedColumnParallelLinear(
50
+ input_size=self.hidden_size,
51
+ gate_ouput_size=self.intermediate_size,
52
+ up_output_size=self.intermediate_size,
53
+ bias=False,
54
+ gather_output=False,
55
+ skip_bias_add=False,
56
+ **column_kwargs,
57
+ )
58
+ self.gate_size = self.intermediate_size // tp_size
59
+
60
+ self.down_proj = tensor_parallel.RowParallelLinear(
61
+ input_size=self.intermediate_size,
62
+ output_size=self.hidden_size,
63
+ bias=False,
64
+ input_is_parallel=True,
65
+ skip_bias_add=False,
66
+ **row_kwargs,
67
+ )
68
+
69
+ self.act_fn = ACT2FN[config.hidden_act]
70
+
71
+ def forward(self, x):
72
+ gate_up = self.gate_up_proj(x)[0]
73
+ gate, up = gate_up.split(self.gate_size, dim=-1)
74
+ return self.down_proj(self.act_fn(gate) * up)[0]
verl/models/llama/megatron/layers/parallel_rmsnorm.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numbers
16
+
17
+ import torch
18
+ from apex.normalization.fused_layer_norm import fused_rms_norm_affine
19
+ from megatron.core import ModelParallelConfig
20
+ from torch import nn
21
+ from transformers import LlamaConfig
22
+
23
+ from verl.utils.megatron import sequence_parallel as sp_utils
24
+
25
+
26
+ class ParallelLlamaRMSNorm(nn.Module):
27
+ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
28
+ """
29
+ LlamaRMSNorm is equivalent to T5LayerNorm
30
+ """
31
+ super().__init__()
32
+ if isinstance(config.hidden_size, numbers.Integral):
33
+ normalized_shape = (config.hidden_size,)
34
+ self.normalized_shape = torch.Size(normalized_shape)
35
+ self.weight = nn.Parameter(torch.ones(self.normalized_shape))
36
+ self.variance_epsilon = config.rms_norm_eps
37
+
38
+ if megatron_config.sequence_parallel:
39
+ sp_utils.mark_parameter_as_sequence_parallel(self.weight)
40
+
41
+ def forward(self, hidden_states):
42
+ return fused_rms_norm_affine(
43
+ input=hidden_states,
44
+ weight=self.weight,
45
+ normalized_shape=self.normalized_shape,
46
+ eps=self.variance_epsilon,
47
+ memory_efficient=True,
48
+ )
verl/models/llama/megatron/modeling_llama_megatron.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch LLaMA model with Megatron-style acceleration."""
21
+
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from megatron.core import ModelParallelConfig, mpu, tensor_parallel
27
+ from torch import nn
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast
29
+ from transformers.models.llama.configuration_llama import LlamaConfig
30
+ from transformers.models.llama.modeling_llama import CausalLMOutputWithPast
31
+
32
+ from verl.utils.megatron import sequence_parallel as sp_utils
33
+ from verl.utils.megatron import tensor_parallel as tp_utils
34
+ from verl.utils.megatron_utils import TransformerConfig, convert_config
35
+
36
+ from .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm
37
+
38
+ """
39
+ TODO:
40
+ 1. Add weight initialization. Here we need to be careful on TP weight init.
41
+ 2. Add sequence parallel
42
+ 3. Load checkpoint from meta LLama pretrained checkpoint
43
+ """
44
+
45
+
46
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
47
+ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):
48
+ """
49
+ Make causal mask used for bi-directional self-attention.
50
+ """
51
+ bsz, tgt_len = input_ids_shape
52
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
53
+ mask_cond = torch.arange(mask.size(-1), device=device)
54
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
55
+ mask = mask.to(dtype)
56
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
57
+
58
+
59
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
60
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
61
+ """
62
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
63
+ """
64
+ bsz, src_len = mask.size()
65
+ tgt_len = tgt_len if tgt_len is not None else src_len
66
+
67
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68
+
69
+ inverted_mask = 1.0 - expanded_mask
70
+
71
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
72
+
73
+
74
+ class ParallelLlamaModel(nn.Module):
75
+ """
76
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
77
+
78
+ Args:
79
+ config: LlamaConfig
80
+ """
81
+
82
+ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
83
+ super().__init__()
84
+ self.config: TransformerConfig = convert_config(config, megatron_config)
85
+ self.padding_idx = config.pad_token_id
86
+ self.vocab_size = config.vocab_size
87
+ embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
88
+ if megatron_config is not None:
89
+ assert embedding_kwargs.get("config", False), "must have ModelParallelConfig"
90
+ tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
91
+ self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)
92
+
93
+ self.layers = nn.ModuleList([ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)])
94
+ self.norm = ParallelLlamaRMSNorm(config, megatron_config)
95
+
96
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
97
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):
98
+ # create causal mask
99
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
100
+ combined_attention_mask = None
101
+ if input_shape[-1] > 1:
102
+ combined_attention_mask = _make_causal_mask(
103
+ input_shape,
104
+ inputs_embeds.dtype,
105
+ device=inputs_embeds.device,
106
+ )
107
+
108
+ if attention_mask is not None:
109
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
110
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
111
+ combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
112
+
113
+ return combined_attention_mask
114
+
115
+ def forward(
116
+ self,
117
+ input_ids: torch.LongTensor = None,
118
+ attention_mask: Optional[torch.Tensor] = None,
119
+ position_ids: Optional[torch.LongTensor] = None,
120
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
121
+ """
122
+
123
+ Args:
124
+ input_ids: input ids. shape (batch_size, seq_length)
125
+ attention_mask: attention_mask. shape (batch_size, seq_length)
126
+ position_ids: position ids. shape (batch_size, seq_length)
127
+
128
+ Returns:
129
+
130
+ """
131
+ batch_size, seq_length = input_ids.shape
132
+ inputs_embeds = self.embed_tokens(input_ids)
133
+ # embed positions
134
+
135
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)
136
+
137
+ hidden_states = inputs_embeds
138
+
139
+ for idx, decoder_layer in enumerate(self.layers):
140
+ layer_outputs = decoder_layer(
141
+ hidden_states,
142
+ attention_mask=attention_mask,
143
+ position_ids=position_ids,
144
+ )
145
+
146
+ hidden_states = layer_outputs
147
+
148
+ hidden_states = self.norm(hidden_states)
149
+
150
+ return hidden_states
151
+
152
+
153
+ class ParallelLlamaForCausalLM(nn.Module):
154
+ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
155
+ super().__init__()
156
+ self.config: TransformerConfig = convert_config(config, megatron_config)
157
+ self.model = ParallelLlamaModel(config, megatron_config=megatron_config)
158
+ self.vocab_size = config.vocab_size
159
+
160
+ column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
161
+ if megatron_config is not None:
162
+ assert column_kwargs.get("config", False), "must have ModelParallelConfig"
163
+ tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
164
+
165
+ self.lm_head = tensor_parallel.ColumnParallelLinear(
166
+ input_size=config.hidden_size,
167
+ output_size=config.vocab_size,
168
+ bias=False,
169
+ gather_output=False,
170
+ skip_bias_add=False,
171
+ **column_kwargs,
172
+ )
173
+
174
+ def forward(
175
+ self,
176
+ input_ids: torch.LongTensor = None,
177
+ attention_mask: Optional[torch.Tensor] = None,
178
+ position_ids: Optional[torch.LongTensor] = None,
179
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
180
+ r"""
181
+ Args:
182
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
183
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
184
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
185
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
186
+
187
+ Returns:
188
+ ```"""
189
+
190
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
191
+ outputs = self.model(
192
+ input_ids=input_ids,
193
+ attention_mask=attention_mask,
194
+ position_ids=position_ids,
195
+ )
196
+
197
+ hidden_states = outputs
198
+ logits = self.lm_head(hidden_states)[0]
199
+
200
+ logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
201
+
202
+ logits = logits.float()
203
+ return CausalLMOutputWithPast(
204
+ loss=None,
205
+ logits=logits,
206
+ past_key_values=None,
207
+ hidden_states=None,
208
+ attentions=None,
209
+ )
210
+
211
+
212
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
213
+
214
+
215
+ class ParallelLlamaModelRmPad(nn.Module):
216
+ """
217
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
218
+
219
+ Args:
220
+ config: LlamaConfig
221
+ """
222
+
223
+ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
224
+ super().__init__()
225
+ self.config: TransformerConfig = convert_config(config, megatron_config)
226
+ self.padding_idx = config.pad_token_id
227
+ self.vocab_size = config.vocab_size
228
+ embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
229
+ self.megatron_config = megatron_config
230
+ if megatron_config is not None:
231
+ assert embedding_kwargs.get("config", False), "must have ModelParallelConfig"
232
+ tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
233
+ self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)
234
+
235
+ self.layers = nn.ModuleList([ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)])
236
+ self.norm = ParallelLlamaRMSNorm(config, megatron_config)
237
+
238
+ def forward(
239
+ self,
240
+ input_ids: torch.Tensor,
241
+ position_ids: Optional[torch.LongTensor] = None,
242
+ sequence_length: int = None,
243
+ indices: torch.Tensor = None,
244
+ cu_seqlens: int = None,
245
+ max_seqlen_in_batch: int = None,
246
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
247
+ """
248
+
249
+ Args:
250
+ input_ids: input ids. shape (1, totol_nnz)
251
+ position_ids: position ids. shape (batch_size, seq_length)
252
+
253
+ Returns:
254
+
255
+ """
256
+ inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
257
+
258
+ # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
259
+ inputs_embeds = inputs_embeds.transpose(0, 1)
260
+ if self.megatron_config.sequence_parallel:
261
+ inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
262
+
263
+ hidden_states = inputs_embeds
264
+ for idx, decoder_layer in enumerate(self.layers):
265
+ layer_outputs = decoder_layer(
266
+ hidden_states,
267
+ position_ids=position_ids,
268
+ sequence_length=sequence_length,
269
+ indices=indices,
270
+ cu_seqlens=cu_seqlens,
271
+ max_seqlen_in_batch=max_seqlen_in_batch,
272
+ )
273
+
274
+ hidden_states = layer_outputs
275
+
276
+ hidden_states = self.norm(hidden_states)
277
+
278
+ return hidden_states
279
+
280
+
281
+ class ParallelLlamaForCausalLMRmPad(nn.Module):
282
+ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
283
+ super().__init__()
284
+ self.config: TransformerConfig = convert_config(config, megatron_config)
285
+ self.megatron_config = megatron_config
286
+ self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)
287
+ self.vocab_size = config.vocab_size
288
+ self._init_head(config)
289
+
290
+ def _init_head(self, config):
291
+ column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
292
+ if self.megatron_config is not None:
293
+ assert column_kwargs.get("config", False), "must have ModelParallelConfig"
294
+ tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
295
+ self.lm_head = tensor_parallel.ColumnParallelLinear(
296
+ input_size=config.hidden_size,
297
+ output_size=config.vocab_size,
298
+ bias=False,
299
+ gather_output=False,
300
+ skip_bias_add=False,
301
+ **column_kwargs,
302
+ )
303
+
304
+ def _forward_head(self, hidden_states):
305
+ # all_gather from sequence parallel region is performed inside lm_head
306
+ logits = self.lm_head(hidden_states)[0]
307
+ logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
308
+ logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size)
309
+ return logits
310
+
311
+ def forward(
312
+ self,
313
+ input_ids: torch.LongTensor = None,
314
+ attention_mask: Optional[torch.Tensor] = None,
315
+ position_ids: Optional[torch.LongTensor] = None,
316
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
317
+ r"""
318
+ Args:
319
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
320
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
321
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
322
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
323
+
324
+ Returns:
325
+ ```"""
326
+ batch_size, sequence_length = input_ids.shape
327
+
328
+ # remove padding here
329
+ input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1)
330
+
331
+ # pad input_ids to multiple of tp for all tp ranks
332
+ # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
333
+ if self.megatron_config.sequence_parallel:
334
+ input_ids = sp_utils.pad_to_sequence_parallel(input_ids)
335
+
336
+ input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad)
337
+
338
+ outputs = self.model(
339
+ input_ids=input_ids,
340
+ position_ids=position_ids,
341
+ sequence_length=sequence_length,
342
+ indices=indices,
343
+ cu_seqlens=cu_seqlens,
344
+ max_seqlen_in_batch=max_seqlen_in_batch,
345
+ )
346
+
347
+ hidden_states = outputs
348
+
349
+ logits = self._forward_head(hidden_states)
350
+
351
+ # remove padding from sequence parallel
352
+ if self.megatron_config.sequence_parallel:
353
+ totol_nnz = cu_seqlens[-1]
354
+ logits = logits[:totol_nnz] # (total_nnz_padded)
355
+
356
+ logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension
357
+ # add removed padding back
358
+ logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
359
+
360
+ return CausalLMOutputWithPast(
361
+ loss=None,
362
+ logits=logits,
363
+ past_key_values=None,
364
+ hidden_states=None,
365
+ attentions=None,
366
+ )
367
+
368
+
369
+ class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):
370
+ def _init_head(self, config):
371
+ column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
372
+ if self.megatron_config is not None:
373
+ assert column_kwargs.get("config", False), "must have ModelParallelConfig"
374
+ tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
375
+ self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)
376
+ # lm_head is effectively the same as sequence parallel
377
+ sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
378
+
379
+ def _forward_head(self, hidden_states):
380
+ logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
381
+ logits = logits.float()
382
+ if self.megatron_config.sequence_parallel:
383
+ logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
384
+ return logits
385
+
386
+ def forward(
387
+ self,
388
+ input_ids: torch.LongTensor = None,
389
+ attention_mask: Optional[torch.Tensor] = None,
390
+ position_ids: Optional[torch.LongTensor] = None,
391
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
392
+ output = super().forward(input_ids, attention_mask, position_ids)
393
+ output.logits = torch.squeeze(output.logits, dim=-1)
394
+ return output
395
+
396
+
397
+ """
398
+ Support pipeline parallelism
399
+ """
400
+
401
+
402
+ class ParallelLlamaModelRmPadPP(nn.Module):
403
+ """
404
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
405
+ This model definition supports pipeline parallelism. To support pp and vpp,
406
+ - This model only contains layer in this pp stage and vpp chunk
407
+ - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.
408
+ Args:
409
+ config: LlamaConfig
410
+ """
411
+
412
+ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
413
+ super().__init__()
414
+ self.config: TransformerConfig = convert_config(config, megatron_config)
415
+ self.padding_idx = config.pad_token_id
416
+ self.vocab_size = config.vocab_size
417
+ self.pre_process = pre_process
418
+ self.post_process = post_process
419
+ self.megatron_config = megatron_config
420
+ embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
421
+ if megatron_config is not None:
422
+ assert embedding_kwargs.get("config", False), "must have ModelParallelConfig"
423
+ tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
424
+ if pre_process:
425
+ self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)
426
+ else:
427
+ self.embed_tokens = None
428
+
429
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
430
+ pp_size = megatron_config.pipeline_model_parallel_size
431
+ self.num_layer_per_pp = config.num_hidden_layers // pp_size
432
+ vpp_size = megatron_config.virtual_pipeline_model_parallel_size
433
+ vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()
434
+
435
+ if vpp_size is not None:
436
+ self.layers = nn.ModuleList()
437
+ self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size
438
+ self.num_layer_this_model = self.num_layer_vpp_chunk
439
+ offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk)
440
+ else:
441
+ self.num_layer_this_model = self.num_layer_per_pp
442
+ offset = pp_rank * self.num_layer_per_pp
443
+
444
+ self.layers = nn.ModuleList()
445
+ for i in range(self.num_layer_this_model):
446
+ layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i)
447
+ self.layers.add_module(f"{i}", layer)
448
+
449
+ if post_process:
450
+ self.norm = ParallelLlamaRMSNorm(config, megatron_config)
451
+ else:
452
+ self.norm = None
453
+
454
+ def set_input_tensor(self, input_tensor):
455
+ """Set input tensor to be used instead of forward()'s input.
456
+
457
+ When doing pipeline parallelism the input from the previous
458
+ stage comes from communication, not from the input, so the
459
+ model's forward_step_func won't have it. This function is thus
460
+ used by internal code to bypass the input provided by the
461
+ forward_step_func"""
462
+ self.input_tensor = input_tensor
463
+
464
+ def forward(
465
+ self,
466
+ input_ids: torch.Tensor,
467
+ position_ids: Optional[torch.LongTensor] = None,
468
+ sequence_length: int = None,
469
+ indices: torch.Tensor = None,
470
+ cu_seqlens: int = None,
471
+ max_seqlen_in_batch: int = None,
472
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
473
+ """
474
+
475
+ Args:
476
+ input_ids: input ids. shape (1, totol_nnz)
477
+ position_ids: position ids. shape (batch_size, seq_length)
478
+
479
+ Returns:
480
+
481
+ """
482
+ if self.pre_process:
483
+ inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
484
+
485
+ # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron
486
+ # so need to deal with it by handle here:
487
+ # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
488
+ inputs_embeds = inputs_embeds.transpose(0, 1)
489
+ if self.megatron_config.sequence_parallel:
490
+ inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
491
+
492
+ hidden_states = inputs_embeds
493
+ else:
494
+ # self.hidden_states should be passed by Megatron
495
+ hidden_states = self.input_tensor
496
+
497
+ for idx, decoder_layer in enumerate(self.layers):
498
+ layer_outputs = decoder_layer(
499
+ hidden_states,
500
+ position_ids=position_ids,
501
+ sequence_length=sequence_length,
502
+ indices=indices,
503
+ cu_seqlens=cu_seqlens,
504
+ max_seqlen_in_batch=max_seqlen_in_batch,
505
+ )
506
+
507
+ hidden_states = layer_outputs
508
+
509
+ if self.post_process:
510
+ hidden_states = self.norm(hidden_states)
511
+
512
+ return hidden_states
513
+
514
+
515
+ class ParallelLlamaForCausalLMRmPadPP(nn.Module):
516
+ def __init__(
517
+ self,
518
+ config: LlamaConfig,
519
+ megatron_config: ModelParallelConfig,
520
+ pre_process,
521
+ post_process,
522
+ share_embeddings_and_output_weights=False,
523
+ ):
524
+ super().__init__()
525
+ self.config: TransformerConfig = convert_config(config, megatron_config)
526
+ self.megatron_config = megatron_config
527
+ self.model = ParallelLlamaModelRmPadPP(config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process)
528
+ assert share_embeddings_and_output_weights is False, "Llama Model not supports sharing embedding and output weights"
529
+ self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
530
+ self.vocab_size = config.vocab_size
531
+ self.pre_process = pre_process
532
+ self.post_process = post_process
533
+ if post_process:
534
+ self._init_head(config)
535
+
536
+ def set_input_tensor(self, input_tensor):
537
+ """Set input tensor to be used instead of forward()'s input.
538
+
539
+ When doing pipeline parallelism the input from the previous
540
+ stage comes from communication, not from the input, so the
541
+ model's forward_step_func won't have it. This function is thus
542
+ used by internal code to bypass the input provided by the
543
+ forward_step_func"""
544
+ assert len(input_tensor) == 1
545
+ self.model.set_input_tensor(input_tensor[0])
546
+
547
+ def _init_head(self, config):
548
+ column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
549
+ if self.megatron_config is not None:
550
+ assert column_kwargs.get("config", False), "must have ModelParallelConfig"
551
+ tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
552
+ self.lm_head = tensor_parallel.ColumnParallelLinear(
553
+ input_size=config.hidden_size,
554
+ output_size=config.vocab_size,
555
+ bias=False,
556
+ gather_output=False,
557
+ skip_bias_add=False,
558
+ **column_kwargs,
559
+ )
560
+
561
+ def _forward_head(self, hidden_states):
562
+ # all_gather from sequence parallel region is performed inside lm_head
563
+ # logits shape before forward_head hidden_states.shape: [4, 32, 4096]
564
+ logits = self.lm_head(hidden_states)[0]
565
+ # logits shape after forward_head logits.shape: [8, 32, 8]
566
+ logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
567
+ return logits
568
+
569
+ def forward(
570
+ self,
571
+ # original input
572
+ *,
573
+ input_ids: torch.LongTensor = None,
574
+ attention_mask: Optional[torch.Tensor] = None,
575
+ position_ids: Optional[torch.LongTensor] = None,
576
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
577
+ r"""
578
+ Args:
579
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
580
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
581
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
582
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
583
+
584
+ Returns:
585
+ ```"""
586
+
587
+ # Note that input_ids, attention_mask and position_ids should be passed to every pp layer.
588
+ # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
589
+ batch_size, sequence_length = input_ids.shape
590
+ # remove padding here
591
+ input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1)
592
+
593
+ # pad input_ids to multiple of tp for all tp ranks
594
+ # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
595
+ if self.megatron_config.sequence_parallel:
596
+ input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)
597
+
598
+ input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad)
599
+
600
+ outputs = self.model(
601
+ input_ids=input_ids_rmpad,
602
+ position_ids=position_ids,
603
+ sequence_length=sequence_length,
604
+ indices=indices,
605
+ cu_seqlens=cu_seqlens,
606
+ max_seqlen_in_batch=max_seqlen_in_batch,
607
+ )
608
+
609
+ if self.post_process:
610
+ hidden_states = outputs
611
+ # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])
612
+ logits = self._forward_head(hidden_states)
613
+ logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16])
614
+
615
+ # remove padding from sequence parallel
616
+ if self.megatron_config.sequence_parallel:
617
+ totol_nnz = cu_seqlens[-1]
618
+ logits = logits[:totol_nnz] # (total_nnz_padded)
619
+ # add removed padding back. If input is already rmpad, we let the caller pad_input
620
+ logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
621
+
622
+ return CausalLMOutputWithPast(
623
+ loss=None,
624
+ logits=logits,
625
+ past_key_values=None,
626
+ hidden_states=None,
627
+ attentions=None,
628
+ )
629
+ else:
630
+ return outputs
631
+
632
+
633
+ class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):
634
+ def _init_head(self, config):
635
+ column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
636
+ if self.megatron_config is not None:
637
+ assert column_kwargs.get("config", False), "must have ModelParallelConfig"
638
+ tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
639
+ self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)
640
+ # lm_head is effectively the same as sequence parallel
641
+ sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
642
+
643
+ def _forward_head(self, hidden_states):
644
+ logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
645
+ logits = logits.float()
646
+ if self.megatron_config.sequence_parallel:
647
+ logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
648
+ return logits
649
+
650
+ def forward(
651
+ self,
652
+ *,
653
+ input_ids: torch.LongTensor = None,
654
+ attention_mask: Optional[torch.Tensor] = None,
655
+ position_ids: Optional[torch.LongTensor] = None,
656
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
657
+ output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
658
+ if self.post_process:
659
+ output.logits = torch.squeeze(output.logits, dim=-1)
660
+ return output
661
+ else:
662
+ return output
verl/models/mcore/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .registry import get_mcore_forward_fn, get_mcore_weight_converter, hf_to_mcore_config, init_mcore_model
17
+
18
+ __all__ = ["hf_to_mcore_config", "init_mcore_model", "get_mcore_forward_fn", "get_mcore_weight_converter"]
verl/models/mcore/config_converter.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # convert huggingface config to mcore transformer config
18
+
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from megatron.core.transformer import MLATransformerConfig, TransformerConfig
23
+ from transformers import PretrainedConfig
24
+
25
+
26
+ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **kwargs) -> TransformerConfig:
27
+ """
28
+ Create a base TransformerConfig with common parameters across different model architectures.
29
+ TODO: (ycl) use dataclass or converter config?
30
+
31
+ Args:
32
+ hf_config: HuggingFace model configuration
33
+ dtype: Data type for the model
34
+ **kwargs: Additional parameters to override defaults
35
+
36
+ Returns:
37
+ TransformerConfig with common parameters
38
+ """
39
+ from megatron.core import parallel_state as mpu
40
+
41
+ # Common parallel state parameters
42
+ overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1
43
+ batch_p2p_comm = False
44
+
45
+ # Base configuration with common parameters
46
+ base_config = {
47
+ # Model architecture parameters
48
+ "num_layers": hf_config.num_hidden_layers,
49
+ "hidden_size": hf_config.hidden_size,
50
+ "num_attention_heads": hf_config.num_attention_heads,
51
+ "num_query_groups": hf_config.num_key_value_heads,
52
+ "ffn_hidden_size": hf_config.intermediate_size,
53
+ "attention_dropout": hf_config.attention_dropout,
54
+ "hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0),
55
+ "kv_channels": getattr(hf_config, "head_dim", None),
56
+ "layernorm_epsilon": hf_config.rms_norm_eps,
57
+ # Activation and normalization
58
+ "activation_func": F.silu,
59
+ "normalization": "RMSNorm",
60
+ "gated_linear_unit": True,
61
+ # Data types
62
+ "pipeline_dtype": dtype,
63
+ "params_dtype": dtype,
64
+ "bf16": dtype is torch.bfloat16,
65
+ # Parallel configuration
66
+ "tensor_model_parallel_size": mpu.get_tensor_model_parallel_world_size(),
67
+ "pipeline_model_parallel_size": mpu.get_pipeline_model_parallel_world_size(),
68
+ "virtual_pipeline_model_parallel_size": mpu.get_virtual_pipeline_model_parallel_world_size(),
69
+ "context_parallel_size": mpu.get_context_parallel_world_size(),
70
+ "overlap_p2p_comm": overlap_p2p_comm,
71
+ "batch_p2p_comm": batch_p2p_comm,
72
+ "sequence_parallel": mpu.get_tensor_model_parallel_world_size() > 1,
73
+ # Common settings
74
+ "variable_seq_lengths": True,
75
+ "masked_softmax_fusion": True,
76
+ "moe_token_dispatcher_type": "alltoall",
77
+ }
78
+
79
+ # Update with any provided overrides
80
+ base_config.update(kwargs)
81
+ print(f"Overridden TF init config: {base_config}")
82
+
83
+ return TransformerConfig(**base_config)
84
+
85
+
86
+ def hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
87
+ # for LlamaForCausalLM or Qwen2ForCausalLM
88
+ qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False)
89
+ qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False
90
+
91
+ return _get_base_transformer_config(
92
+ hf_config=hf_config,
93
+ dtype=dtype,
94
+ use_cpu_initialization=False,
95
+ add_bias_linear=False,
96
+ add_qkv_bias=qkv_bias,
97
+ qk_layernorm=qk_layernorm,
98
+ )
99
+
100
+
101
+ def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
102
+ return _get_base_transformer_config(
103
+ hf_config=hf_config,
104
+ dtype=dtype,
105
+ use_cpu_initialization=False,
106
+ add_bias_linear=False,
107
+ layernorm_epsilon=hf_config.rms_norm_eps,
108
+ # MoE specific
109
+ moe_ffn_hidden_size=hf_config.moe_intermediate_size,
110
+ moe_router_bias_update_rate=0.001,
111
+ moe_router_topk=hf_config.num_experts_per_tok,
112
+ num_moe_experts=hf_config.num_experts,
113
+ moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size,
114
+ moe_aux_loss_coeff=hf_config.router_aux_loss_coef,
115
+ # moe_aux_loss_coeff=0.0,
116
+ moe_router_load_balancing_type="aux_loss",
117
+ moe_shared_expert_overlap=True,
118
+ moe_grouped_gemm=True,
119
+ moe_router_score_function="softmax",
120
+ # Other optimizations
121
+ persist_layer_norm=True,
122
+ bias_activation_fusion=True,
123
+ bias_dropout_fusion=True,
124
+ # Qwen specific
125
+ moe_router_pre_softmax=True,
126
+ add_qkv_bias=True,
127
+ )
128
+
129
+
130
+ def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
131
+ return _get_base_transformer_config(
132
+ hf_config=hf_config,
133
+ dtype=dtype,
134
+ use_cpu_initialization=False,
135
+ add_bias_linear=False,
136
+ layernorm_epsilon=hf_config.rms_norm_eps,
137
+ # MoE specific
138
+ num_moe_experts=hf_config.num_local_experts,
139
+ moe_aux_loss_coeff=hf_config.router_aux_loss_coef,
140
+ moe_router_topk=hf_config.num_experts_per_tok,
141
+ moe_router_pre_softmax=True,
142
+ moe_router_load_balancing_type="aux_loss",
143
+ moe_router_score_function="softmax",
144
+ moe_shared_expert_intermediate_size=None, # mixtral has no shared expert
145
+ moe_shared_expert_overlap=False, # mixtral has no shared expert
146
+ moe_ffn_hidden_size=hf_config.intermediate_size,
147
+ moe_router_bias_update_rate=0.001,
148
+ # moe_permute_fusion=True, # need TE 2.1+
149
+ moe_grouped_gemm=True,
150
+ # Other optimizations
151
+ persist_layer_norm=True,
152
+ apply_rope_fusion=True,
153
+ bias_activation_fusion=True,
154
+ bias_dropout_fusion=True,
155
+ )
156
+
157
+
158
+ def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
159
+ return _get_base_transformer_config(
160
+ hf_config=hf_config,
161
+ dtype=dtype,
162
+ use_cpu_initialization=False,
163
+ add_bias_linear=False,
164
+ layernorm_epsilon=hf_config.rms_norm_eps,
165
+ # MoE specific
166
+ moe_ffn_hidden_size=hf_config.moe_intermediate_size,
167
+ moe_router_bias_update_rate=0.001,
168
+ moe_router_topk=hf_config.num_experts_per_tok,
169
+ num_moe_experts=hf_config.num_experts,
170
+ moe_aux_loss_coeff=hf_config.router_aux_loss_coef,
171
+ # moe_aux_loss_coeff=0.0,
172
+ moe_router_load_balancing_type="aux_loss",
173
+ moe_grouped_gemm=True,
174
+ moe_router_score_function="softmax",
175
+ # Other optimizations
176
+ persist_layer_norm=True,
177
+ bias_activation_fusion=True,
178
+ bias_dropout_fusion=True,
179
+ # Qwen specific
180
+ moe_router_pre_softmax=True,
181
+ qk_layernorm=True,
182
+ )
183
+
184
+
185
+ def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype) -> MLATransformerConfig:
186
+ # DeepseekV3ForCausalLM
187
+ raise NotImplementedError("DeepseekV3ForCausalLM is not supported yet")
188
+
189
+
190
+ def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
191
+ # Qwen2_5_VLForConditionalGeneration
192
+ raise NotImplementedError("Qwen2_5_VLForConditionalGeneration is not supported yet")
193
+
194
+
195
+ def hf_to_mcore_config_llama4(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
196
+ # Llama4ForConditionalGeneration
197
+ raise NotImplementedError("Llama4ForConditionalGeneration is not supported yet")
verl/models/mcore/loader.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import time
17
+
18
+ import torch
19
+ import torch.distributed as dist
20
+
21
+ from .saver import _megatron_calc_global_rank
22
+
23
+
24
+ def _megatron_calc_layer_map(config):
25
+ """Calculate the mapping of global layer_idx to local layer_idx
26
+ Returns:
27
+ layer_map (Dict: int -> tuple(int, int, int)):
28
+ mapping from the global layer index to
29
+ a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
30
+ """
31
+ from megatron.core import mpu
32
+
33
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
34
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
35
+
36
+ layer_map = dict()
37
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
38
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
39
+
40
+ for pp_rank_idx in range(pp_size):
41
+ for virtual_pp_rank_idx in range(virtual_pp_size):
42
+ layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
43
+ for layer_idx in range(num_layers_per_model):
44
+ layer_map[layer_offset + layer_idx] = (
45
+ pp_rank_idx,
46
+ virtual_pp_rank_idx,
47
+ layer_idx,
48
+ )
49
+ return layer_map
50
+
51
+
52
+ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False):
53
+ """Load merged state_dict to sharded Megatron module in training."""
54
+ from megatron.core import DistributedDataParallel as LocalDDP
55
+ from megatron.core import mpu
56
+ from megatron.core.transformer.module import Float16Module
57
+ from torch.nn.parallel import DistributedDataParallel as torchDDP
58
+
59
+ from verl.utils.megatron_utils import print_rank_0, unwrap_model
60
+
61
+ start_time = time.time()
62
+
63
+ def _get_gpt_model(model):
64
+ return model
65
+
66
+ def broadcast_params(module):
67
+ for param in module.parameters():
68
+ torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())
69
+
70
+ dp_rank = mpu.get_data_parallel_rank()
71
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
72
+ cp_rank = mpu.get_context_parallel_rank()
73
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank)
74
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
75
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
76
+ mp_group = mpu.get_model_parallel_group()
77
+
78
+ if torch.distributed.get_rank() == src_rank:
79
+ assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
80
+ assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
81
+ assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
82
+
83
+ if not isinstance(wrapped_models, (list, tuple)):
84
+ wrapped_models = list(wrapped_models)
85
+
86
+ assert len(wrapped_models) == virtual_pp_size
87
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
88
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
89
+
90
+ models = [None] * len(wrapped_models)
91
+
92
+ for i, wrapped_model in enumerate(wrapped_models):
93
+ models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
94
+ gpt_model_module = _get_gpt_model(models[i])
95
+ assert len(gpt_model_module.decoder.layers) == num_layers_per_model
96
+
97
+ def _broadcast_tensor(tensor, name) -> torch.Tensor:
98
+ """broadcast tensor from rank0 across mp_group"""
99
+ nonlocal state_dict
100
+ nonlocal mp_group
101
+ if torch.distributed.get_rank() == src_rank:
102
+ if name in state_dict:
103
+ weight = state_dict[name]
104
+ tensor_shape = weight.shape
105
+ else:
106
+ tensor_shape = None
107
+ else:
108
+ weight = None
109
+ tensor_shape = None
110
+
111
+ obj_list = [tensor_shape]
112
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
113
+ tensor_shape = obj_list[0]
114
+
115
+ if tensor_shape is None:
116
+ # all or none ranks in the mp_group should reach here
117
+ print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
118
+ return
119
+
120
+ if tensor is None:
121
+ tensor = torch.empty(
122
+ tensor_shape,
123
+ dtype=params_dtype,
124
+ device=torch.cuda.current_device(),
125
+ requires_grad=False,
126
+ )
127
+ if torch.distributed.get_rank() == src_rank:
128
+ tensor.data.copy_(weight)
129
+ dist.broadcast(tensor, src=src_rank, group=mp_group)
130
+
131
+ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
132
+ """broadcast tensor in tp shards across mp_group"""
133
+ nonlocal state_dict
134
+ nonlocal mp_group
135
+ tp_rank = mpu.get_tensor_model_parallel_rank()
136
+ tp_size = mpu.get_tensor_model_parallel_world_size()
137
+
138
+ if torch.distributed.get_rank() == src_rank:
139
+ if name in state_dict:
140
+ full_weight = state_dict[name]
141
+
142
+ if mutate_func is not None:
143
+ full_weight = mutate_func(full_weight)
144
+ tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
145
+ chunk_shape = tensor_chunk[0].shape
146
+ else:
147
+ chunk_shape = None
148
+ else:
149
+ chunk_shape = None
150
+
151
+ obj_list = [chunk_shape]
152
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
153
+ chunk_shape = obj_list[0]
154
+ if chunk_shape is None:
155
+ # all or none ranks in the mp_group should reach here
156
+ print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
157
+ return
158
+
159
+ if tensor is None:
160
+ sync_tensor = torch.empty(
161
+ chunk_shape,
162
+ dtype=params_dtype,
163
+ device=torch.cuda.current_device(),
164
+ requires_grad=False,
165
+ )
166
+ else:
167
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
168
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
169
+
170
+ for i in range(tp_size):
171
+ if torch.distributed.get_rank() == src_rank:
172
+ sync_tensor.data.copy_(tensor_chunk[i])
173
+ dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
174
+ if (i == tp_rank) and (tensor is not None):
175
+ tensor.data.copy_(sync_tensor)
176
+
177
+ def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
178
+ """broadcast tensor in tp shards across mp_group"""
179
+ nonlocal state_dict
180
+ nonlocal mp_group
181
+ tp_rank = mpu.get_tensor_model_parallel_rank()
182
+ tp_size = mpu.get_tensor_model_parallel_world_size()
183
+
184
+ if torch.distributed.get_rank() == src_rank:
185
+ if name in state_dict:
186
+ full_weight = state_dict[name]
187
+ if mutate_func is not None:
188
+ full_weight = mutate_func(full_weight)
189
+ tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
190
+ chunk_shape = tensor_chunk[0].shape
191
+ else:
192
+ chunk_shape = None
193
+ else:
194
+ chunk_shape = None
195
+
196
+ obj_list = [chunk_shape]
197
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
198
+ chunk_shape = obj_list[0]
199
+ if chunk_shape is None:
200
+ # all or none ranks in the mp_group should reach here
201
+ print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
202
+ return
203
+
204
+ if tensor is None:
205
+ sync_tensor = torch.empty(
206
+ chunk_shape,
207
+ dtype=params_dtype,
208
+ device=torch.cuda.current_device(),
209
+ requires_grad=False,
210
+ )
211
+ else:
212
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
213
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
214
+
215
+ for i in range(tp_size):
216
+ if torch.distributed.get_rank() == src_rank:
217
+ sync_tensor.data.copy_(tensor_chunk[i])
218
+ dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
219
+ if (i == tp_rank) and (tensor is not None):
220
+ tensor.data.copy_(sync_tensor)
221
+
222
+ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
223
+ """broadcast tensor in tp shards across mp_group"""
224
+ nonlocal state_dict
225
+ nonlocal mp_group
226
+ tp_rank = mpu.get_tensor_model_parallel_rank()
227
+ tp_size = mpu.get_tensor_model_parallel_world_size()
228
+
229
+ if torch.distributed.get_rank() == src_rank:
230
+ gate_weight = state_dict[gate_name]
231
+ up_weight = state_dict[up_name]
232
+ new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
233
+ for i in range(tp_size):
234
+ intermediate_size_tp = config.intermediate_size // tp_size
235
+ gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
236
+ up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
237
+ new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))
238
+
239
+ tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
240
+ chunk_shape = tensor_chunk[0].shape
241
+ else:
242
+ chunk_shape = None
243
+
244
+ obj_list = [chunk_shape]
245
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
246
+ chunk_shape = obj_list[0]
247
+ if chunk_shape is None:
248
+ # all or none ranks in the mp_group should reach here
249
+ print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
250
+ return
251
+
252
+ if tensor is None:
253
+ sync_tensor = torch.empty(
254
+ chunk_shape,
255
+ dtype=params_dtype,
256
+ device=torch.cuda.current_device(),
257
+ requires_grad=False,
258
+ )
259
+ else:
260
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
261
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
262
+
263
+ for i in range(tp_size):
264
+ if torch.distributed.get_rank() == src_rank:
265
+ sync_tensor.data.copy_(tensor_chunk[i])
266
+ dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
267
+ if (i == tp_rank) and (tensor is not None):
268
+ tensor.data.copy_(sync_tensor)
269
+
270
+ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:
271
+ """broadcast tensor in tp shards across mp_group"""
272
+ nonlocal state_dict
273
+ nonlocal mp_group
274
+ tp_rank = mpu.get_tensor_model_parallel_rank()
275
+ tp_size = mpu.get_tensor_model_parallel_world_size()
276
+
277
+ if torch.distributed.get_rank() == src_rank:
278
+ assert q_name in state_dict and k_name in state_dict and v_name in state_dict
279
+ full_weight_q = state_dict[q_name]
280
+ full_weight_k = state_dict[k_name]
281
+ full_weight_v = state_dict[v_name]
282
+
283
+ hidden_size_per_head = config.hidden_size // config.num_attention_heads
284
+
285
+ if config.num_key_value_heads >= tp_size:
286
+ q_size_tp = config.hidden_size // tp_size
287
+ kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
288
+ total_size = q_size_tp + 2 * kv_size_tp
289
+ sizes = [total_size * tp_size]
290
+ if not bias:
291
+ sizes.append(config.hidden_size)
292
+ new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device())
293
+ for i in range(tp_size):
294
+ q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
295
+ k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
296
+ v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
297
+ num_query_groups_per_partition = models[0].config.num_query_groups // tp_size
298
+ new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]
299
+ q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0)
300
+ k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0)
301
+ v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0)
302
+ total_size_per_head = total_size // num_query_groups_per_partition
303
+ for j in range(num_query_groups_per_partition):
304
+ new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0))
305
+
306
+ else:
307
+ q_size_tp = config.hidden_size // tp_size
308
+ kv_size_tp = hidden_size_per_head
309
+ total_size = q_size_tp + 2 * kv_size_tp
310
+ sizes = [total_size * tp_size]
311
+ if not bias:
312
+ sizes.append(config.hidden_size)
313
+ new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device())
314
+ for i in range(tp_size):
315
+ q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
316
+ start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
317
+ end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
318
+ k_part = full_weight_k[start_idx:end_idx]
319
+ v_part = full_weight_v[start_idx:end_idx]
320
+ new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]
321
+ q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0)
322
+ k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0)
323
+ v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0)
324
+ total_size_per_head = total_size // config.num_attention_heads
325
+ for j in range(config.num_attention_heads):
326
+ new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0))
327
+
328
+ tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
329
+ chunk_shape = tensor_chunk[0].shape
330
+ else:
331
+ chunk_shape = None
332
+
333
+ obj_list = [chunk_shape]
334
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
335
+ chunk_shape = obj_list[0]
336
+ if chunk_shape is None:
337
+ # all or none ranks in the mp_group should reach here
338
+ print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading")
339
+ return
340
+
341
+ if tensor is None:
342
+ sync_tensor = torch.empty(
343
+ chunk_shape,
344
+ dtype=params_dtype,
345
+ device=torch.cuda.current_device(),
346
+ requires_grad=False,
347
+ )
348
+ else:
349
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
350
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
351
+
352
+ for i in range(tp_size):
353
+ if torch.distributed.get_rank() == src_rank:
354
+ sync_tensor.data.copy_(tensor_chunk[i])
355
+ dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
356
+ if (i == tp_rank) and (tensor is not None):
357
+ tensor.data.copy_(sync_tensor)
358
+
359
+ if dp_rank == 0:
360
+ # Embeddings
361
+ # -------------------
362
+ print_rank_0("loading embeddings...")
363
+ gpt_model_module = _get_gpt_model(models[0])
364
+ embed_tokens_weight = None
365
+ if pp_rank == 0:
366
+ embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight
367
+ _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
368
+
369
+ # Transformer layers
370
+ # -------------------
371
+ layer_map = _megatron_calc_layer_map(config)
372
+
373
+ for layer in range(config.num_hidden_layers):
374
+ print_rank_0(f"loading layer #{layer}...")
375
+ layer_name = f"model.layers.{layer}"
376
+ dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
377
+
378
+ gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
379
+ sync_layer = gpt_model_module.decoder.layers[dst_layer_idx]
380
+
381
+ _broadcast_tensor(
382
+ sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None,
383
+ f"{layer_name}.input_layernorm.weight",
384
+ )
385
+
386
+ if f"{layer_name}.self_attn.q_norm.weight" in state_dict:
387
+ _broadcast_tensor(
388
+ sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None,
389
+ f"{layer_name}.self_attn.q_norm.weight",
390
+ )
391
+ _broadcast_tensor(
392
+ sync_layer.self_attention.k_layernorm.weight if dst_pp_rank == pp_rank else None,
393
+ f"{layer_name}.self_attn.k_norm.weight",
394
+ )
395
+
396
+ _broadcast_tp_shard_tensor_qkv(
397
+ sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None,
398
+ f"{layer_name}.self_attn.q_proj.weight",
399
+ f"{layer_name}.self_attn.k_proj.weight",
400
+ f"{layer_name}.self_attn.v_proj.weight",
401
+ )
402
+ if f"{layer_name}.self_attn.q_proj.bias" in state_dict:
403
+ _broadcast_tp_shard_tensor_qkv(
404
+ sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None,
405
+ f"{layer_name}.self_attn.q_proj.bias",
406
+ f"{layer_name}.self_attn.k_proj.bias",
407
+ f"{layer_name}.self_attn.v_proj.bias",
408
+ bias=True,
409
+ )
410
+
411
+ _broadcast_tp_shard_tensor(
412
+ sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None,
413
+ f"{layer_name}.self_attn.o_proj.weight",
414
+ chunk_dim=1,
415
+ )
416
+ _broadcast_tensor(
417
+ sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None,
418
+ f"{layer_name}.post_attention_layernorm.weight",
419
+ )
420
+
421
+ _broadcast_tp_shard_tensor_gate_up(
422
+ sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None,
423
+ f"{layer_name}.mlp.gate_proj.weight",
424
+ f"{layer_name}.mlp.up_proj.weight",
425
+ )
426
+
427
+ _broadcast_tp_shard_tensor(
428
+ sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None,
429
+ f"{layer_name}.mlp.down_proj.weight",
430
+ chunk_dim=1,
431
+ )
432
+ # Final Layernorm
433
+ # -------------------
434
+ print_rank_0("loading final layernorm...")
435
+ gpt_model_module = _get_gpt_model(models[-1])
436
+ _broadcast_tensor(
437
+ getattr(gpt_model_module.decoder.final_layernorm, "weight", None),
438
+ "model.norm.weight",
439
+ )
440
+
441
+ print_rank_0("loading lm_head...")
442
+ lm_head_weight = None
443
+ if pp_rank + 1 == pp_size:
444
+ lm_head_weight = gpt_model_module.output_layer.weight
445
+
446
+ if is_value_model:
447
+ # if torch.distributed.get_rank() == src_rank:
448
+ if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
449
+ _broadcast_tensor(lm_head_weight, "lm_head.weight")
450
+ elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
451
+ _broadcast_tensor(lm_head_weight, "reward_head.weight")
452
+ print_rank_0("load lm_head from value_head weight")
453
+ else:
454
+ _broadcast_tensor(None, "lm_head.weight")
455
+ print_rank_0("fail to match lm_head in value_model")
456
+ # else:
457
+
458
+ # _broadcast_tensor(lm_head_weight, "lm_head.weight")
459
+
460
+ else:
461
+ _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
462
+ dist.barrier()
463
+ # Broadcast weights inside data parallel groups
464
+ for wrapped_model in wrapped_models:
465
+ broadcast_params(wrapped_model)
466
+ pass
467
+ torch.cuda.empty_cache()
468
+ print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
verl/models/mcore/model_forward.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from verl.utils.megatron_utils import unwrap_model
18
+
19
+ from .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding
20
+
21
+
22
+ def gptmodel_forward(model, input_ids, attention_mask, position_ids, sequence_parallel, value_model=False, pack_seqs=True):
23
+ """Default forward pass for GPT models with optional sequence packing."""
24
+ pre_process = unwrap_model(model).pre_process
25
+ post_process = unwrap_model(model).post_process
26
+ if pack_seqs:
27
+ batch_size, seq_len = attention_mask.shape[:2]
28
+ input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)
29
+ input_ids_rmpad = input_ids_rmpad.contiguous()
30
+ output_orig = model(
31
+ input_ids=input_ids_rmpad,
32
+ attention_mask=None,
33
+ position_ids=position_ids,
34
+ packed_seq_params=packed_seq_params,
35
+ )
36
+
37
+ output = postprocess_packed_seqs(output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process)
38
+ else:
39
+ batch_size, sequence_length = attention_mask.shape
40
+ new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process)
41
+ output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids)
42
+ output = recover_left_padding(output, new_attention_mask, attention_mask, sequence_length, post_process=post_process)
43
+ if value_model and post_process:
44
+ output = output[..., 0]
45
+ return output
46
+
47
+
48
+ def gptmodel_forward_qwen2_5_vl(*args, **kwargs):
49
+ """Forward pass for Qwen2.5 VL model (not implemented)."""
50
+ raise NotImplementedError("VLM is not supported yet")
verl/models/mcore/model_initializer.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # use mcore transformer config to initialize the model
18
+ from abc import ABC, abstractmethod
19
+
20
+ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
21
+ from megatron.core.models.gpt.gpt_model import GPTModel
22
+
23
+ from .config_converter import PretrainedConfig, TransformerConfig
24
+
25
+
26
+ class BaseModelInitializer(ABC):
27
+ """Base class for model initializers."""
28
+
29
+ def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig):
30
+ self.tfconfig = tfconfig
31
+ self.hf_config = hf_config
32
+
33
+ @abstractmethod
34
+ def get_transformer_layer_spec(self):
35
+ """Get the transformer layer specification.
36
+ https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py"""
37
+ pass
38
+
39
+ def get_rope_scaling_args(self) -> dict:
40
+ """Get rope scaling args."""
41
+ rope_scaling_args = {}
42
+ if "rope_scaling" in self.hf_config:
43
+ if self.hf_config.rope_scaling is not None:
44
+ assert self.hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now"
45
+ rope_scaling_args["seq_len_interpolation_factor"] = self.hf_config.rope_scaling["factor"]
46
+ return rope_scaling_args
47
+
48
+ def initialize(
49
+ self,
50
+ pre_process: bool = True,
51
+ post_process: bool = True,
52
+ share_embeddings_and_output_weights: bool = False,
53
+ value: bool = False,
54
+ **extra_kwargs,
55
+ ) -> GPTModel:
56
+ """Initialize a GPT model with the given configuration.
57
+ https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py
58
+
59
+ Args:
60
+ pre_process (bool): include embedding layer.
61
+ post_process (bool): including an output layer.
62
+ share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared.
63
+ value (bool): add an extra linear layer for classification or regression.
64
+
65
+ Returns:
66
+ GPTModel: An initialized GPT model instance
67
+ """
68
+ transformer_layer_spec = self.get_transformer_layer_spec()
69
+ rope_scaling_args = self.get_rope_scaling_args()
70
+
71
+ model = GPTModel(
72
+ config=self.tfconfig,
73
+ transformer_layer_spec=transformer_layer_spec,
74
+ vocab_size=self.hf_config.vocab_size,
75
+ max_sequence_length=self.hf_config.max_position_embeddings,
76
+ pre_process=pre_process,
77
+ post_process=post_process,
78
+ share_embeddings_and_output_weights=share_embeddings_and_output_weights,
79
+ position_embedding_type="rope",
80
+ rotary_base=self.hf_config.rope_theta,
81
+ **rope_scaling_args,
82
+ )
83
+
84
+ if post_process and value:
85
+ from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer
86
+
87
+ model.output_layer = LinearForLastLayer(input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig)
88
+
89
+ return model
90
+
91
+
92
+ class DenseModel(BaseModelInitializer):
93
+ """Initializer for dense models like Llama and Qwen2."""
94
+
95
+ def get_transformer_layer_spec(self):
96
+ assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
97
+ return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
98
+
99
+
100
+ class Qwen2MoEModel(BaseModelInitializer):
101
+ """Initializer for Qwen2 MoE models."""
102
+
103
+ def get_transformer_layer_spec(self):
104
+ assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
105
+ transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
106
+
107
+ # Patch layer spec for shared experts
108
+ for i in range(len(transformer_layer_spec.layer_specs)):
109
+ transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params["gate"] = True
110
+
111
+ return transformer_layer_spec
112
+
113
+ def initialize(self, freeze_moe_router: bool = True, **kwargs):
114
+ # Qwen default freeze_moe_router: true
115
+ model = super().initialize(**kwargs)
116
+ if freeze_moe_router:
117
+ for layer in model.decoder.layers:
118
+ layer.mlp.router.weight.requires_grad = False
119
+ layer.mlp.shared_experts.gate_weight.requires_grad = False
120
+ return model
121
+
122
+
123
+ class MixtralModel(BaseModelInitializer):
124
+ """Initializer for Mixtral models."""
125
+
126
+ def get_transformer_layer_spec(self):
127
+ assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
128
+ transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
129
+ return transformer_layer_spec
130
+
131
+ def initialize(self, freeze_moe_router: bool = False, **kwargs):
132
+ model = super().initialize(**kwargs)
133
+ if freeze_moe_router:
134
+ for layer in model.decoder.layers:
135
+ layer.mlp.router.weight.requires_grad = False
136
+ return model
137
+
138
+
139
+ class Qwen3MoEModel(BaseModelInitializer):
140
+ """Initializer for Qwen3 MoE models."""
141
+
142
+ def get_transformer_layer_spec(self):
143
+ assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
144
+ transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
145
+ return transformer_layer_spec
146
+
147
+ def initialize(self, freeze_moe_router: bool = True, **kwargs):
148
+ # Qwen default freeze_moe_router: true
149
+ model = super().initialize(**kwargs)
150
+ if freeze_moe_router:
151
+ for layer in model.decoder.layers:
152
+ layer.mlp.router.weight.requires_grad = False
153
+ return model
154
+
155
+
156
+ class Qwen25VLModel(BaseModelInitializer):
157
+ """Initializer for Qwen2.5 VL models."""
158
+
159
+ def get_transformer_layer_spec(self):
160
+ raise NotImplementedError("VLM is not supported yet")
verl/models/mcore/readme.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # verl Megatron-Core Models
2
+ The earlier versions of verl use `Megatron-LM` 0.4 and workaround huggingface model classes. To better use the latest features and speedup of modern Megatron, we are migrating to `Megatron-Core`(mcore), and use the recommended `GPTModel` class for all language models. With mcore `GPTModel`, we can use the latest features like `context parallel`, `expert parallel`, `dist_checkpointing`, etc. and we can update mcore with little effort in the future for new features.
3
+
4
+ The migration has been successful with the help of the mcore team and the community. What we have done is:
5
+ 1. update `Megatron` version to `0.11.0`
6
+ 2. migrate `LlamaForCausalLM` and `Qwen2ForCausalLM` to mcore `GPTModel`
7
+ 3. support sequence packing/thd format.
8
+ 4. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`.
9
+ 5. support the mcore `dist_checkpointing` feature and a basic offline weighs conversion scipt from huggingface to mcore `dist_checkpointing` format.
10
+
11
+ We are working on the following features:
12
+ - support `Qwen2MoeForCausalLM`
13
+ - support `MixtralForCausalLM`
14
+ - support `DeepseekV3ForCausalLM`
15
+ - support `expert parallel`
16
+
17
+ Features we invite the community to contribute:
18
+ - better scipts for offline weights conversion from huggingface to mcore `dist_checkpointing` format.
19
+ - conversion of large models with multiple GPUs
20
+ - conversion of large models with single GPU
21
+ - refactor the `megatron_checkpoint_manager.py` by `dist_checkpointing` format.
22
+ - support llama4
23
+ - support qwen2.5-vl
24
+
25
+ To track the progress of verl mcore integration, please refer to the [mcore integration issue](https://github.com/volcengine/verl/issues/1033).
26
+
27
+ ## How things work now
28
+ To engage the community in contributing, here are the key steps in our mcore integration process and features under development.
29
+
30
+ The huggingface `transformers` is the de facto standard of model zoo while mcore is good at computation efficiency. The main challenge is conversion between the two.
31
+ main steps:
32
+ 1. modelling the huggingface model with mcore `GPTModel`
33
+ - a. convert the huggingface config to mcore `TransformerConfig`
34
+ - b. init the mcore `GPTModel` with the converted config
35
+ - c. load the huggingface model weights to the `GPTModel`
36
+ 2. online weight conversion from mcore to huggingface (due the the rollout engine `vLLM` is using huggingface format)
37
+ - a. bridge the gap between mcore and huggingface weights format and name mapping
38
+ - b. online resharding the mcore weights to rollout engine
39
+ - this part is very complicated with multiple parallel strategies composition between mcore and rollout engine
40
+ 3. support the mcore features in verl
41
+ - a. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`
42
+ - b. support recompute and other mcore speed up features
43
+
44
+ 4. checkpointing
45
+ - a. support recovering the verl training.
46
+ - b. support exporting the mcore checkpoint to huggingface format, for downstream inference.
47
+
48
+ ### Modelling the huggingface model with mcore `GPTModel`
49
+ The first step is to convert huggingface config to mcore `TransformerConfig` and init the mcore `GPTModel` with the converted config. See code in `verl/models/mcore/config_converter.py` and `verl/verl/models/mcore/models/model_initializer.py`. The corresponding model forward code is in `verl/verl/models/mcore/models/model_forward.py`.
50
+
51
+ There are two ways of loading the huggingface model weights to the `GPTModel`
52
+ 1. Runtime loading
53
+ - every rank loads the entire huggingface model weights and then shard and convert to mcore weights.
54
+ - speed is slow and memory consumption is high.
55
+ - this way is deprecated and will not support new models.
56
+ 2. Offline loading
57
+ - use offline script to convert the huggingface model weights to mcore weights and save with mcore `dist_checkpointing` format.
58
+ - online loading and sharding is automatically done by mcore `dist_checkpointing` format. The speed is fast and memory consumption is low.
59
+ - the offline script is in `verl/scripts/converter_hf_to_mcore.py`.
60
+
61
+ ### online weight conversion from mcore to huggingface
62
+ See function `convert_megatron_model_to_transformers_model` in `verl/utils/megatron_utils.py` for the details.
63
+
64
+ It should be refatored for extensibility and better performance.
65
+
66
+ ### support the mcore features in verl
67
+ Most of the features of `GPTModel` is out-of-the-box supported in verl through changing the `TransformerConfig`, except those about parallel strategies, such as `expert parallel`.
68
+ Features about parallel strategies should be supported with changes about the online weights conversion(especially the resharding part) and verl work dispatching.
69
+
70
+ ### checkpointing
71
+ The existing checkpointing code is in `verl/utils/checkpoint/megatron_checkpoint_manager.py`. And the script to convert checkpoint to huggingface format is in `verl/scripts/model_merger.py`.
72
+
73
+ The existing checkpoint format is simplely save every rank's weights and optimizer states. It should be refactored by `dist_checkpointing` format.
74
+
75
+
76
+ ## How to support new models
77
+ 1. make sure the model is supported by vLLM
78
+ 2. modelling the huggingface model with mcore `GPTModel` (The [Pai-Megatron-Path](https://github.com/alibaba/Pai-Megatron-Patch/tree/main) is a good reference)
79
+ - a. convert the huggingface config to mcore `TransformerConfig`
80
+ - b. init the mcore `GPTModel` with the converted config
81
+ - c. load the huggingface model weights to the `GPTModel`
82
+ - d. for VLM the interface might be different, it is ok to add a new model class with GPTModel as its module.
83
+ 3. offline weights conversion from huggingface to mcore `dist_checkpointing` format
84
+ 4. support online weights conversion from mcore to huggingface
85
+ - it is recommended to initilize a vLLM model with the converted mcore weights, and then test if the generating sequence is correct.
86
+
87
+
88
+ ## How to scale up to larger models like deepseek-v3 or other 100B+ models
89
+ The greatest challenge for scaling up to larger models is the memory consumption.
90
+
91
+ The necessary features under development for scaling up are
92
+ 1. Training engine part
93
+ - expert parallel
94
+ 2. Rollout engine part
95
+ - pipeline parallel
96
+ - expert parallel
97
+ - more efficient and general weight resharding and loading
98
+ 3. Offline weights conversion
99
+ - support weights larger then single GPU memory
verl/models/mcore/registry.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Registry module for model architecture components.
17
+ """
18
+
19
+ from enum import Enum
20
+ from typing import Callable, Dict, Type
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from .config_converter import (
26
+ PretrainedConfig,
27
+ TransformerConfig,
28
+ hf_to_mcore_config_dense,
29
+ hf_to_mcore_config_dpskv3,
30
+ hf_to_mcore_config_llama4,
31
+ hf_to_mcore_config_mixtral,
32
+ hf_to_mcore_config_qwen2_5_vl,
33
+ hf_to_mcore_config_qwen2moe,
34
+ hf_to_mcore_config_qwen3moe,
35
+ )
36
+ from .model_forward import (
37
+ gptmodel_forward,
38
+ )
39
+ from .model_initializer import (
40
+ BaseModelInitializer,
41
+ DenseModel,
42
+ MixtralModel,
43
+ Qwen2MoEModel,
44
+ Qwen3MoEModel,
45
+ Qwen25VLModel,
46
+ )
47
+ from .weight_converter import (
48
+ McoreToHFWeightConverterDense,
49
+ McoreToHFWeightConverterMixtral,
50
+ McoreToHFWeightConverterQwen2Moe,
51
+ McoreToHFWeightConverterQwen3Moe,
52
+ )
53
+
54
+
55
+ class SupportedModel(Enum):
56
+ LLAMA = "LlamaForCausalLM" # tested
57
+ QWEN2 = "Qwen2ForCausalLM" # tested
58
+ QWEN2_MOE = "Qwen2MoeForCausalLM" # pending
59
+ DEEPSEEK_V3 = "DeepseekV3ForCausalLM" # not tested
60
+ MIXTRAL = "MixtralForCausalLM" # tested
61
+ QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" # not supported
62
+ LLAMA4 = "Llama4ForConditionalGeneration" # not tested
63
+ QWEN3 = "Qwen3ForCausalLM" # tested
64
+ QWEN3_MOE = "Qwen3MoeForCausalLM" # not tested
65
+
66
+
67
+ # Registry for model configuration converters
68
+ MODEL_CONFIG_CONVERTER_REGISTRY: Dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
69
+ SupportedModel.LLAMA: hf_to_mcore_config_dense,
70
+ SupportedModel.QWEN2: hf_to_mcore_config_dense,
71
+ SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,
72
+ SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3,
73
+ SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral,
74
+ SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
75
+ SupportedModel.LLAMA4: hf_to_mcore_config_llama4,
76
+ SupportedModel.QWEN3: hf_to_mcore_config_dense,
77
+ SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,
78
+ }
79
+
80
+ # Registry for model initializers
81
+ MODEL_INITIALIZER_REGISTRY: Dict[SupportedModel, Type[BaseModelInitializer]] = {
82
+ SupportedModel.LLAMA: DenseModel,
83
+ SupportedModel.QWEN2: DenseModel,
84
+ SupportedModel.QWEN2_MOE: Qwen2MoEModel,
85
+ SupportedModel.MIXTRAL: MixtralModel,
86
+ SupportedModel.DEEPSEEK_V3: DenseModel,
87
+ SupportedModel.QWEN2_5_VL: Qwen25VLModel,
88
+ SupportedModel.LLAMA4: DenseModel,
89
+ SupportedModel.QWEN3: DenseModel,
90
+ SupportedModel.QWEN3_MOE: Qwen3MoEModel,
91
+ }
92
+
93
+ # Registry for model forward functions
94
+ MODEL_FORWARD_REGISTRY: Dict[SupportedModel, Callable] = {
95
+ SupportedModel.LLAMA: gptmodel_forward,
96
+ SupportedModel.QWEN2: gptmodel_forward,
97
+ SupportedModel.QWEN2_MOE: gptmodel_forward,
98
+ SupportedModel.MIXTRAL: gptmodel_forward,
99
+ SupportedModel.DEEPSEEK_V3: gptmodel_forward,
100
+ SupportedModel.QWEN2_5_VL: gptmodel_forward,
101
+ SupportedModel.LLAMA4: gptmodel_forward,
102
+ SupportedModel.QWEN3: gptmodel_forward,
103
+ SupportedModel.QWEN3_MOE: gptmodel_forward,
104
+ }
105
+
106
+ # Registry for model weight converters
107
+ MODEL_WEIGHT_CONVERTER_REGISTRY: Dict[SupportedModel, Type] = {
108
+ SupportedModel.LLAMA: McoreToHFWeightConverterDense,
109
+ SupportedModel.QWEN2: McoreToHFWeightConverterDense,
110
+ SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe,
111
+ SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral,
112
+ SupportedModel.QWEN3: McoreToHFWeightConverterDense,
113
+ SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,
114
+ }
115
+
116
+
117
+ def get_supported_model(model_type: str) -> SupportedModel:
118
+ try:
119
+ return SupportedModel(model_type)
120
+ except ValueError as err:
121
+ supported_models = [e.value for e in SupportedModel]
122
+ raise NotImplementedError(f"Model Type: {model_type} not supported. Supported models: {supported_models}") from err
123
+
124
+
125
+ def hf_to_mcore_config(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
126
+ assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
127
+ model = get_supported_model(hf_config.architectures[0])
128
+ return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype)
129
+
130
+
131
+ def init_mcore_model(
132
+ tfconfig: TransformerConfig,
133
+ hf_config: PretrainedConfig,
134
+ pre_process: bool = True,
135
+ post_process: bool = None,
136
+ *,
137
+ share_embeddings_and_output_weights: bool = False,
138
+ value: bool = False,
139
+ **extra_kwargs, # may be used for vlm and moe
140
+ ) -> nn.Module:
141
+ """
142
+ Initialize a Mcore model.
143
+
144
+ Args:
145
+ tfconfig: The transformer config.
146
+ hf_config: The HuggingFace config.
147
+ pre_process: Optional pre-processing function.
148
+ post_process: Optional post-processing function.
149
+ share_embeddings_and_output_weights: Whether to share embeddings and output weights.
150
+ value: Whether to use value.
151
+ **extra_kwargs: Additional keyword arguments.
152
+
153
+ Returns:
154
+ The initialized model.
155
+ """
156
+ assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
157
+ model = get_supported_model(hf_config.architectures[0])
158
+ initializer_cls = MODEL_INITIALIZER_REGISTRY[model]
159
+ initializer = initializer_cls(tfconfig, hf_config)
160
+ return initializer.initialize(pre_process=pre_process, post_process=post_process, share_embeddings_and_output_weights=share_embeddings_and_output_weights, value=value, **extra_kwargs)
161
+
162
+
163
+ def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:
164
+ """
165
+ Get the forward function for given model architecture.
166
+ """
167
+ assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
168
+ model = get_supported_model(hf_config.architectures[0])
169
+ return MODEL_FORWARD_REGISTRY[model]
170
+
171
+
172
+ def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable:
173
+ """
174
+ Get the weight converter for given model architecture.
175
+ """
176
+ assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
177
+ model = get_supported_model(hf_config.architectures[0])
178
+ tfconfig = hf_to_mcore_config(hf_config, dtype)
179
+ return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig)
verl/models/mcore/saver.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import time
17
+
18
+ import torch
19
+ import torch.distributed as dist
20
+ from megatron.core import mpu
21
+ from megatron.core.distributed import DistributedDataParallel as LocalDDP
22
+ from megatron.core.transformer.module import Float16Module
23
+ from torch.nn.parallel import DistributedDataParallel as torchDDP
24
+
25
+ from verl.utils.megatron_utils import print_rank_0, unwrap_model
26
+
27
+
28
+ def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0):
29
+ """Calculate global rank with support for CP/EP parallelism"""
30
+
31
+ # Get parallel sizes for each dimension
32
+ tp_size = mpu.get_tensor_model_parallel_world_size()
33
+ dp_size = mpu.get_data_parallel_world_size()
34
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
35
+ cp_size = mpu.get_context_parallel_world_size()
36
+ # ep_size = mpu.get_expert_model_parallel_world_size()
37
+
38
+ # Verify total GPU count matches (must be consistent with parallel_state.py)
39
+ total_size = tp_size * dp_size * pp_size * cp_size
40
+ assert total_size == torch.distributed.get_world_size(), f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}"
41
+
42
+ # Core calculation logic (corresponds to RankGenerator order parameter)
43
+ # Assumes default order is "tp-cp-ep-dp-pp"
44
+ return ((pp_rank * dp_size + dp_rank) * cp_size + cp_rank) * tp_size + tp_rank
45
+
46
+
47
+ def _megatron_calc_layer_map(config):
48
+ """Calculate the mapping of global layer_idx to local layer_idx
49
+ Returns:
50
+ layer_map (Dict: int -> tuple(int, int, int)):
51
+ mapping from the global layer index to
52
+ a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
53
+ """
54
+ from megatron.core import mpu
55
+
56
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
57
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
58
+
59
+ layer_map = dict()
60
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
61
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
62
+
63
+ for pp_rank_idx in range(pp_size):
64
+ for virtual_pp_rank_idx in range(virtual_pp_size):
65
+ layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
66
+ for layer_idx in range(num_layers_per_model):
67
+ layer_map[layer_offset + layer_idx] = (
68
+ pp_rank_idx,
69
+ virtual_pp_rank_idx,
70
+ layer_idx,
71
+ )
72
+ return layer_map
73
+
74
+
75
+ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
76
+ """Merge sharded parameters of a Megatron module into a merged checkpoint.
77
+
78
+ Args:
79
+ wrapped_models (list of megatron.core.distributed.DistributedDataParallel):
80
+ The local DDP wrapped megatron modules.
81
+ config (str or None):
82
+ HF config for model
83
+ dtype: model params type
84
+ is_value_model: if model is value model
85
+ tie_word_embeddings: tie_word_embeddings
86
+ Returns:
87
+ state_dict (dict):
88
+ The merged state_dict in rank 0, and an empty dictionary in other ranks.
89
+ """
90
+ start_time = time.time()
91
+
92
+ def _get_gpt_model(model):
93
+ return model
94
+
95
+ dp_rank = mpu.get_data_parallel_rank()
96
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
97
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
98
+ cp_rank = mpu.get_context_parallel_rank()
99
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
100
+ mp_group = mpu.get_model_parallel_group()
101
+
102
+ if dist.get_rank() == 0:
103
+ assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
104
+ assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
105
+ assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
106
+
107
+ if not isinstance(wrapped_models, (list, tuple)):
108
+ wrapped_models = list(wrapped_models)
109
+
110
+ assert len(wrapped_models) == virtual_pp_size
111
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
112
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
113
+
114
+ models = [None] * len(wrapped_models)
115
+
116
+ for i, wrapped_model in enumerate(wrapped_models):
117
+ models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
118
+ assert len(models[i].decoder.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].decoder.layers), num_layers_per_model)
119
+
120
+ state_dict = dict()
121
+
122
+ def _get_cpu_tensor(tensor: torch.Tensor):
123
+ if tensor is None:
124
+ return None
125
+ if tensor.device == torch.device("cpu"):
126
+ return tensor.detach().clone()
127
+ return tensor.detach().cpu()
128
+
129
+ def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
130
+ """broadcast tensor across mp_group"""
131
+ nonlocal state_dict
132
+ nonlocal mp_group
133
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
134
+
135
+ if torch.distributed.get_rank() == src_rank:
136
+ if tensor is None:
137
+ weight = None
138
+ tensor_shape = None
139
+ else:
140
+ weight = tensor
141
+ tensor_shape = weight.shape
142
+ else:
143
+ weight = None
144
+ tensor_shape = None
145
+
146
+ obj_list = [tensor_shape]
147
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
148
+ tensor_shape = obj_list[0]
149
+
150
+ if tensor_shape is None:
151
+ # all or none ranks in the mp_group should reach here
152
+ print_rank_0(f"tensor:[{name}] not exist, skip collect")
153
+ return
154
+
155
+ if weight is None:
156
+ weight = torch.empty(
157
+ tensor_shape,
158
+ dtype=dtype,
159
+ device=torch.cuda.current_device(),
160
+ requires_grad=False,
161
+ )
162
+
163
+ dist.broadcast(weight, src=src_rank, group=mp_group)
164
+
165
+ if torch.distributed.get_rank() == 0:
166
+ state_dict[name] = _get_cpu_tensor(weight)
167
+
168
+ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
169
+ """broadcast tensor in tp shards across mp_group"""
170
+ nonlocal state_dict
171
+ nonlocal mp_group
172
+ # tp_rank = mpu.get_tensor_model_parallel_rank()
173
+ tp_size = mpu.get_tensor_model_parallel_world_size()
174
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
175
+
176
+ chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
177
+
178
+ obj_list = [chunk_shape]
179
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
180
+ chunk_shape = obj_list[0]
181
+ if chunk_shape is None:
182
+ # all or none ranks in the mp_group should reach here
183
+ print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
184
+ return
185
+
186
+ buffer_tensor = torch.empty(
187
+ chunk_shape,
188
+ dtype=dtype,
189
+ device=torch.cuda.current_device(),
190
+ requires_grad=False,
191
+ )
192
+
193
+ chunk_tensors = [None] * tp_size
194
+
195
+ for i in range(tp_size):
196
+ cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
197
+ sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
198
+ dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
199
+
200
+ if torch.distributed.get_rank() == 0:
201
+ chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
202
+
203
+ if torch.distributed.get_rank() == 0:
204
+ full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
205
+ if mutate_func is not None:
206
+ full_tensor = mutate_func(full_tensor)
207
+ state_dict[name] = full_tensor
208
+
209
+ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
210
+ """broadcast tensor in tp shards across mp_group"""
211
+ nonlocal state_dict
212
+ nonlocal mp_group
213
+ # tp_rank = mpu.get_tensor_model_parallel_rank()
214
+ tp_size = mpu.get_tensor_model_parallel_world_size()
215
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
216
+
217
+ chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
218
+
219
+ obj_list = [chunk_shape]
220
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
221
+ chunk_shape = obj_list[0]
222
+ if chunk_shape is None:
223
+ # all or none ranks in the mp_group should reach here
224
+ print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
225
+ return
226
+
227
+ buffer_tensor = torch.empty(
228
+ chunk_shape,
229
+ dtype=dtype,
230
+ device=torch.cuda.current_device(),
231
+ requires_grad=False,
232
+ )
233
+
234
+ chunk_tensors = [None] * tp_size
235
+
236
+ for i in range(tp_size):
237
+ cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
238
+ sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
239
+ dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
240
+
241
+ if torch.distributed.get_rank() == 0:
242
+ chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
243
+
244
+ if torch.distributed.get_rank() == 0:
245
+ full_tensor = torch.concat(chunk_tensors, dim=0)
246
+ intermediate_size_tp = config.intermediate_size // tp_size
247
+ gate_weight_list = []
248
+ up_weight_list = []
249
+ for i in range(tp_size):
250
+ gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]
251
+ gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
252
+ up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
253
+ gate_weight_list.append(gate_weight_tp)
254
+ up_weight_list.append(up_weight_tp)
255
+
256
+ state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
257
+ state_dict[up_name] = torch.cat(up_weight_list, dim=0)
258
+
259
+ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
260
+ """broadcast tensor in tp shards across mp_group"""
261
+ nonlocal state_dict
262
+ nonlocal mp_group
263
+ # tp_rank = mpu.get_tensor_model_parallel_rank()
264
+ tp_size = mpu.get_tensor_model_parallel_world_size()
265
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
266
+
267
+ chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
268
+
269
+ obj_list = [chunk_shape]
270
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
271
+ chunk_shape = obj_list[0]
272
+ if chunk_shape is None:
273
+ # all or none ranks in the mp_group should reach here
274
+ print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
275
+ return
276
+
277
+ buffer_tensor = torch.empty(
278
+ chunk_shape,
279
+ dtype=dtype,
280
+ device=torch.cuda.current_device(),
281
+ requires_grad=False,
282
+ )
283
+
284
+ chunk_tensors = [None] * tp_size
285
+
286
+ for i in range(tp_size):
287
+ cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
288
+ sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
289
+ dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
290
+
291
+ if torch.distributed.get_rank() == 0:
292
+ chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
293
+
294
+ if torch.distributed.get_rank() == 0:
295
+ full_tensor = torch.concat(chunk_tensors, dim=0)
296
+ q_weight_list = []
297
+ k_weight_list = []
298
+ v_weight_list = []
299
+ hidden_size_per_head = config.hidden_size // config.num_attention_heads
300
+
301
+ if config.num_key_value_heads >= tp_size:
302
+ q_size_tp = config.hidden_size // tp_size
303
+ kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
304
+ total_size = q_size_tp + 2 * kv_size_tp
305
+ for i in range(tp_size):
306
+ num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size
307
+ qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
308
+ q_size_chunk = q_size_tp // num_query_groups_per_partition
309
+ kv_size_chunk = kv_size_tp // num_query_groups_per_partition
310
+ for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):
311
+ q_part = qkv_part_chunk[:q_size_chunk]
312
+ k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]
313
+ v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]
314
+ q_weight_list.append(q_part)
315
+ k_weight_list.append(k_part)
316
+ v_weight_list.append(v_part)
317
+ else:
318
+ q_size_tp = config.hidden_size // tp_size
319
+ kv_size_tp = hidden_size_per_head
320
+ total_size = q_size_tp + 2 * kv_size_tp
321
+ for i in range(tp_size):
322
+ num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size
323
+ qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
324
+ q_size_chunk = q_size_tp // num_query_groups_per_partition
325
+ kv_size_chunk = kv_size_tp // num_query_groups_per_partition
326
+ for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):
327
+ q_part = qkv_part_chunk[:q_size_chunk]
328
+ k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]
329
+ v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]
330
+ q_weight_list.append(q_part)
331
+ if i * config.num_key_value_heads % tp_size == 0:
332
+ k_weight_list.append(k_part)
333
+ v_weight_list.append(v_part)
334
+
335
+ state_dict[q_name] = torch.cat(q_weight_list, dim=0)
336
+ state_dict[k_name] = torch.cat(k_weight_list, dim=0)
337
+ state_dict[v_name] = torch.cat(v_weight_list, dim=0)
338
+
339
+ # empty cache before collecting weights
340
+ torch.cuda.empty_cache()
341
+ # Embeddings
342
+ # -------------------
343
+ if dp_rank == 0 and cp_rank == 0: # models are identical across cp ranks
344
+ # Embeddings
345
+ # -------------------
346
+ print_rank_0("collecting embeddings...")
347
+ gpt_model_module = _get_gpt_model(models[0])
348
+ _broadcast_tp_shard_tensor(
349
+ gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None,
350
+ "model.embed_tokens.weight",
351
+ src_pp_rank=0,
352
+ )
353
+
354
+ # Transformer layers
355
+ # -------------------
356
+ layer_map = _megatron_calc_layer_map(config)
357
+ for layer in range(config.num_hidden_layers):
358
+ print_rank_0(f"collecting layer #{layer}...")
359
+ layer_name = f"model.layers.{layer}"
360
+ src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]
361
+
362
+ gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
363
+ sync_layer = gpt_model_module.decoder.layers[src_layer_idx]
364
+
365
+ _broadcast_tensor(
366
+ sync_layer.self_attention.linear_qkv.layer_norm_weight,
367
+ f"{layer_name}.input_layernorm.weight",
368
+ src_pp_rank=src_pp_rank,
369
+ )
370
+
371
+ _broadcast_tp_shard_tensor_qkv(
372
+ sync_layer.self_attention.linear_qkv.weight,
373
+ f"{layer_name}.self_attn.q_proj.weight",
374
+ f"{layer_name}.self_attn.k_proj.weight",
375
+ f"{layer_name}.self_attn.v_proj.weight",
376
+ src_pp_rank=src_pp_rank,
377
+ )
378
+
379
+ if getattr(sync_layer.self_attention.linear_qkv, "bias", None) is not None and sync_layer.self_attention.linear_qkv.bias.numel() > 0:
380
+ _broadcast_tp_shard_tensor_qkv(
381
+ sync_layer.self_attention.linear_qkv.bias,
382
+ f"{layer_name}.self_attn.q_proj.bias",
383
+ f"{layer_name}.self_attn.k_proj.bias",
384
+ f"{layer_name}.self_attn.v_proj.bias",
385
+ src_pp_rank=src_pp_rank,
386
+ )
387
+
388
+ _broadcast_tp_shard_tensor(
389
+ sync_layer.self_attention.linear_proj.weight,
390
+ f"{layer_name}.self_attn.o_proj.weight",
391
+ concat_dim=1,
392
+ src_pp_rank=src_pp_rank,
393
+ )
394
+
395
+ _broadcast_tensor(
396
+ sync_layer.mlp.linear_fc1.layer_norm_weight,
397
+ f"{layer_name}.post_attention_layernorm.weight",
398
+ src_pp_rank=src_pp_rank,
399
+ )
400
+
401
+ _broadcast_tp_shard_tensor_gate_up(
402
+ sync_layer.mlp.linear_fc1.weight,
403
+ f"{layer_name}.mlp.gate_proj.weight",
404
+ f"{layer_name}.mlp.up_proj.weight",
405
+ src_pp_rank=src_pp_rank,
406
+ )
407
+
408
+ _broadcast_tp_shard_tensor(
409
+ sync_layer.mlp.linear_fc2.weight,
410
+ f"{layer_name}.mlp.down_proj.weight",
411
+ concat_dim=1,
412
+ src_pp_rank=src_pp_rank,
413
+ )
414
+
415
+ # Final Layernorm
416
+ # -------------------
417
+ print_rank_0("collecting final layernorm...")
418
+ gpt_model_module = _get_gpt_model(models[-1])
419
+ _broadcast_tensor(
420
+ getattr(gpt_model_module.decoder.final_layernorm, "weight", None),
421
+ "model.norm.weight",
422
+ src_pp_rank=pp_size - 1,
423
+ )
424
+
425
+ if tie_word_embeddings:
426
+ print_rank_0("tie word embedding skip load lm_head...")
427
+ else:
428
+ print_rank_0("collecting lm_head...")
429
+
430
+ if is_value_model:
431
+ lm_head_weight = None
432
+ if pp_rank == pp_size - 1:
433
+ lm_head_weight = getattr(gpt_model_module.output_layer, "weight", None)
434
+ _broadcast_tensor(lm_head_weight, "lm_head.weight", src_pp_rank=pp_size - 1)
435
+
436
+ else:
437
+ _broadcast_tp_shard_tensor(
438
+ getattr(gpt_model_module.output_layer, "weight", None) if pp_rank == pp_size - 1 else None,
439
+ "lm_head.weight",
440
+ src_pp_rank=pp_size - 1,
441
+ )
442
+
443
+ dist.barrier()
444
+ torch.cuda.empty_cache()
445
+ if torch.distributed.get_rank() == 0:
446
+ for k, v in state_dict.items():
447
+ if dtype != v.dtype:
448
+ state_dict[k] = v.to(dtype)
449
+
450
+ print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
451
+ return state_dict
452
+
453
+
454
+ def merge_megatron_ckpt_gptmodel_qwen_moe(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
455
+ raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen_moe is not implemented")
456
+
457
+
458
+ def merge_megatron_ckpt_gptmodel_mixtral(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
459
+ raise NotImplementedError("merge_megatron_ckpt_gptmodel_mixtral is not implemented")
verl/models/mcore/util.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from megatron.core import parallel_state as mpu
18
+ from megatron.core.packed_seq_params import PackedSeqParams
19
+
20
+
21
+ def preprocess_packed_seqs(input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True) -> tuple[torch.Tensor, PackedSeqParams]:
22
+ """
23
+ Preprocess packed sequences
24
+ CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 gets second and second last chunks, and so on), this is for load balancing with causal masking.
25
+ See https://github.com/NVIDIA/TransformerEngine/issues/1368
26
+ """
27
+ batch_size = input_ids.shape[0]
28
+
29
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
30
+ tp_size = mpu.get_tensor_model_parallel_world_size()
31
+ cp_size = mpu.get_context_parallel_world_size()
32
+ cp_rank = mpu.get_context_parallel_rank()
33
+ align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size
34
+
35
+ pad_size = (align_size - seqlens_in_batch % align_size) % align_size
36
+ seqlens_in_batch_padded = seqlens_in_batch + pad_size
37
+ cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
38
+ cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
39
+ cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
40
+ cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)
41
+ max_seqlen_in_batch = seqlens_in_batch_padded.max().item()
42
+
43
+ shape = list(input_ids.shape[1:])
44
+ shape[0] = seqlens_in_batch_padded.sum().item() // cp_size
45
+ if pre_process:
46
+ input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
47
+ for i in range(batch_size):
48
+ if cp_size <= 1:
49
+ seqlen = seqlens_in_batch[i]
50
+ input_ids_rmpad[cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]]
51
+ continue
52
+ seqlen = seqlens_in_batch_padded[i] // cp_size
53
+ half_seqlen = seqlen // 2
54
+ start_idx = cu_seqlens_padded[i] // cp_size
55
+ # split to 2 chunks
56
+ d = input_ids[i, attention_mask[i]]
57
+ input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)]
58
+
59
+ remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1)
60
+ remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank
61
+ remain_end = min(remain_end, d.shape[0])
62
+ remain_len = remain_end - remain_start
63
+ if remain_len > 0:
64
+ input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[remain_start:remain_end]
65
+
66
+ packed_seq_params = PackedSeqParams(
67
+ qkv_format="thd",
68
+ cu_seqlens_q=cu_seqlens_padded,
69
+ max_seqlen_q=max_seqlen_in_batch,
70
+ cu_seqlens_kv=cu_seqlens_padded,
71
+ max_seqlen_kv=max_seqlen_in_batch,
72
+ cu_seqlens_q_padded=cu_seqlens_padded,
73
+ cu_seqlens_kv_padded=cu_seqlens_padded,
74
+ )
75
+ if pre_process:
76
+ return input_ids_rmpad.unsqueeze(0), packed_seq_params
77
+ else:
78
+ return input_ids, packed_seq_params
79
+
80
+
81
+ def postprocess_packed_seqs(
82
+ output: torch.Tensor,
83
+ packed_seq_params: PackedSeqParams,
84
+ attention_mask: torch.Tensor,
85
+ batch_size: int,
86
+ seq_len: int,
87
+ post_process: bool = True,
88
+ ) -> torch.Tensor:
89
+ """
90
+ Postprocess packed sequences
91
+ """
92
+ if not post_process:
93
+ return output
94
+ shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim
95
+ output_new = torch.zeros(shape, dtype=output.dtype, device=output.device)
96
+
97
+ cp_size = mpu.get_context_parallel_world_size()
98
+ # all gather output across context parallel group
99
+ if cp_size > 1:
100
+ # output shape: [1, packed_len, hidden_dim]
101
+ # need to gather across cp group and concatenate in sequence dimension
102
+ output_list = [torch.empty_like(output) for _ in range(cp_size)]
103
+ torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())
104
+ output_list[mpu.get_context_parallel_rank()] = output
105
+ else:
106
+ output_list = [output]
107
+ for i in range(batch_size):
108
+ if cp_size <= 1:
109
+ s = attention_mask[i].sum().item()
110
+ output_new[i, attention_mask[i]] = output[0][packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s]
111
+ continue
112
+ s_len_padded_chunk = (packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i]) // cp_size
113
+ half_seqlen = s_len_padded_chunk // 2
114
+ s_len = attention_mask[i].sum().item()
115
+ s_len_padded = s_len_padded_chunk * cp_size
116
+ tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
117
+ for j in range(cp_size):
118
+ o = output_list[j][0]
119
+ # split to 2 chunks
120
+ packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size
121
+ o0, o1 = (
122
+ o[packed_start_idx : packed_start_idx + half_seqlen],
123
+ o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],
124
+ )
125
+ tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0
126
+ tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1
127
+ output_new[i, attention_mask[i]] = tmp[:s_len]
128
+
129
+ return output_new
130
+
131
+
132
+ def remove_left_padding(
133
+ input_ids: torch.Tensor,
134
+ attention_mask: torch.Tensor,
135
+ position_ids: torch.Tensor,
136
+ sequence_parallel: bool = False,
137
+ pre_process: bool = True,
138
+ ):
139
+ """
140
+ Remove left padding from input_ids, attention_mask and position_ids
141
+ return new_input_ids, new_attention_mask, new_position_ids
142
+ """
143
+ assert attention_mask.ndim == 2
144
+ assert position_ids.ndim == 2
145
+ cp_size = mpu.get_context_parallel_world_size()
146
+ assert cp_size == 1, "Context parallel size without seq_pack is not supported"
147
+ batch_size = input_ids.shape[0]
148
+ shape = list(input_ids.shape) # batch_size, seq_len,...
149
+ seq_lens = attention_mask.sum(dim=1)
150
+ seq_len = seq_lens.max().item()
151
+ if sequence_parallel:
152
+ sp_world_size = mpu.get_tensor_model_parallel_world_size()
153
+ pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size
154
+ seq_len = seq_len + pad_size
155
+ shape[1] = seq_len
156
+ if pre_process:
157
+ new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape)
158
+ new_attention_mask = torch.zeros(dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len))
159
+ new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len))
160
+ for i in range(batch_size):
161
+ if pre_process:
162
+ new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]]
163
+ new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]]
164
+ new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]]
165
+ if pre_process:
166
+ return new_input_ids, new_attention_mask, new_position_ids
167
+ else:
168
+ return input_ids, new_attention_mask, new_position_ids
169
+
170
+
171
+ def recover_left_padding(
172
+ result,
173
+ attention_mask: torch.Tensor,
174
+ original_attention_mask: torch.Tensor,
175
+ origin_seqlen: int,
176
+ post_process: bool = True,
177
+ ):
178
+ """
179
+ Recover left padding from result
180
+ return result
181
+ """
182
+ if not post_process:
183
+ return result
184
+ shape = list(result.shape)
185
+ batch_size = shape[0]
186
+ shape[1] = origin_seqlen
187
+ new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape)
188
+ for i in range(batch_size):
189
+ new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]]
190
+ return new_result
verl/models/mcore/weight_converter.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # online convert mcore weight to pure huggingface weight, no any fusion
18
+ # including format conversion and name mapping
19
+ # not including resharding
20
+ import torch
21
+ from megatron.core.transformer import TransformerConfig
22
+ from transformers import PretrainedConfig
23
+
24
+
25
+ class McoreToHFWeightConverterBase:
26
+ def __init__(self, hf_config: PretrainedConfig, mcore_config: TransformerConfig):
27
+ self.hf_config = hf_config
28
+ self.mcore_config = mcore_config
29
+
30
+ def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> torch.Tensor:
31
+ raise NotImplementedError
32
+
33
+
34
+ class McoreToHFWeightConverterDense(McoreToHFWeightConverterBase):
35
+ def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
36
+ # 'decoder.layers.0.self_attention.linear_proj.weight'
37
+ # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight'
38
+ # 'decoder.layers.0.self_attention.linear_qkv.weight'
39
+ # 'decoder.layers.0.self_attention.linear_qkv.bias'
40
+ layer_number = name.split(".")[2]
41
+ convert_names = []
42
+ if "self_attention.linear_qkv.bias" in name or "self_attention.linear_qkv.weight" in name:
43
+ param_type = name.split(".")[-1]
44
+ assert param_type == "bias" or param_type == "weight"
45
+ convert_names.append(f"model.layers.{layer_number}.self_attn.q_proj.{param_type}")
46
+ convert_names.append(f"model.layers.{layer_number}.self_attn.k_proj.{param_type}")
47
+ convert_names.append(f"model.layers.{layer_number}.self_attn.v_proj.{param_type}")
48
+ assert len(params) == 3
49
+ elif "self_attention.linear_proj.weight" in name:
50
+ convert_names.append(f"model.layers.{layer_number}.self_attn.o_proj.weight")
51
+ assert len(params) == 1
52
+ elif "self_attention.linear_qkv.layer_norm_weight" in name:
53
+ convert_names.append(f"model.layers.{layer_number}.input_layernorm.weight")
54
+ assert len(params) == 1
55
+ elif "self_attention.q_layernorm.weight" in name:
56
+ convert_names.append(f"model.layers.{layer_number}.self_attn.q_norm.weight")
57
+ assert len(params) == 1
58
+ elif "self_attention.k_layernorm.weight" in name:
59
+ convert_names.append(f"model.layers.{layer_number}.self_attn.k_norm.weight")
60
+ assert len(params) == 1
61
+ else:
62
+ raise NotImplementedError(f"Unsupported parameter name: {name}")
63
+ return convert_names, params
64
+
65
+ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
66
+ # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight'
67
+ # 'decoder.layers.0.mlp.linear_fc1.weight'
68
+ # 'decoder.layers.0.mlp.linear_fc2.weight'
69
+ layer_number = name.split(".")[2]
70
+ convert_names = []
71
+ if "mlp.linear_fc1.weight" in name:
72
+ # split gate_proj and up_proj
73
+ convert_names.append(f"model.layers.{layer_number}.mlp.gate_proj.weight")
74
+ convert_names.append(f"model.layers.{layer_number}.mlp.up_proj.weight")
75
+ assert len(params) == 2
76
+ elif "mlp.linear_fc1.layer_norm_weight" in name:
77
+ convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight")
78
+ assert len(params) == 1
79
+ elif "mlp.linear_fc2.weight" in name:
80
+ convert_names.append(f"model.layers.{layer_number}.mlp.down_proj.weight")
81
+ assert len(params) == 1
82
+ else:
83
+ raise NotImplementedError(f"Unsupported parameter name: {name}")
84
+ return convert_names, params
85
+
86
+ def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
87
+ direct_name_mapping = {
88
+ "embedding.word_embeddings.weight": "model.embed_tokens.weight",
89
+ "decoder.final_layernorm.weight": "model.norm.weight",
90
+ "output_layer.weight": "lm_head.weight",
91
+ }
92
+ if name in direct_name_mapping:
93
+ return [direct_name_mapping[name]], [params_one_group[0]]
94
+
95
+ if "self_attention" in name:
96
+ return self._convert_attention_param(name, params_one_group)
97
+ elif "mlp" in name:
98
+ return self._convert_mlp_param(name, params_one_group)
99
+ else:
100
+ raise NotImplementedError(f"Unsupported parameter name: {name}")
101
+
102
+
103
+ class McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense):
104
+ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
105
+ # 'decoder.layers.0.pre_mlp_layernorm.weight',
106
+ # 'decoder.layers.0.mlp.router.weight',
107
+ # 'decoder.layers.0.mlp.shared_experts.gate_weight',
108
+ # 'decoder.layers.0.mlp.shared_experts.linear_fc1.weight',
109
+ # 'decoder.layers.0.mlp.shared_experts.linear_fc2.weight'
110
+ # moe1
111
+ # 'decoder.layers.0.mlp.experts.linear_fc1.weight0',
112
+ # 'decoder.layers.0.mlp.experts.linear_fc1.weight1',
113
+ # 'decoder.layers.0.mlp.experts.linear_fc1.weight2',
114
+ # 'decoder.layers.0.mlp.experts.linear_fc1.weight3',
115
+ # moe2
116
+ # 'decoder.layers.0.mlp.experts.linear_fc2.weight0',
117
+ # 'decoder.layers.0.mlp.experts.linear_fc2.weight1',
118
+ layer_number = name.split(".")[2]
119
+ convert_names = []
120
+ if "pre_mlp_layernorm" in name:
121
+ convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight")
122
+ assert len(params) == 1
123
+ elif "mlp.router.weight" in name:
124
+ convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight")
125
+ assert len(params) == 1
126
+ elif "shared_experts.gate_weight" in name:
127
+ convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert_gate.weight")
128
+ assert len(params) == 1
129
+ elif "shared_experts.linear_fc1.weight" in name: # split gate_proj and up_proj
130
+ convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight")
131
+ convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight")
132
+ assert len(params) == 2
133
+ elif "shared_experts.linear_fc2.weight" in name:
134
+ convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight")
135
+ assert len(params) == 1
136
+ elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj
137
+ expert_id = name.split("weight")[-1]
138
+ convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight")
139
+ convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight")
140
+ assert len(params) == 2
141
+ elif "mlp.experts.linear_fc2" in name:
142
+ expert_id = name.split("weight")[-1]
143
+ convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight")
144
+ assert len(params) == 1
145
+ else:
146
+ raise NotImplementedError(f"Unsupported parameter name: {name}")
147
+ return convert_names, params
148
+
149
+
150
+ class McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense):
151
+ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
152
+ # decoder.layers.0.mlp.router.weight
153
+ # decoder.layers.0.mlp.experts.linear_fc1.weight0 - weight7
154
+ # decoder.layers.0.mlp.experts.linear_fc2.weight0 - weight7
155
+
156
+ layer_number = name.split(".")[2]
157
+ convert_names = []
158
+ if "pre_mlp_layernorm" in name:
159
+ convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight")
160
+ elif "mlp.router.weight" in name:
161
+ convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.gate.weight")
162
+ elif "mlp.experts.linear_fc1.weight" in name:
163
+ expert_id = name.split("weight")[-1]
164
+ convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight")
165
+ convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight")
166
+ elif "mlp.experts.linear_fc2.weight" in name:
167
+ expert_id = name.split("weight")[-1]
168
+ convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight")
169
+ else:
170
+ raise NotImplementedError(f"Unsupported parameter name: {name}")
171
+ return convert_names, params
172
+
173
+
174
+ class McoreToHFWeightConverterQwen3Moe(McoreToHFWeightConverterDense):
175
+ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
176
+ # qwen3 moe no share expert
177
+
178
+ # 'decoder.layers.0.pre_mlp_layernorm.weight',
179
+ # 'decoder.layers.0.mlp.router.weight',
180
+ # moe1
181
+ # 'decoder.layers.0.mlp.experts.linear_fc1.weight0',
182
+ # 'decoder.layers.0.mlp.experts.linear_fc1.weight1',
183
+ # 'decoder.layers.0.mlp.experts.linear_fc1.weight2',
184
+ # 'decoder.layers.0.mlp.experts.linear_fc1.weight3',
185
+ # moe2
186
+ # 'decoder.layers.0.mlp.experts.linear_fc2.weight0',
187
+ # 'decoder.layers.0.mlp.experts.linear_fc2.weight1',
188
+ layer_number = name.split(".")[2]
189
+ convert_names = []
190
+ if "pre_mlp_layernorm" in name:
191
+ convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight")
192
+ assert len(params) == 1
193
+ elif "mlp.router.weight" in name:
194
+ convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight")
195
+ assert len(params) == 1
196
+ elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj
197
+ expert_id = name.split("weight")[-1]
198
+ convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight")
199
+ convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight")
200
+ assert len(params) == 2
201
+ elif "mlp.experts.linear_fc2" in name:
202
+ expert_id = name.split("weight")[-1]
203
+ convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight")
204
+ assert len(params) == 1
205
+ else:
206
+ raise NotImplementedError(f"Unsupported parameter name: {name}")
207
+ return convert_names, params
verl/models/qwen2/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
verl/models/qwen2/megatron/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .modeling_qwen2_megatron import (
16
+ ParallelQwen2ForCausalLM,
17
+ # rmpad with megatron
18
+ ParallelQwen2ForCausalLMRmPad,
19
+ # rmpad with megatron and pipeline parallelism
20
+ ParallelQwen2ForCausalLMRmPadPP,
21
+ ParallelQwen2ForValueRmPad,
22
+ ParallelQwen2ForValueRmPadPP,
23
+ # original model with megatron
24
+ ParallelQwen2Model,
25
+ )
26
+
27
+ __all__ = [
28
+ "ParallelQwen2ForCausalLM",
29
+ "ParallelQwen2ForCausalLMRmPad",
30
+ "ParallelQwen2ForCausalLMRmPadPP",
31
+ "ParallelQwen2ForValueRmPad",
32
+ "ParallelQwen2ForValueRmPadPP",
33
+ "ParallelQwen2Model",
34
+ ]
verl/models/qwen2/megatron/checkpoint_utils/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+
20
+
21
+ def _megatron_calc_layer_map(config):
22
+ """Calculate the mapping of global layer_idx to local layer_idx
23
+ Returns:
24
+ layer_map (Dict: int -> tuple(int, int, int)):
25
+ mapping from the global layer index to
26
+ a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
27
+ """
28
+ from megatron.core import mpu
29
+
30
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
31
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
32
+
33
+ layer_map = dict()
34
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
35
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
36
+
37
+ for pp_rank_idx in range(pp_size):
38
+ for virtual_pp_rank_idx in range(virtual_pp_size):
39
+ layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
40
+ for layer_idx in range(num_layers_per_model):
41
+ layer_map[layer_offset + layer_idx] = (
42
+ pp_rank_idx,
43
+ virtual_pp_rank_idx,
44
+ layer_idx,
45
+ )
46
+ return layer_map
47
+
48
+
49
+ def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):
50
+ """Load merged state_dict to sharded Megatron module in training."""
51
+ from megatron.core import DistributedDataParallel as LocalDDP
52
+ from megatron.core import mpu
53
+ from megatron.core.transformer.module import Float16Module
54
+ from torch.nn.parallel import DistributedDataParallel as torchDDP
55
+
56
+ from verl.utils.megatron_utils import print_rank_0, unwrap_model
57
+
58
+ start_time = time.time()
59
+
60
+ def _get_gpt_model(model):
61
+ return model
62
+
63
+ def fetch_params(module):
64
+ for param in module.parameters():
65
+ torch.distributed.fetch(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())
66
+
67
+ dp_rank = mpu.get_data_parallel_rank()
68
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
69
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
70
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
71
+ mp_group = mpu.get_model_parallel_group()
72
+
73
+ if torch.distributed.get_rank() == 0:
74
+ assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
75
+ assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
76
+ assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
77
+
78
+ if not isinstance(wrapped_models, (list, tuple)):
79
+ wrapped_models = list(wrapped_models)
80
+
81
+ assert len(wrapped_models) == virtual_pp_size
82
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
83
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}"
84
+
85
+ models = [None] * len(wrapped_models)
86
+
87
+ for i, wrapped_model in enumerate(wrapped_models):
88
+ models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
89
+ gpt_model_module = _get_gpt_model(models[i])
90
+ assert len(gpt_model_module.model.layers) == num_layers_per_model
91
+
92
+ def _fetch_tensor(tensor, name) -> torch.Tensor:
93
+ """fetch tensor"""
94
+ nonlocal state_dict
95
+ if tensor is not None:
96
+ tensor = tensor.data.copy_(state_dict[name], non_blocking=True)
97
+
98
+ def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
99
+ """fetch tensor in tp shards"""
100
+ nonlocal state_dict
101
+ tp_rank = mpu.get_tensor_model_parallel_rank()
102
+ tp_size = mpu.get_tensor_model_parallel_world_size()
103
+ if name in state_dict:
104
+ full_weight = state_dict[name]
105
+
106
+ if mutate_func is not None:
107
+ full_weight = mutate_func(full_weight)
108
+ tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
109
+ if tensor is not None:
110
+ tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
111
+ else:
112
+ print(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
113
+
114
+ def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
115
+ """fetch tensor in tp shards"""
116
+ nonlocal state_dict
117
+ tp_rank = mpu.get_tensor_model_parallel_rank()
118
+ tp_size = mpu.get_tensor_model_parallel_world_size()
119
+ if name in state_dict:
120
+ full_weight = state_dict[name]
121
+
122
+ if mutate_func is not None:
123
+ full_weight = mutate_func(full_weight)
124
+ tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
125
+ if tensor is not None:
126
+ tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
127
+ else:
128
+ print(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
129
+
130
+ def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
131
+ """fetch gate_up tensor in tp shards"""
132
+ nonlocal state_dict
133
+ nonlocal mp_group
134
+ tp_rank = mpu.get_tensor_model_parallel_rank()
135
+ tp_size = mpu.get_tensor_model_parallel_world_size()
136
+ if gate_name in state_dict and up_name in state_dict:
137
+ gate_weight = state_dict[gate_name]
138
+ up_weight = state_dict[up_name]
139
+ new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
140
+ for i in range(tp_size):
141
+ intermediate_size_tp = config.intermediate_size // tp_size
142
+ gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
143
+ up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
144
+ new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))
145
+
146
+ tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
147
+ if tensor is not None:
148
+ tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
149
+ else:
150
+ print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading")
151
+
152
+ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:
153
+ """fetch tensor in tp shards across mp_group"""
154
+ nonlocal state_dict
155
+ nonlocal mp_group
156
+ tp_rank = mpu.get_tensor_model_parallel_rank()
157
+ tp_size = mpu.get_tensor_model_parallel_world_size()
158
+ assert q_name in state_dict and k_name in state_dict and v_name in state_dict
159
+ full_weight_q = state_dict[q_name]
160
+ full_weight_k = state_dict[k_name]
161
+ full_weight_v = state_dict[v_name]
162
+
163
+ hidden_size_per_head = config.hidden_size // config.num_attention_heads
164
+
165
+ if config.num_key_value_heads >= tp_size:
166
+ q_size_tp = config.hidden_size // tp_size
167
+ kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
168
+ total_size = q_size_tp + 2 * kv_size_tp
169
+ if not bias:
170
+ new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
171
+ else:
172
+ new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device())
173
+ for i in range(tp_size):
174
+ q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
175
+ k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
176
+ v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
177
+ new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
178
+
179
+ else:
180
+ q_size_tp = config.hidden_size // tp_size
181
+ kv_size_tp = hidden_size_per_head
182
+ total_size = q_size_tp + 2 * kv_size_tp
183
+ if not bias:
184
+ new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
185
+ else:
186
+ new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device())
187
+ for i in range(tp_size):
188
+ q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
189
+ start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
190
+ end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
191
+ k_part = full_weight_k[start_idx:end_idx]
192
+ v_part = full_weight_v[start_idx:end_idx]
193
+ new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
194
+
195
+ tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
196
+ if tensor is not None:
197
+ tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
198
+
199
+ # Embeddings
200
+ # -------------------
201
+ print_rank_0("loading embeddings...")
202
+ gpt_model_module = _get_gpt_model(models[0])
203
+ if pp_rank == 0:
204
+ embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
205
+ _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
206
+
207
+ # Transformer layers
208
+ # -------------------
209
+ layer_map = _megatron_calc_layer_map(config)
210
+
211
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
212
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
213
+ num_layer_per_pp = config.num_hidden_layers // pp_size
214
+ vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
215
+
216
+ layer_list = []
217
+ if vpp_size is not None:
218
+ for vpp_rank in range(vpp_size):
219
+ num_layer_vpp_chunk = num_layer_per_pp // vpp_size
220
+ num_layer_this_model = num_layer_vpp_chunk
221
+ offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk)
222
+ layer_list.extend(list(range(offset, offset + num_layer_this_model)))
223
+ else:
224
+ num_layer_this_model = num_layer_per_pp
225
+ offset = pp_rank * num_layer_per_pp
226
+ layer_list.extend(list(range(offset, offset + num_layer_this_model)))
227
+
228
+ for layer in layer_list:
229
+ print(f"{torch.distributed.get_rank()} loading layer #{layer}...")
230
+ layer_name = f"model.layers.{layer}"
231
+ dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
232
+
233
+ print(f"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}")
234
+
235
+ gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
236
+ sync_layer = gpt_model_module.model.layers[dst_layer_idx]
237
+
238
+ _fetch_tensor(
239
+ sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
240
+ f"{layer_name}.input_layernorm.weight",
241
+ )
242
+
243
+ _fetch_tp_shard_tensor_qkv(
244
+ sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
245
+ f"{layer_name}.self_attn.q_proj.weight",
246
+ f"{layer_name}.self_attn.k_proj.weight",
247
+ f"{layer_name}.self_attn.v_proj.weight",
248
+ )
249
+
250
+ _fetch_tp_shard_tensor_qkv(
251
+ sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,
252
+ f"{layer_name}.self_attn.q_proj.bias",
253
+ f"{layer_name}.self_attn.k_proj.bias",
254
+ f"{layer_name}.self_attn.v_proj.bias",
255
+ bias=True,
256
+ )
257
+
258
+ _fetch_tp_shard_tensor(
259
+ sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
260
+ f"{layer_name}.self_attn.o_proj.weight",
261
+ chunk_dim=1,
262
+ )
263
+
264
+ _fetch_tensor(
265
+ sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
266
+ f"{layer_name}.post_attention_layernorm.weight",
267
+ )
268
+
269
+ _fetch_tp_shard_tensor_gate_up(
270
+ sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
271
+ f"{layer_name}.mlp.gate_proj.weight",
272
+ f"{layer_name}.mlp.up_proj.weight",
273
+ )
274
+
275
+ _fetch_tp_shard_tensor(
276
+ sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
277
+ f"{layer_name}.mlp.down_proj.weight",
278
+ chunk_dim=1,
279
+ )
280
+ # Final Layernorm
281
+ # -------------------
282
+ print_rank_0("loading final layernorm...")
283
+ gpt_model_module = _get_gpt_model(models[-1])
284
+ _fetch_tensor(
285
+ getattr(gpt_model_module.model.norm, "weight", None),
286
+ "model.norm.weight",
287
+ )
288
+
289
+ if tie_word_embeddings:
290
+ print_rank_0("tie_word_embeddings skip load lm_head")
291
+ else:
292
+ print_rank_0("loading lm_head...")
293
+ if pp_rank + 1 == pp_size:
294
+ lm_head_weight = gpt_model_module.lm_head.weight
295
+
296
+ if is_value_model:
297
+ if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
298
+ _fetch_tensor(lm_head_weight, "lm_head.weight")
299
+ print_rank_0("load lm_head from value_head weight")
300
+ elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
301
+ _fetch_tensor(lm_head_weight, "reward_head.weight")
302
+ print_rank_0("load lm_head from value_head weight")
303
+ else:
304
+ _fetch_tensor(None, "lm_head.weight")
305
+ print_rank_0("fail to match lm_head in value_model")
306
+
307
+ else:
308
+ _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight")
309
+
310
+ dist.barrier()
311
+ torch.cuda.empty_cache()
312
+ print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+
20
+
21
+ def _megatron_calc_layer_map(config):
22
+ """Calculate the mapping of global layer_idx to local layer_idx
23
+ Returns:
24
+ layer_map (Dict: int -> tuple(int, int, int)):
25
+ mapping from the global layer index to
26
+ a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
27
+ """
28
+ from megatron.core import mpu
29
+
30
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
31
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
32
+
33
+ layer_map = dict()
34
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
35
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
36
+
37
+ for pp_rank_idx in range(pp_size):
38
+ for virtual_pp_rank_idx in range(virtual_pp_size):
39
+ layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
40
+ for layer_idx in range(num_layers_per_model):
41
+ layer_map[layer_offset + layer_idx] = (
42
+ pp_rank_idx,
43
+ virtual_pp_rank_idx,
44
+ layer_idx,
45
+ )
46
+ return layer_map
47
+
48
+
49
+ def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):
50
+ """Load merged state_dict to sharded Megatron module in training."""
51
+ from megatron.core import DistributedDataParallel as LocalDDP
52
+ from megatron.core import mpu
53
+ from megatron.core.transformer.module import Float16Module
54
+ from torch.nn.parallel import DistributedDataParallel as torchDDP
55
+
56
+ from verl.utils.megatron_utils import print_rank_0, unwrap_model
57
+
58
+ start_time = time.time()
59
+
60
+ def _get_gpt_model(model):
61
+ return model
62
+
63
+ def broadcast_params(module):
64
+ for param in module.parameters():
65
+ torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())
66
+
67
+ dp_rank = mpu.get_data_parallel_rank()
68
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
69
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
70
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
71
+ mp_group = mpu.get_model_parallel_group()
72
+
73
+ if torch.distributed.get_rank() == 0:
74
+ assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
75
+ assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
76
+ assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
77
+
78
+ if not isinstance(wrapped_models, (list, tuple)):
79
+ wrapped_models = list(wrapped_models)
80
+
81
+ assert len(wrapped_models) == virtual_pp_size
82
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
83
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}"
84
+
85
+ models = [None] * len(wrapped_models)
86
+
87
+ for i, wrapped_model in enumerate(wrapped_models):
88
+ models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
89
+ gpt_model_module = _get_gpt_model(models[i])
90
+ assert len(gpt_model_module.model.layers) == num_layers_per_model
91
+
92
+ def _broadcast_tensor(tensor, name) -> torch.Tensor:
93
+ """broadcast tensor from rank0 across mp_group"""
94
+ nonlocal state_dict
95
+ nonlocal mp_group
96
+ if torch.distributed.get_rank() == 0:
97
+ if name in state_dict:
98
+ weight = state_dict[name]
99
+ tensor_shape = weight.shape
100
+ else:
101
+ tensor_shape = None
102
+ else:
103
+ weight = None
104
+ tensor_shape = None
105
+
106
+ obj_list = [tensor_shape]
107
+ dist.broadcast_object_list(obj_list, src=0, group=mp_group)
108
+ tensor_shape = obj_list[0]
109
+
110
+ if tensor_shape is None:
111
+ # all or none ranks in the mp_group should reach here
112
+ print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
113
+ return
114
+
115
+ if tensor is None:
116
+ tensor = torch.empty(
117
+ tensor_shape,
118
+ dtype=params_dtype,
119
+ device=torch.cuda.current_device(),
120
+ requires_grad=False,
121
+ )
122
+ if torch.distributed.get_rank() == 0:
123
+ tensor.data.copy_(weight)
124
+ dist.broadcast(tensor, src=0, group=mp_group)
125
+
126
+ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
127
+ """broadcast tensor in tp shards across mp_group"""
128
+ nonlocal state_dict
129
+ nonlocal mp_group
130
+ tp_rank = mpu.get_tensor_model_parallel_rank()
131
+ tp_size = mpu.get_tensor_model_parallel_world_size()
132
+
133
+ if torch.distributed.get_rank() == 0:
134
+ if name in state_dict:
135
+ full_weight = state_dict[name]
136
+
137
+ if mutate_func is not None:
138
+ full_weight = mutate_func(full_weight)
139
+ tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
140
+ chunk_shape = tensor_chunk[0].shape
141
+ else:
142
+ chunk_shape = None
143
+ else:
144
+ chunk_shape = None
145
+
146
+ obj_list = [chunk_shape]
147
+ dist.broadcast_object_list(obj_list, src=0, group=mp_group)
148
+ chunk_shape = obj_list[0]
149
+ if chunk_shape is None:
150
+ # all or none ranks in the mp_group should reach here
151
+ print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
152
+ return
153
+
154
+ if tensor is None:
155
+ sync_tensor = torch.empty(
156
+ chunk_shape,
157
+ dtype=params_dtype,
158
+ device=torch.cuda.current_device(),
159
+ requires_grad=False,
160
+ )
161
+ else:
162
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
163
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
164
+
165
+ for i in range(tp_size):
166
+ if torch.distributed.get_rank() == 0:
167
+ sync_tensor.data.copy_(tensor_chunk[i])
168
+ dist.broadcast(sync_tensor, src=0, group=mp_group)
169
+ if (i == tp_rank) and (tensor is not None):
170
+ tensor.data.copy_(sync_tensor)
171
+
172
+ def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
173
+ """broadcast tensor in tp shards across mp_group"""
174
+ nonlocal state_dict
175
+ nonlocal mp_group
176
+ tp_rank = mpu.get_tensor_model_parallel_rank()
177
+ tp_size = mpu.get_tensor_model_parallel_world_size()
178
+
179
+ if torch.distributed.get_rank() == 0:
180
+ if name in state_dict:
181
+ full_weight = state_dict[name]
182
+ if mutate_func is not None:
183
+ full_weight = mutate_func(full_weight)
184
+ tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
185
+ chunk_shape = tensor_chunk[0].shape
186
+ else:
187
+ chunk_shape = None
188
+ else:
189
+ chunk_shape = None
190
+
191
+ obj_list = [chunk_shape]
192
+ dist.broadcast_object_list(obj_list, src=0, group=mp_group)
193
+ chunk_shape = obj_list[0]
194
+ if chunk_shape is None:
195
+ # all or none ranks in the mp_group should reach here
196
+ print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
197
+ return
198
+
199
+ if tensor is None:
200
+ sync_tensor = torch.empty(
201
+ chunk_shape,
202
+ dtype=params_dtype,
203
+ device=torch.cuda.current_device(),
204
+ requires_grad=False,
205
+ )
206
+ else:
207
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
208
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
209
+
210
+ for i in range(tp_size):
211
+ if torch.distributed.get_rank() == 0:
212
+ sync_tensor.data.copy_(tensor_chunk[i])
213
+ dist.broadcast(sync_tensor, src=0, group=mp_group)
214
+ if (i == tp_rank) and (tensor is not None):
215
+ tensor.data.copy_(sync_tensor)
216
+
217
+ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
218
+ """broadcast tensor in tp shards across mp_group"""
219
+ nonlocal state_dict
220
+ nonlocal mp_group
221
+ tp_rank = mpu.get_tensor_model_parallel_rank()
222
+ tp_size = mpu.get_tensor_model_parallel_world_size()
223
+
224
+ if torch.distributed.get_rank() == 0:
225
+ gate_weight = state_dict[gate_name]
226
+ up_weight = state_dict[up_name]
227
+ new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
228
+ for i in range(tp_size):
229
+ intermediate_size_tp = config.intermediate_size // tp_size
230
+ gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
231
+ up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
232
+ new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))
233
+
234
+ tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
235
+ chunk_shape = tensor_chunk[0].shape
236
+ else:
237
+ chunk_shape = None
238
+
239
+ obj_list = [chunk_shape]
240
+ dist.broadcast_object_list(obj_list, src=0, group=mp_group)
241
+ chunk_shape = obj_list[0]
242
+ if chunk_shape is None:
243
+ # all or none ranks in the mp_group should reach here
244
+ print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
245
+ return
246
+
247
+ if tensor is None:
248
+ sync_tensor = torch.empty(
249
+ chunk_shape,
250
+ dtype=params_dtype,
251
+ device=torch.cuda.current_device(),
252
+ requires_grad=False,
253
+ )
254
+ else:
255
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
256
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
257
+
258
+ for i in range(tp_size):
259
+ if torch.distributed.get_rank() == 0:
260
+ sync_tensor.data.copy_(tensor_chunk[i])
261
+ dist.broadcast(sync_tensor, src=0, group=mp_group)
262
+ if (i == tp_rank) and (tensor is not None):
263
+ tensor.data.copy_(sync_tensor)
264
+
265
+ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:
266
+ """broadcast tensor in tp shards across mp_group"""
267
+ nonlocal state_dict
268
+ nonlocal mp_group
269
+ tp_rank = mpu.get_tensor_model_parallel_rank()
270
+ tp_size = mpu.get_tensor_model_parallel_world_size()
271
+
272
+ if torch.distributed.get_rank() == 0:
273
+ assert q_name in state_dict and k_name in state_dict and v_name in state_dict
274
+ full_weight_q = state_dict[q_name]
275
+ full_weight_k = state_dict[k_name]
276
+ full_weight_v = state_dict[v_name]
277
+
278
+ hidden_size_per_head = config.hidden_size // config.num_attention_heads
279
+
280
+ if config.num_key_value_heads >= tp_size:
281
+ q_size_tp = config.hidden_size // tp_size
282
+ kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
283
+ total_size = q_size_tp + 2 * kv_size_tp
284
+ if not bias:
285
+ new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
286
+ else:
287
+ new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device())
288
+ for i in range(tp_size):
289
+ q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
290
+ k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
291
+ v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
292
+ new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
293
+
294
+ else:
295
+ q_size_tp = config.hidden_size // tp_size
296
+ kv_size_tp = hidden_size_per_head
297
+ total_size = q_size_tp + 2 * kv_size_tp
298
+ if not bias:
299
+ new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
300
+ else:
301
+ new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device())
302
+ for i in range(tp_size):
303
+ q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
304
+ start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
305
+ end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
306
+ k_part = full_weight_k[start_idx:end_idx]
307
+ v_part = full_weight_v[start_idx:end_idx]
308
+ new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
309
+
310
+ tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
311
+ chunk_shape = tensor_chunk[0].shape
312
+ else:
313
+ chunk_shape = None
314
+
315
+ obj_list = [chunk_shape]
316
+ dist.broadcast_object_list(obj_list, src=0, group=mp_group)
317
+ chunk_shape = obj_list[0]
318
+ if chunk_shape is None:
319
+ # all or none ranks in the mp_group should reach here
320
+ print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading")
321
+ return
322
+
323
+ if tensor is None:
324
+ sync_tensor = torch.empty(
325
+ chunk_shape,
326
+ dtype=params_dtype,
327
+ device=torch.cuda.current_device(),
328
+ requires_grad=False,
329
+ )
330
+ else:
331
+ assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
332
+ sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
333
+
334
+ for i in range(tp_size):
335
+ if torch.distributed.get_rank() == 0:
336
+ sync_tensor.data.copy_(tensor_chunk[i])
337
+ dist.broadcast(sync_tensor, src=0, group=mp_group)
338
+ if (i == tp_rank) and (tensor is not None):
339
+ tensor.data.copy_(sync_tensor)
340
+
341
+ if dp_rank == 0:
342
+ # Embeddings
343
+ # -------------------
344
+ print_rank_0("loading embeddings...")
345
+ gpt_model_module = _get_gpt_model(models[0])
346
+ embed_tokens_weight = None
347
+ if pp_rank == 0:
348
+ embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
349
+ _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
350
+
351
+ # Transformer layers
352
+ # -------------------
353
+ layer_map = _megatron_calc_layer_map(config)
354
+
355
+ for layer in range(config.num_hidden_layers):
356
+ print_rank_0(f"loading layer #{layer}...")
357
+ layer_name = f"model.layers.{layer}"
358
+ dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
359
+
360
+ gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
361
+ sync_layer = gpt_model_module.model.layers[dst_layer_idx]
362
+
363
+ _broadcast_tensor(
364
+ sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
365
+ f"{layer_name}.input_layernorm.weight",
366
+ )
367
+
368
+ _broadcast_tp_shard_tensor_qkv(
369
+ sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
370
+ f"{layer_name}.self_attn.q_proj.weight",
371
+ f"{layer_name}.self_attn.k_proj.weight",
372
+ f"{layer_name}.self_attn.v_proj.weight",
373
+ )
374
+
375
+ _broadcast_tp_shard_tensor_qkv(
376
+ sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,
377
+ f"{layer_name}.self_attn.q_proj.bias",
378
+ f"{layer_name}.self_attn.k_proj.bias",
379
+ f"{layer_name}.self_attn.v_proj.bias",
380
+ bias=True,
381
+ )
382
+
383
+ _broadcast_tp_shard_tensor(
384
+ sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
385
+ f"{layer_name}.self_attn.o_proj.weight",
386
+ chunk_dim=1,
387
+ )
388
+
389
+ _broadcast_tensor(
390
+ sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
391
+ f"{layer_name}.post_attention_layernorm.weight",
392
+ )
393
+
394
+ _broadcast_tp_shard_tensor_gate_up(
395
+ sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
396
+ f"{layer_name}.mlp.gate_proj.weight",
397
+ f"{layer_name}.mlp.up_proj.weight",
398
+ )
399
+
400
+ _broadcast_tp_shard_tensor(
401
+ sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
402
+ f"{layer_name}.mlp.down_proj.weight",
403
+ chunk_dim=1,
404
+ )
405
+ # Final Layernorm
406
+ # -------------------
407
+ print_rank_0("loading final layernorm...")
408
+ gpt_model_module = _get_gpt_model(models[-1])
409
+ _broadcast_tensor(
410
+ getattr(gpt_model_module.model.norm, "weight", None),
411
+ "model.norm.weight",
412
+ )
413
+
414
+ if tie_word_embeddings:
415
+ print_rank_0("tie_word_embeddings skip load lm_head")
416
+ else:
417
+ print_rank_0("loading lm_head...")
418
+ lm_head_weight = None
419
+ if pp_rank + 1 == pp_size:
420
+ lm_head_weight = gpt_model_module.lm_head.weight
421
+
422
+ if is_value_model:
423
+ if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
424
+ _broadcast_tensor(lm_head_weight, "lm_head.weight")
425
+ print_rank_0("load lm_head from value_head weight")
426
+ elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
427
+ _broadcast_tensor(lm_head_weight, "reward_head.weight")
428
+ print_rank_0("load lm_head from value_head weight")
429
+ else:
430
+ _broadcast_tensor(None, "lm_head.weight")
431
+ print_rank_0("fail to match lm_head in value_model")
432
+
433
+ else:
434
+ _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
435
+
436
+ dist.barrier()
437
+ # Broadcast weights inside data parallel groups
438
+ for wrapped_model in wrapped_models:
439
+ broadcast_params(wrapped_model)
440
+
441
+ torch.cuda.empty_cache()
442
+ print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ from megatron.core import mpu
20
+ from megatron.core.distributed import DistributedDataParallel as LocalDDP
21
+ from megatron.core.transformer.module import Float16Module
22
+ from torch.nn.parallel import DistributedDataParallel as torchDDP
23
+
24
+ from verl.utils.megatron_utils import print_rank_0, unwrap_model
25
+
26
+
27
+ def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
28
+ """given TP,DP,PP rank to get the global rank."""
29
+
30
+ tp_size = mpu.get_tensor_model_parallel_world_size()
31
+ dp_size = mpu.get_data_parallel_world_size()
32
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
33
+ assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
34
+ # We only support TP-DP-PP grouping, for correctness when resharding
35
+ return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank
36
+
37
+
38
+ def _megatron_calc_layer_map(config):
39
+ """Calculate the mapping of global layer_idx to local layer_idx
40
+ Returns:
41
+ layer_map (Dict: int -> tuple(int, int, int)):
42
+ mapping from the global layer index to
43
+ a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
44
+ """
45
+ from megatron.core import mpu
46
+
47
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
48
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
49
+
50
+ layer_map = dict()
51
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
52
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
53
+
54
+ for pp_rank_idx in range(pp_size):
55
+ for virtual_pp_rank_idx in range(virtual_pp_size):
56
+ layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
57
+ for layer_idx in range(num_layers_per_model):
58
+ layer_map[layer_offset + layer_idx] = (
59
+ pp_rank_idx,
60
+ virtual_pp_rank_idx,
61
+ layer_idx,
62
+ )
63
+ return layer_map
64
+
65
+
66
+ def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
67
+ """Merge sharded parameters of a Megatron module into a merged checkpoint.
68
+
69
+ Args:
70
+ wrapped_models (list of megatron.core.distributed.DistributedDataParallel):
71
+ The local DDP wrapped megatron modules.
72
+ config (str or None):
73
+ HF config for model
74
+ dtype: model params type
75
+ is_value_model: if model is value model
76
+ tie_word_embeddings: tie_word_embeddings
77
+ Returns:
78
+ state_dict (dict):
79
+ The merged state_dict in rank 0, and an empty dictionary in other ranks.
80
+ """
81
+ start_time = time.time()
82
+
83
+ def _get_gpt_model(model):
84
+ return model
85
+
86
+ dp_rank = mpu.get_data_parallel_rank()
87
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
88
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
89
+ virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
90
+ mp_group = mpu.get_model_parallel_group()
91
+
92
+ if dist.get_rank() == 0:
93
+ assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
94
+ assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
95
+ assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
96
+
97
+ if not isinstance(wrapped_models, (list, tuple)):
98
+ wrapped_models = list(wrapped_models)
99
+
100
+ assert len(wrapped_models) == virtual_pp_size
101
+ num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
102
+ assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
103
+
104
+ models = [None] * len(wrapped_models)
105
+
106
+ for i, wrapped_model in enumerate(wrapped_models):
107
+ models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
108
+ assert len(models[i].model.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].model.layers), num_layers_per_model)
109
+
110
+ state_dict = dict()
111
+
112
+ def _get_cpu_tensor(tensor: torch.Tensor):
113
+ if tensor is None:
114
+ return None
115
+ if tensor.device == torch.device("cpu"):
116
+ return tensor.detach().clone()
117
+ return tensor.detach().cpu()
118
+
119
+ def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
120
+ """broadcast tensor across mp_group"""
121
+ nonlocal state_dict
122
+ nonlocal mp_group
123
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
124
+
125
+ if torch.distributed.get_rank() == src_rank:
126
+ if tensor is None:
127
+ weight = None
128
+ tensor_shape = None
129
+ else:
130
+ weight = tensor
131
+ tensor_shape = weight.shape
132
+ else:
133
+ weight = None
134
+ tensor_shape = None
135
+
136
+ obj_list = [tensor_shape]
137
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
138
+ tensor_shape = obj_list[0]
139
+
140
+ if tensor_shape is None:
141
+ # all or none ranks in the mp_group should reach here
142
+ print_rank_0(f"tensor:[{name}] not exist, skip collect")
143
+ return
144
+
145
+ if weight is None:
146
+ weight = torch.empty(
147
+ tensor_shape,
148
+ dtype=dtype,
149
+ device=torch.cuda.current_device(),
150
+ requires_grad=False,
151
+ )
152
+
153
+ dist.broadcast(weight, src=src_rank, group=mp_group)
154
+
155
+ if torch.distributed.get_rank() == 0:
156
+ state_dict[name] = _get_cpu_tensor(weight)
157
+
158
+ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
159
+ """broadcast tensor in tp shards across mp_group"""
160
+ nonlocal state_dict
161
+ nonlocal mp_group
162
+ tp_size = mpu.get_tensor_model_parallel_world_size()
163
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
164
+
165
+ chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
166
+
167
+ obj_list = [chunk_shape]
168
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
169
+ chunk_shape = obj_list[0]
170
+ if chunk_shape is None:
171
+ # all or none ranks in the mp_group should reach here
172
+ print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
173
+ return
174
+
175
+ buffer_tensor = torch.empty(
176
+ chunk_shape,
177
+ dtype=dtype,
178
+ device=torch.cuda.current_device(),
179
+ requires_grad=False,
180
+ )
181
+
182
+ chunk_tensors = [None] * tp_size
183
+
184
+ for i in range(tp_size):
185
+ cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
186
+ sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
187
+ dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
188
+
189
+ if torch.distributed.get_rank() == 0:
190
+ chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
191
+
192
+ if torch.distributed.get_rank() == 0:
193
+ full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
194
+ if mutate_func is not None:
195
+ full_tensor = mutate_func(full_tensor)
196
+ state_dict[name] = full_tensor
197
+
198
+ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
199
+ """broadcast tensor in tp shards across mp_group"""
200
+ nonlocal state_dict
201
+ nonlocal mp_group
202
+ tp_size = mpu.get_tensor_model_parallel_world_size()
203
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
204
+
205
+ chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
206
+
207
+ obj_list = [chunk_shape]
208
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
209
+ chunk_shape = obj_list[0]
210
+ if chunk_shape is None:
211
+ # all or none ranks in the mp_group should reach here
212
+ print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
213
+ return
214
+
215
+ buffer_tensor = torch.empty(
216
+ chunk_shape,
217
+ dtype=dtype,
218
+ device=torch.cuda.current_device(),
219
+ requires_grad=False,
220
+ )
221
+
222
+ chunk_tensors = [None] * tp_size
223
+
224
+ for i in range(tp_size):
225
+ cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
226
+ sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
227
+ dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
228
+
229
+ if torch.distributed.get_rank() == 0:
230
+ chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
231
+
232
+ if torch.distributed.get_rank() == 0:
233
+ full_tensor = torch.concat(chunk_tensors, dim=0)
234
+ intermediate_size_tp = config.intermediate_size // tp_size
235
+ gate_weight_list = []
236
+ up_weight_list = []
237
+ for i in range(tp_size):
238
+ gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]
239
+ gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
240
+ up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
241
+ gate_weight_list.append(gate_weight_tp)
242
+ up_weight_list.append(up_weight_tp)
243
+
244
+ state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
245
+ state_dict[up_name] = torch.cat(up_weight_list, dim=0)
246
+
247
+ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
248
+ """broadcast tensor in tp shards across mp_group"""
249
+ nonlocal state_dict
250
+ nonlocal mp_group
251
+ tp_size = mpu.get_tensor_model_parallel_world_size()
252
+ src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
253
+
254
+ chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
255
+
256
+ obj_list = [chunk_shape]
257
+ dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
258
+ chunk_shape = obj_list[0]
259
+ if chunk_shape is None:
260
+ # all or none ranks in the mp_group should reach here
261
+ print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
262
+ return
263
+
264
+ buffer_tensor = torch.empty(
265
+ chunk_shape,
266
+ dtype=dtype,
267
+ device=torch.cuda.current_device(),
268
+ requires_grad=False,
269
+ )
270
+
271
+ chunk_tensors = [None] * tp_size
272
+
273
+ for i in range(tp_size):
274
+ cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
275
+ sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
276
+ dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
277
+
278
+ if torch.distributed.get_rank() == 0:
279
+ chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
280
+
281
+ if torch.distributed.get_rank() == 0:
282
+ full_tensor = torch.concat(chunk_tensors, dim=0)
283
+ q_weight_list = []
284
+ k_weight_list = []
285
+ v_weight_list = []
286
+ hidden_size_per_head = config.hidden_size // config.num_attention_heads
287
+
288
+ if config.num_key_value_heads >= tp_size:
289
+ q_size_tp = config.hidden_size // tp_size
290
+ kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
291
+ total_size = q_size_tp + 2 * kv_size_tp
292
+ for i in range(tp_size):
293
+ qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
294
+ q_part = qkv_part[:q_size_tp]
295
+ k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]
296
+ v_part = qkv_part[q_size_tp + kv_size_tp : total_size]
297
+ q_weight_list.append(q_part)
298
+ k_weight_list.append(k_part)
299
+ v_weight_list.append(v_part)
300
+ else:
301
+ q_size_tp = config.hidden_size // tp_size
302
+ kv_size_tp = hidden_size_per_head
303
+ total_size = q_size_tp + 2 * kv_size_tp
304
+ for i in range(tp_size):
305
+ qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
306
+ q_part = qkv_part[:q_size_tp]
307
+ k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]
308
+ v_part = qkv_part[q_size_tp + kv_size_tp : total_size]
309
+ q_weight_list.append(q_part)
310
+ if i * config.num_key_value_heads % tp_size == 0:
311
+ k_weight_list.append(k_part)
312
+ v_weight_list.append(v_part)
313
+
314
+ state_dict[q_name] = torch.cat(q_weight_list, dim=0)
315
+ state_dict[k_name] = torch.cat(k_weight_list, dim=0)
316
+ state_dict[v_name] = torch.cat(v_weight_list, dim=0)
317
+
318
+ # empty cache before collecting weights
319
+ torch.cuda.empty_cache()
320
+ # Embeddings
321
+ # -------------------
322
+ if dp_rank == 0:
323
+ # Embeddings
324
+ # -------------------
325
+ print_rank_0("collecting embeddings...")
326
+ gpt_model_module = _get_gpt_model(models[0])
327
+ _broadcast_tp_shard_tensor(
328
+ gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,
329
+ "model.embed_tokens.weight",
330
+ src_pp_rank=0,
331
+ )
332
+
333
+ # Transformer layers
334
+ # -------------------
335
+ layer_map = _megatron_calc_layer_map(config)
336
+ for layer in range(config.num_hidden_layers):
337
+ print_rank_0(f"collecting layer #{layer}...")
338
+ layer_name = f"model.layers.{layer}"
339
+ src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]
340
+
341
+ gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
342
+ sync_layer = gpt_model_module.model.layers[src_layer_idx]
343
+
344
+ _broadcast_tensor(
345
+ sync_layer.input_layernorm.weight,
346
+ f"{layer_name}.input_layernorm.weight",
347
+ src_pp_rank=src_pp_rank,
348
+ )
349
+
350
+ _broadcast_tp_shard_tensor_qkv(
351
+ sync_layer.self_attn.qkv_proj.weight,
352
+ f"{layer_name}.self_attn.q_proj.weight",
353
+ f"{layer_name}.self_attn.k_proj.weight",
354
+ f"{layer_name}.self_attn.v_proj.weight",
355
+ src_pp_rank=src_pp_rank,
356
+ )
357
+
358
+ _broadcast_tp_shard_tensor_qkv(
359
+ sync_layer.self_attn.qkv_proj.bias,
360
+ f"{layer_name}.self_attn.q_proj.bias",
361
+ f"{layer_name}.self_attn.k_proj.bias",
362
+ f"{layer_name}.self_attn.v_proj.bias",
363
+ src_pp_rank=src_pp_rank,
364
+ )
365
+
366
+ _broadcast_tp_shard_tensor(
367
+ sync_layer.self_attn.o_proj.weight,
368
+ f"{layer_name}.self_attn.o_proj.weight",
369
+ concat_dim=1,
370
+ src_pp_rank=src_pp_rank,
371
+ )
372
+
373
+ _broadcast_tensor(
374
+ sync_layer.post_attention_layernorm.weight,
375
+ f"{layer_name}.post_attention_layernorm.weight",
376
+ src_pp_rank=src_pp_rank,
377
+ )
378
+
379
+ _broadcast_tp_shard_tensor_gate_up(
380
+ sync_layer.mlp.gate_up_proj.weight,
381
+ f"{layer_name}.mlp.gate_proj.weight",
382
+ f"{layer_name}.mlp.up_proj.weight",
383
+ src_pp_rank=src_pp_rank,
384
+ )
385
+
386
+ _broadcast_tp_shard_tensor(
387
+ sync_layer.mlp.down_proj.weight,
388
+ f"{layer_name}.mlp.down_proj.weight",
389
+ concat_dim=1,
390
+ src_pp_rank=src_pp_rank,
391
+ )
392
+
393
+ # Final Layernorm
394
+ # -------------------
395
+ print_rank_0("collecting final layernorm...")
396
+ gpt_model_module = _get_gpt_model(models[-1])
397
+ _broadcast_tensor(
398
+ getattr(gpt_model_module.model.norm, "weight", None),
399
+ "model.norm.weight",
400
+ src_pp_rank=pp_size - 1,
401
+ )
402
+
403
+ if tie_word_embeddings:
404
+ print_rank_0("tie word embedding skip load lm_head...")
405
+ else:
406
+ print_rank_0("collecting lm_head...")
407
+
408
+ if is_value_model:
409
+ _broadcast_tensor(
410
+ gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,
411
+ "lm_head.weight",
412
+ src_pp_rank=pp_size - 1,
413
+ )
414
+ _broadcast_tensor(
415
+ gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None else None,
416
+ "reward_head.weight",
417
+ src_pp_rank=pp_size - 1,
418
+ )
419
+
420
+ else:
421
+ _broadcast_tp_shard_tensor(
422
+ getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
423
+ "lm_head.weight",
424
+ src_pp_rank=pp_size - 1,
425
+ )
426
+
427
+ dist.barrier()
428
+
429
+ torch.cuda.empty_cache()
430
+ if torch.distributed.get_rank() == 0:
431
+ for k, v in state_dict.items():
432
+ if dtype != v.dtype:
433
+ state_dict[k] = v.to(dtype)
434
+
435
+ print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
436
+ return state_dict
verl/models/qwen2/megatron/layers/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .parallel_attention import ParallelQwen2Attention
16
+ from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad
17
+ from .parallel_mlp import ParallelQwen2MLP
18
+ from .parallel_rmsnorm import ParallelQwen2RMSNorm
19
+
20
+ __all__ = ["ParallelQwen2Attention", "ParallelQwen2DecoderLayer", "ParallelQwen2DecoderLayerRmPad", "ParallelQwen2MLP", "ParallelQwen2RMSNorm"]