braindeck
commited on
Commit
·
bcdf9fa
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -0
- README.md +20 -0
- _infer.py +309 -0
- app.py +43 -0
- config/infer.yaml +56 -0
- prompts/__pycache__/base_instruction.cpython-311.pyc +0 -0
- prompts/__pycache__/infer_prompt.cpython-311.pyc +0 -0
- prompts/__pycache__/sft_prompt.cpython-311.pyc +0 -0
- prompts/base_instruction.py +24 -0
- prompts/infer_prompt.py +101 -0
- requirements.txt +4 -0
- requirements.txt.txt +200 -0
- scripts/sft_infer_pass1.sh +33 -0
- verl/__init__.py +40 -0
- verl/__pycache__/__init__.cpython-311.pyc +0 -0
- verl/__pycache__/protocol.cpython-311.pyc +0 -0
- verl/models/README.md +35 -0
- verl/models/__init__.py +13 -0
- verl/models/__pycache__/__init__.cpython-311.pyc +0 -0
- verl/models/__pycache__/registry.cpython-311.pyc +0 -0
- verl/models/llama/__init__.py +13 -0
- verl/models/llama/megatron/__init__.py +34 -0
- verl/models/llama/megatron/checkpoint_utils/__init__.py +13 -0
- verl/models/llama/megatron/checkpoint_utils/llama_loader.py +295 -0
- verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py +425 -0
- verl/models/llama/megatron/checkpoint_utils/llama_saver.py +430 -0
- verl/models/llama/megatron/layers/__init__.py +25 -0
- verl/models/llama/megatron/layers/parallel_attention.py +425 -0
- verl/models/llama/megatron/layers/parallel_decoder.py +150 -0
- verl/models/llama/megatron/layers/parallel_linear.py +106 -0
- verl/models/llama/megatron/layers/parallel_mlp.py +74 -0
- verl/models/llama/megatron/layers/parallel_rmsnorm.py +48 -0
- verl/models/llama/megatron/modeling_llama_megatron.py +662 -0
- verl/models/mcore/__init__.py +18 -0
- verl/models/mcore/config_converter.py +197 -0
- verl/models/mcore/loader.py +468 -0
- verl/models/mcore/model_forward.py +50 -0
- verl/models/mcore/model_initializer.py +160 -0
- verl/models/mcore/readme.md +99 -0
- verl/models/mcore/registry.py +179 -0
- verl/models/mcore/saver.py +459 -0
- verl/models/mcore/util.py +190 -0
- verl/models/mcore/weight_converter.py +207 -0
- verl/models/qwen2/__init__.py +13 -0
- verl/models/qwen2/megatron/__init__.py +34 -0
- verl/models/qwen2/megatron/checkpoint_utils/__init__.py +13 -0
- verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py +312 -0
- verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py +442 -0
- verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py +436 -0
- verl/models/qwen2/megatron/layers/__init__.py +20 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
checkpoints/
|
README.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Text-to-Text Generation with DeepSeek-R1-Distill-Qwen-7B
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.19.2
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Text-to-Text Generation with DeepSeek-R1-Distill-Qwen-7B
|
| 13 |
+
|
| 14 |
+
This is a simple Gradio interface for text-to-text generation using the `deepseek-ai/DeepSeek-R1-Distill-Qwen-7B` model.
|
| 15 |
+
|
| 16 |
+
## How to use
|
| 17 |
+
|
| 18 |
+
1. Enter a prompt in the text box.
|
| 19 |
+
2. Click the "Generate" button.
|
| 20 |
+
3. The model will generate a response in the "Response" text box.
|
_infer.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright 2024 Bytedance
|
| 3 |
+
# Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# VERL + vLLM inference with runtime LoRA (no merge).
|
| 6 |
+
# - Wraps a LoRA .pt into a PEFT adapter and attaches via rollout.lora_modules
|
| 7 |
+
# - Mixed precision defaults for H100: dtype=bf16, kv_cache_dtype=fp8_e5m2
|
| 8 |
+
# - Pins max_model_len, max_num_batched_tokens, sets swap_space
|
| 9 |
+
# - Uses OmegaConf.open_dict to add keys safely (no "not in struct" errors)
|
| 10 |
+
# - Prevents FSDP from trying to load LoRA .pt as a full model
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import ast
|
| 14 |
+
import json
|
| 15 |
+
import hydra
|
| 16 |
+
import numpy as np
|
| 17 |
+
import ray
|
| 18 |
+
import torch
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from pprint import pprint
|
| 21 |
+
|
| 22 |
+
# Quiet logs
|
| 23 |
+
os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "WARN")
|
| 24 |
+
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get("TOKENIZERS_PARALLELISM", "true")
|
| 25 |
+
|
| 26 |
+
# vLLM CuMem allocator is incompatible with expandable_segments
|
| 27 |
+
_bad = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
|
| 28 |
+
if "expandable_segments:True" in _bad:
|
| 29 |
+
print(f"[fix] Removing incompatible PYTORCH_CUDA_ALLOC_CONF={_bad}")
|
| 30 |
+
os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None)
|
| 31 |
+
|
| 32 |
+
import pandas as pd
|
| 33 |
+
from omegaconf import OmegaConf, open_dict
|
| 34 |
+
|
| 35 |
+
from verl import DataProto
|
| 36 |
+
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
|
| 37 |
+
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
| 38 |
+
from verl.utils import hf_tokenizer
|
| 39 |
+
from verl.utils.fs import copy_to_local
|
| 40 |
+
from verl.utils.hdfs_io import makedirs
|
| 41 |
+
from verl.utils.model import compute_position_id_with_mask
|
| 42 |
+
from verl.workers.fsdp_workers import ActorRolloutRefWorker
|
| 43 |
+
|
| 44 |
+
# ---------------- LoRA helpers ----------------
|
| 45 |
+
|
| 46 |
+
DEFAULT_TARGET_MODULES = [
|
| 47 |
+
"q_proj","k_proj","v_proj","o_proj",
|
| 48 |
+
"up_proj","gate_proj","down_proj",
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
def _infer_lengths_and_defaults(config):
|
| 52 |
+
"""Ensure rollout/data keys exist and set reasonable H100 defaults."""
|
| 53 |
+
# Ensure nested structs exist
|
| 54 |
+
with open_dict(config):
|
| 55 |
+
if "rollout" not in config:
|
| 56 |
+
config["rollout"] = OmegaConf.create()
|
| 57 |
+
if "data" not in config:
|
| 58 |
+
config["data"] = OmegaConf.create()
|
| 59 |
+
if "trainer" not in config:
|
| 60 |
+
config["trainer"] = OmegaConf.create()
|
| 61 |
+
if "ray_init" not in config:
|
| 62 |
+
config["ray_init"] = OmegaConf.create()
|
| 63 |
+
|
| 64 |
+
# Defaults that work on a single H100
|
| 65 |
+
with open_dict(config.rollout):
|
| 66 |
+
# If user didn't set these, choose H100-friendly defaults
|
| 67 |
+
config.rollout.setdefault("dtype", "bfloat16") # weights/activations
|
| 68 |
+
config.rollout.setdefault("kv_cache_dtype", "fp8_e5m2") # KV cache precision
|
| 69 |
+
config.rollout.setdefault("tensor_model_parallel_size", 1)
|
| 70 |
+
config.rollout.setdefault("enable_chunked_prefill", True)
|
| 71 |
+
config.rollout.setdefault("swap_space", 8) # GB of host swap for KV
|
| 72 |
+
config.rollout.setdefault("gpu_memory_utilization", 0.62) # adjust 0.60~0.75 if needed
|
| 73 |
+
|
| 74 |
+
# Pin lengths to avoid vLLM over-reserving KV cache
|
| 75 |
+
pl = int(config.rollout.get("prompt_length", 1024))
|
| 76 |
+
rl = int(config.rollout.get("response_length", 128))
|
| 77 |
+
need = int(pl + rl)
|
| 78 |
+
config.rollout.setdefault("max_model_len", need)
|
| 79 |
+
config.rollout.setdefault("max_num_batched_tokens", need)
|
| 80 |
+
|
| 81 |
+
# Users may pass +rollout.quantization={fp8|awq|gptq} to shrink weights further
|
| 82 |
+
# We don't force it here.
|
| 83 |
+
|
| 84 |
+
with open_dict(config.data):
|
| 85 |
+
config.data.setdefault("batch_size", 1)
|
| 86 |
+
config.data.setdefault("n_samples", 1)
|
| 87 |
+
config.data.setdefault("prompt_key", "prompt")
|
| 88 |
+
|
| 89 |
+
with open_dict(config.trainer):
|
| 90 |
+
config.trainer.setdefault("n_gpus_per_node", 1)
|
| 91 |
+
config.trainer.setdefault("nnodes", 1)
|
| 92 |
+
|
| 93 |
+
with open_dict(config.ray_init):
|
| 94 |
+
config.ray_init.setdefault("num_cpus", 4)
|
| 95 |
+
|
| 96 |
+
def _infer_lora_rank_from_state(sd):
|
| 97 |
+
for k, v in sd.items():
|
| 98 |
+
if k.endswith("lora_A.weight") and hasattr(v, "dim") and v.dim() == 2:
|
| 99 |
+
return int(v.shape[0])
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
def _list_target_modules_from_state(sd):
|
| 103 |
+
found = set()
|
| 104 |
+
for k in sd.keys():
|
| 105 |
+
if "lora_A.weight" in k or "lora_B.weight" in k:
|
| 106 |
+
if ".q_proj." in k: found.add("q_proj")
|
| 107 |
+
if ".k_proj." in k: found.add("k_proj")
|
| 108 |
+
if ".v_proj." in k: found.add("v_proj")
|
| 109 |
+
if ".o_proj." in k: found.add("o_proj")
|
| 110 |
+
if ".up_proj." in k: found.add("up_proj")
|
| 111 |
+
if ".gate_proj." in k: found.add("gate_proj")
|
| 112 |
+
if ".down_proj." in k: found.add("down_proj")
|
| 113 |
+
return sorted(found)
|
| 114 |
+
|
| 115 |
+
def _write_adapter_config(adapter_dir, r, alpha, target_modules, dropout=0.0):
|
| 116 |
+
cfg = {
|
| 117 |
+
"peft_type": "LORA",
|
| 118 |
+
"auto_mapping": None,
|
| 119 |
+
"base_model_name_or_path": "",
|
| 120 |
+
"bias": "none",
|
| 121 |
+
"inference_mode": True,
|
| 122 |
+
"lora_alpha": int(alpha),
|
| 123 |
+
"lora_dropout": float(dropout),
|
| 124 |
+
"r": int(r),
|
| 125 |
+
"target_modules": target_modules,
|
| 126 |
+
"task_type": "CAUSAL_LM",
|
| 127 |
+
}
|
| 128 |
+
with open(os.path.join(adapter_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
|
| 129 |
+
json.dump(cfg, f, ensure_ascii=False, indent=2)
|
| 130 |
+
|
| 131 |
+
def _wrap_lora_pt_as_peft(adapter_pt_path: str, out_dir: str,
|
| 132 |
+
fallback_rank=32, fallback_alpha=16):
|
| 133 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 134 |
+
print(f"[lora] Loading LoRA state from: {adapter_pt_path}")
|
| 135 |
+
sd = torch.load(adapter_pt_path, map_location="cpu")
|
| 136 |
+
if isinstance(sd, dict) and "state_dict" in sd:
|
| 137 |
+
sd = sd["state_dict"]
|
| 138 |
+
|
| 139 |
+
r = _infer_lora_rank_from_state(sd) or int(fallback_rank)
|
| 140 |
+
tmods = _list_target_modules_from_state(sd) or DEFAULT_TARGET_MODULES
|
| 141 |
+
print(f"[lora] inferred rank={r}, target_modules={tmods}")
|
| 142 |
+
|
| 143 |
+
_write_adapter_config(out_dir, r=r, alpha=fallback_alpha, target_modules=tmods)
|
| 144 |
+
torch.save(sd, os.path.join(out_dir, "adapter_model.bin"))
|
| 145 |
+
return r, tmods
|
| 146 |
+
|
| 147 |
+
def _maybe_attach_lora_adapter(config):
|
| 148 |
+
"""Attach LoRA adapter directory to vLLM rollout (runtime LoRA)."""
|
| 149 |
+
# Accept either +lora.pt_path or model.load_param_path as a hint
|
| 150 |
+
lora_pt = None
|
| 151 |
+
if "lora" in config and getattr(config.lora, "pt_path", ""):
|
| 152 |
+
lora_pt = config.lora.pt_path
|
| 153 |
+
elif getattr(config.model, "load_param_path", ""):
|
| 154 |
+
lora_pt = config.model.load_param_path
|
| 155 |
+
|
| 156 |
+
if not lora_pt or not Path(lora_pt).is_file():
|
| 157 |
+
print("[lora] No LoRA .pt provided; running base model only.")
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
adapter_dir = os.path.join("/tmp", "lora_adapter_vllm")
|
| 161 |
+
r, _ = _wrap_lora_pt_as_peft(lora_pt, adapter_dir, fallback_rank=32, fallback_alpha=16)
|
| 162 |
+
|
| 163 |
+
# Ensure rollout keys exist and add LoRA knobs required by vLLM
|
| 164 |
+
with open_dict(config):
|
| 165 |
+
if "rollout" not in config:
|
| 166 |
+
config["rollout"] = OmegaConf.create()
|
| 167 |
+
with open_dict(config.rollout):
|
| 168 |
+
config.rollout.setdefault("max_loras", 1)
|
| 169 |
+
config.rollout.setdefault("max_lora_rank", int(r))
|
| 170 |
+
config.rollout["lora_modules"] = [{"path": adapter_dir, "scale": 1.0}]
|
| 171 |
+
print(f"[lora] Attached PEFT adapter: {adapter_dir} (rank={r})")
|
| 172 |
+
|
| 173 |
+
# CRITICAL: don't let FSDP try to load the LoRA .pt as a full state dict
|
| 174 |
+
with open_dict(config.model):
|
| 175 |
+
if getattr(config.model, "load_param", False):
|
| 176 |
+
print("[lora] Disabling model.load_param to avoid FSDP load_state_dict mismatch.")
|
| 177 |
+
config.model["load_param"] = False
|
| 178 |
+
|
| 179 |
+
# ---------------- Hydra entry ----------------
|
| 180 |
+
|
| 181 |
+
@hydra.main(config_path="config", config_name="infer", version_base=None)
|
| 182 |
+
def main(config):
|
| 183 |
+
_infer_lengths_and_defaults(config)
|
| 184 |
+
|
| 185 |
+
# Ray env for workers
|
| 186 |
+
if not ray.is_initialized():
|
| 187 |
+
ray.init(
|
| 188 |
+
runtime_env={"env_vars": {
|
| 189 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 190 |
+
"NCCL_DEBUG": "WARN",
|
| 191 |
+
"PYTORCH_CUDA_ALLOC_CONF": "", # keep allocator happy for vLLM
|
| 192 |
+
}},
|
| 193 |
+
num_cpus=config.ray_init.num_cpus,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
ray.get(main_task.remote(config))
|
| 197 |
+
|
| 198 |
+
@ray.remote(num_cpus=1)
|
| 199 |
+
def main_task(config):
|
| 200 |
+
print("[worker] PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))
|
| 201 |
+
pprint(OmegaConf.to_container(config, resolve=True))
|
| 202 |
+
OmegaConf.resolve(config)
|
| 203 |
+
|
| 204 |
+
# Build LoRA adapter if provided
|
| 205 |
+
_maybe_attach_lora_adapter(config)
|
| 206 |
+
|
| 207 |
+
# Optionally pre-gen dataset schema if your repo provides it
|
| 208 |
+
try:
|
| 209 |
+
from prompts.infer_prompt import infer_dataset
|
| 210 |
+
infer_dataset(
|
| 211 |
+
model_name=config.model.path,
|
| 212 |
+
data_path=os.path.dirname(os.path.dirname(config.data.path)),
|
| 213 |
+
)
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print(f"[info] infer_dataset() skipped: {e}")
|
| 216 |
+
|
| 217 |
+
# ---- Tokenizer from base model
|
| 218 |
+
local_path = copy_to_local(config.model.path)
|
| 219 |
+
trust_remote_code = getattr(config.model, "trust_remote_code", False)
|
| 220 |
+
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
| 221 |
+
tokenizer.padding_side = "left"
|
| 222 |
+
if tokenizer.pad_token is None:
|
| 223 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 224 |
+
|
| 225 |
+
# ---- Sampling checks
|
| 226 |
+
if float(config.rollout.temperature) == 0.0:
|
| 227 |
+
assert int(config.data.n_samples) == 1, "When temperature=0, n_samples must be 1."
|
| 228 |
+
assert int(config.data.n_samples) >= 1, "n_samples should always >= 1"
|
| 229 |
+
|
| 230 |
+
# ---- Load dataset
|
| 231 |
+
dataset = pd.read_parquet(config.data.path)
|
| 232 |
+
prompt_key = getattr(config.data, "prompt_key", "prompt")
|
| 233 |
+
if prompt_key not in dataset.columns:
|
| 234 |
+
raise KeyError(f"Dataset missing column '{prompt_key}'")
|
| 235 |
+
chat_lst = dataset[prompt_key].tolist()
|
| 236 |
+
chat_lst = [chat.tolist() if hasattr(chat, "tolist") else chat for chat in chat_lst]
|
| 237 |
+
|
| 238 |
+
# ---- Worker group (vLLM inside Rollout)
|
| 239 |
+
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
|
| 240 |
+
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
|
| 241 |
+
print("[debug] rollout.lora_modules =", config.rollout.get("lora_modules", None))
|
| 242 |
+
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
|
| 243 |
+
wg.init_model() # vLLM spins up; adapter used if set in rollout.lora_modules
|
| 244 |
+
|
| 245 |
+
total = len(dataset)
|
| 246 |
+
bs = int(config.data.batch_size)
|
| 247 |
+
num_batch = -(-total // bs)
|
| 248 |
+
slots = [[] for _ in range(int(config.data.n_samples))]
|
| 249 |
+
|
| 250 |
+
for b in range(num_batch):
|
| 251 |
+
print(f"[{b+1}/{num_batch}] Start to process.")
|
| 252 |
+
batch_chat = chat_lst[b * bs : (b + 1) * bs]
|
| 253 |
+
|
| 254 |
+
inputs = tokenizer.apply_chat_template(
|
| 255 |
+
batch_chat,
|
| 256 |
+
add_generation_prompt=True,
|
| 257 |
+
padding=True,
|
| 258 |
+
truncation=True,
|
| 259 |
+
max_length=int(config.rollout.prompt_length),
|
| 260 |
+
return_tensors="pt",
|
| 261 |
+
return_dict=True,
|
| 262 |
+
tokenize=True,
|
| 263 |
+
)
|
| 264 |
+
input_ids = inputs["input_ids"]
|
| 265 |
+
attention_mask = inputs["attention_mask"]
|
| 266 |
+
position_ids = compute_position_id_with_mask(attention_mask)
|
| 267 |
+
batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}
|
| 268 |
+
|
| 269 |
+
data = DataProto.from_dict(batch_dict)
|
| 270 |
+
data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)
|
| 271 |
+
|
| 272 |
+
print(f"[{b+1}/{num_batch}] Start to generate.")
|
| 273 |
+
for n in range(int(config.data.n_samples)):
|
| 274 |
+
output_padded = wg.generate_sequences(data_padded)
|
| 275 |
+
output = unpad_dataproto(output_padded, pad_size=pad_size)
|
| 276 |
+
texts = []
|
| 277 |
+
for i in range(len(output)):
|
| 278 |
+
item = output[i]
|
| 279 |
+
pl = item.batch["prompts"].shape[-1]
|
| 280 |
+
valid_len = item.batch["attention_mask"][pl:].sum()
|
| 281 |
+
resp_ids = item.batch["responses"][:valid_len]
|
| 282 |
+
s = tokenizer.decode(resp_ids, skip_special_tokens=True)
|
| 283 |
+
print(f"[raw] Response {i}: {s!r}")
|
| 284 |
+
ix = s.find("</think>")
|
| 285 |
+
if ix != -1:
|
| 286 |
+
s = s[ix + len("</think>") :].lstrip()
|
| 287 |
+
print(f"Response {i}: {s!r}")
|
| 288 |
+
try:
|
| 289 |
+
texts.append(ast.literal_eval(s))
|
| 290 |
+
except Exception:
|
| 291 |
+
texts.append(s)
|
| 292 |
+
slots[n].extend(texts)
|
| 293 |
+
|
| 294 |
+
outputs = np.array(slots, dtype=object)
|
| 295 |
+
outputs = np.transpose(outputs, (1, 0)).tolist()
|
| 296 |
+
dataset["response"] = outputs
|
| 297 |
+
|
| 298 |
+
keep = ["file_id", "vt", "gt", "response"]
|
| 299 |
+
cols = [c for c in keep if c in dataset.columns]
|
| 300 |
+
if cols:
|
| 301 |
+
dataset = dataset[cols]
|
| 302 |
+
|
| 303 |
+
out_path = config.data.output_path
|
| 304 |
+
makedirs(os.path.dirname(out_path), exist_ok=True)
|
| 305 |
+
dataset.to_json(out_path, orient="records", lines=True, force_ascii=False)
|
| 306 |
+
print(f"[done] Wrote: {out_path}")
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
main()
|
app.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
# Load the model and tokenizer
|
| 7 |
+
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", trust_remote_code=True)
|
| 8 |
+
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto")
|
| 9 |
+
|
| 10 |
+
def generate_response(prompt):
|
| 11 |
+
"""
|
| 12 |
+
Generates a response from the model.
|
| 13 |
+
"""
|
| 14 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 15 |
+
outputs = model.generate(**inputs, max_new_tokens=512)
|
| 16 |
+
|
| 17 |
+
# Decode the generated text
|
| 18 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 19 |
+
|
| 20 |
+
return generated_text
|
| 21 |
+
|
| 22 |
+
# Create the Gradio interface
|
| 23 |
+
with gr.Blocks() as demo:
|
| 24 |
+
gr.Markdown("# Text-to-Text Generation with DeepSeek-R1-Distill-Qwen-7B")
|
| 25 |
+
gr.Markdown("Enter a prompt and the model will generate a response.")
|
| 26 |
+
|
| 27 |
+
with gr.Row():
|
| 28 |
+
prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="Enter your prompt here...")
|
| 29 |
+
|
| 30 |
+
with gr.Row():
|
| 31 |
+
generate_button = gr.Button("Generate")
|
| 32 |
+
|
| 33 |
+
with gr.Row():
|
| 34 |
+
response_output = gr.Textbox(label="Response", lines=8, interactive=False)
|
| 35 |
+
|
| 36 |
+
generate_button.click(
|
| 37 |
+
fn=generate_response,
|
| 38 |
+
inputs=prompt_input,
|
| 39 |
+
outputs=response_output
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
demo.launch()
|
config/infer.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
trainer:
|
| 2 |
+
nnodes: 1
|
| 3 |
+
n_gpus_per_node: 1
|
| 4 |
+
|
| 5 |
+
data:
|
| 6 |
+
path: ./data/parquet/test.parquet
|
| 7 |
+
prompt_key: prompt
|
| 8 |
+
n_samples: 1
|
| 9 |
+
output_path: ./checkpoints/grammar_generation.parquet
|
| 10 |
+
batch_size: 1
|
| 11 |
+
|
| 12 |
+
model:
|
| 13 |
+
path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
|
| 14 |
+
external_lib: null
|
| 15 |
+
load_param: False
|
| 16 |
+
load_param_path: null
|
| 17 |
+
|
| 18 |
+
rollout:
|
| 19 |
+
name: vllm
|
| 20 |
+
mode: sync # sync: LLM, async: AsyncLLM
|
| 21 |
+
temperature: 0.0
|
| 22 |
+
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
|
| 23 |
+
top_p: 1.0
|
| 24 |
+
max_loras: 1
|
| 25 |
+
prompt_length: 1800
|
| 26 |
+
response_length: 512
|
| 27 |
+
# for vllm rollout
|
| 28 |
+
dtype: bfloat16 # should align with FSDP
|
| 29 |
+
gpu_memory_utilization: 0.9 # ↑ allow cache to allocate
|
| 30 |
+
ignore_eos: False
|
| 31 |
+
enforce_eager: True
|
| 32 |
+
free_cache_engine: True
|
| 33 |
+
load_format: dummy_dtensor
|
| 34 |
+
tensor_model_parallel_size: 1
|
| 35 |
+
max_num_batched_tokens: 8192
|
| 36 |
+
max_model_len: 1800 # ≥ 1200 + 512
|
| 37 |
+
max_num_seqs: 1024
|
| 38 |
+
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
|
| 39 |
+
log_prob_micro_batch_size_per_gpu: 1
|
| 40 |
+
# for fire vllm rollout
|
| 41 |
+
use_fire_sampling: False # enable FIRE https://arxiv.org/abs/2410.21236
|
| 42 |
+
# for hf rollout
|
| 43 |
+
do_sample: True
|
| 44 |
+
disable_log_stats: False
|
| 45 |
+
enable_chunked_prefill: True # OK because 8192 ≥ 3072
|
| 46 |
+
n: 1
|
| 47 |
+
# if beam search activated, top_k, temperature and top_p will be ignored
|
| 48 |
+
|
| 49 |
+
actor:
|
| 50 |
+
strategy: fsdp # This is for backward-compatibility
|
| 51 |
+
ulysses_sequence_parallel_size: 1 # sp size
|
| 52 |
+
fsdp_config:
|
| 53 |
+
fsdp_size: -1
|
| 54 |
+
|
| 55 |
+
ray_init:
|
| 56 |
+
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
|
prompts/__pycache__/base_instruction.cpython-311.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
prompts/__pycache__/infer_prompt.cpython-311.pyc
ADDED
|
Binary file (6.33 kB). View file
|
|
|
prompts/__pycache__/sft_prompt.cpython-311.pyc
ADDED
|
Binary file (7.64 kB). View file
|
|
|
prompts/base_instruction.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def basic_instruction(content, modelname):
|
| 2 |
+
system_instruction = (
|
| 3 |
+
"당신은 한국어 문장 교정 전문가입니다. "
|
| 4 |
+
"입력 문장은 다양한 오류(자모 분리, 철자 오류, 단어 누락 등)를 포함할 수 있습니다. "
|
| 5 |
+
"당신의 임무는 이러한 잘못된 문장을 완전하고 올바른 한국어 문장으로 복원하는 것입니다.\n"
|
| 6 |
+
"규칙:\n"
|
| 7 |
+
"•출력은 반드시 교정된 한국어 문장만 작성합니다.\n"
|
| 8 |
+
"•불필요한 설명, 이유, 따옴표는 포함하지 않습니다.\n"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
user_instruction = (
|
| 12 |
+
f"잘못된 문장(노이즈): {content}\n\n"
|
| 13 |
+
"위 문장을 올바른 한국어 문장으로 교정하세요.\n"
|
| 14 |
+
"출력은 반드시 교정된 문장 하나만 작성하세요."
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
return [
|
| 18 |
+
{"role": "system", "content": system_instruction},
|
| 19 |
+
{"role": "user", "content": user_instruction},
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_instruction_func(modelname):
|
| 24 |
+
return lambda desc, _: basic_instruction(desc, modelname)
|
prompts/infer_prompt.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from prompts.base_instruction import get_instruction_func
|
| 2 |
+
|
| 3 |
+
def infer_dataset(
|
| 4 |
+
model_name: str,
|
| 5 |
+
data_path: str,
|
| 6 |
+
):
|
| 7 |
+
import os, json
|
| 8 |
+
from typing import Any, Dict, List
|
| 9 |
+
from datasets import Dataset
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
|
| 12 |
+
MAX_TOKENS = 1200 # same as SFT
|
| 13 |
+
|
| 14 |
+
jsonl_path = os.path.join(data_path, "jsonl")
|
| 15 |
+
parquet_path = os.path.join(data_path, "parquet")
|
| 16 |
+
os.makedirs(parquet_path, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
test_jsonl = os.path.join(jsonl_path, "test.jsonl")
|
| 19 |
+
|
| 20 |
+
# --- robust load: tolerant JSONL/array/concatenated JSON
|
| 21 |
+
rows = []
|
| 22 |
+
with open(test_jsonl, "r", encoding="utf-8") as f:
|
| 23 |
+
raw = f.read().strip()
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
obj = json.loads(raw)
|
| 27 |
+
if isinstance(obj, list):
|
| 28 |
+
rows = [x for x in obj if isinstance(x, dict)]
|
| 29 |
+
except Exception:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
if not rows:
|
| 33 |
+
for ln in raw.replace("}{", "}\n{").splitlines():
|
| 34 |
+
ln = ln.strip()
|
| 35 |
+
if not ln:
|
| 36 |
+
continue
|
| 37 |
+
try:
|
| 38 |
+
x = json.loads(ln)
|
| 39 |
+
if isinstance(x, dict):
|
| 40 |
+
rows.append(x)
|
| 41 |
+
except Exception:
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
test_dataset = Dataset.from_list(rows)
|
| 45 |
+
|
| 46 |
+
instruction_func = get_instruction_func(model_name)
|
| 47 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 48 |
+
|
| 49 |
+
# ─── helpers ───
|
| 50 |
+
def _coerce(rec: Dict[str, Any]) -> Dict[str, Any]:
|
| 51 |
+
r = dict(rec)
|
| 52 |
+
r["vt"] = str(r.get("vt", "") or "")
|
| 53 |
+
return r
|
| 54 |
+
|
| 55 |
+
def _prompt_tokens(prompt_messages) -> int:
|
| 56 |
+
prompt_str = tokenizer.apply_chat_template(
|
| 57 |
+
prompt_messages, add_generation_prompt=True, tokenize=False
|
| 58 |
+
)
|
| 59 |
+
return len(tokenizer(prompt_str, add_special_tokens=False).input_ids)
|
| 60 |
+
|
| 61 |
+
def make_map_fn(split: str):
|
| 62 |
+
def process_fn(example: Dict[str, Any], idx: int) -> Dict[str, Any]:
|
| 63 |
+
ex = _coerce(example)
|
| 64 |
+
vt = ex.get("vt", "").strip()
|
| 65 |
+
if not vt:
|
| 66 |
+
return {}
|
| 67 |
+
|
| 68 |
+
chat_prompt = instruction_func(vt, model_name)
|
| 69 |
+
total_tokens = _prompt_tokens(chat_prompt)
|
| 70 |
+
|
| 71 |
+
extra = {
|
| 72 |
+
"split": split,
|
| 73 |
+
"index": idx,
|
| 74 |
+
"total_tokens": int(total_tokens),
|
| 75 |
+
"file_id": ex.get("file_id")
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
return {
|
| 79 |
+
"prompt": chat_prompt,
|
| 80 |
+
"extra_info": extra,
|
| 81 |
+
"total_tokens": int(total_tokens)
|
| 82 |
+
}
|
| 83 |
+
return process_fn
|
| 84 |
+
|
| 85 |
+
# build prompts + token counts
|
| 86 |
+
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
|
| 87 |
+
|
| 88 |
+
# drop rows where prompt is empty
|
| 89 |
+
test_dataset = test_dataset.filter(lambda ex: bool(ex.get("prompt")))
|
| 90 |
+
|
| 91 |
+
# drop long prompts (> MAX_TOKENS)
|
| 92 |
+
n_before_len = len(test_dataset)
|
| 93 |
+
test_dataset = test_dataset.filter(lambda ex: ex["total_tokens"] <= MAX_TOKENS)
|
| 94 |
+
kept = len(test_dataset)
|
| 95 |
+
dropped_long = n_before_len - kept
|
| 96 |
+
|
| 97 |
+
out_path = os.path.join(parquet_path, "test.parquet")
|
| 98 |
+
test_dataset.to_parquet(out_path)
|
| 99 |
+
|
| 100 |
+
print(f"[test] kept {kept} rows, dropped_long(>{MAX_TOKENS}) {dropped_long}")
|
| 101 |
+
print(f"Wrote {kept} rows → {out_path}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
transformers
|
| 3 |
+
torch
|
| 4 |
+
accelerate
|
requirements.txt.txt
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.3.1
|
| 2 |
+
accelerate==1.6.0
|
| 3 |
+
aiohappyeyeballs==2.6.1
|
| 4 |
+
aiohttp==3.11.16
|
| 5 |
+
aiohttp-cors==0.8.1
|
| 6 |
+
aiosignal==1.3.2
|
| 7 |
+
airportsdata==20250224
|
| 8 |
+
alabaster==1.0.0
|
| 9 |
+
annotated-types==0.7.0
|
| 10 |
+
antlr4-python3-runtime==4.9.3
|
| 11 |
+
anyio==4.9.0
|
| 12 |
+
astor==0.8.1
|
| 13 |
+
attrs==25.3.0
|
| 14 |
+
babel==2.17.0
|
| 15 |
+
blake3==1.0.4
|
| 16 |
+
cachetools==5.5.2
|
| 17 |
+
certifi==2025.1.31
|
| 18 |
+
charset-normalizer==3.4.1
|
| 19 |
+
click==8.1.8
|
| 20 |
+
cloudpickle==3.1.1
|
| 21 |
+
codetiming==1.4.0
|
| 22 |
+
colorful==0.5.6
|
| 23 |
+
compressed-tensors==0.9.2
|
| 24 |
+
cupy-cuda12x==13.4.1
|
| 25 |
+
datasets==3.5.0
|
| 26 |
+
depyf==0.18.0
|
| 27 |
+
dill==0.3.8
|
| 28 |
+
diskcache==5.6.3
|
| 29 |
+
distlib==0.3.9
|
| 30 |
+
distro==1.9.0
|
| 31 |
+
dnspython==2.7.0
|
| 32 |
+
docker-pycreds==0.4.0
|
| 33 |
+
docutils==0.21.2
|
| 34 |
+
einops==0.8.1
|
| 35 |
+
email_validator==2.2.0
|
| 36 |
+
fastapi==0.115.12
|
| 37 |
+
fastapi-cli==0.0.7
|
| 38 |
+
fastrlock==0.8.3
|
| 39 |
+
filelock==3.18.0
|
| 40 |
+
flash_attn==2.7.4.post1
|
| 41 |
+
flashinfer-python==0.2.5
|
| 42 |
+
frozenlist==1.5.0
|
| 43 |
+
fsspec==2024.6.1
|
| 44 |
+
gguf==0.10.0
|
| 45 |
+
gitdb==4.0.12
|
| 46 |
+
GitPython==3.1.44
|
| 47 |
+
google-api-core==2.24.2
|
| 48 |
+
google-auth==2.38.0
|
| 49 |
+
googleapis-common-protos==1.69.2
|
| 50 |
+
grpcio==1.71.0
|
| 51 |
+
h11==0.14.0
|
| 52 |
+
httpcore==1.0.8
|
| 53 |
+
httptools==0.6.4
|
| 54 |
+
httpx==0.28.1
|
| 55 |
+
huggingface-hub==0.30.2
|
| 56 |
+
hydra-core==1.3.2
|
| 57 |
+
idna==3.10
|
| 58 |
+
imagesize==1.4.1
|
| 59 |
+
importlib_metadata==8.6.1
|
| 60 |
+
interegular==0.3.3
|
| 61 |
+
Jinja2==3.1.6
|
| 62 |
+
jiter==0.9.0
|
| 63 |
+
jiwer==4.0.0
|
| 64 |
+
joblib==1.5.2
|
| 65 |
+
jsonschema==4.23.0
|
| 66 |
+
jsonschema-specifications==2024.10.1
|
| 67 |
+
lark==1.2.2
|
| 68 |
+
Levenshtein==0.27.1
|
| 69 |
+
liger_kernel==0.5.9
|
| 70 |
+
llguidance==0.7.14
|
| 71 |
+
llvmlite==0.43.0
|
| 72 |
+
lm-format-enforcer==0.10.11
|
| 73 |
+
markdown-it-py==3.0.0
|
| 74 |
+
MarkupSafe==2.1.5
|
| 75 |
+
mdurl==0.1.2
|
| 76 |
+
mistral_common==1.5.4
|
| 77 |
+
mpmath==1.3.0
|
| 78 |
+
msgpack==1.1.0
|
| 79 |
+
msgspec==0.19.0
|
| 80 |
+
multidict==6.4.3
|
| 81 |
+
multiprocess==0.70.16
|
| 82 |
+
nest-asyncio==1.6.0
|
| 83 |
+
networkx==3.3
|
| 84 |
+
ninja==1.11.1.4
|
| 85 |
+
nltk==3.9.1
|
| 86 |
+
numba==0.60.0
|
| 87 |
+
numpy==1.26.4
|
| 88 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 89 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 90 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 91 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 92 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 93 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 94 |
+
nvidia-curand-cu12==10.3.5.147
|
| 95 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 96 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 97 |
+
nvidia-cusparselt-cu12==0.6.2
|
| 98 |
+
nvidia-ml-py==12.570.86
|
| 99 |
+
nvidia-nccl-cu12==2.21.5
|
| 100 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 101 |
+
nvidia-nvtx-cu12==12.4.127
|
| 102 |
+
omegaconf==2.3.0
|
| 103 |
+
openai==1.73.0
|
| 104 |
+
opencensus==0.11.4
|
| 105 |
+
opencensus-context==0.1.3
|
| 106 |
+
opencv-python-headless==4.11.0.86
|
| 107 |
+
orjson==3.10.16
|
| 108 |
+
outlines==0.1.11
|
| 109 |
+
outlines_core==0.1.26
|
| 110 |
+
packaging==24.2
|
| 111 |
+
pandas==2.2.3
|
| 112 |
+
partial-json-parser==0.2.1.1.post5
|
| 113 |
+
peft==0.15.1
|
| 114 |
+
pillow==11.2.1
|
| 115 |
+
platformdirs==4.3.7
|
| 116 |
+
prometheus-fastapi-instrumentator==7.1.0
|
| 117 |
+
prometheus_client==0.21.1
|
| 118 |
+
propcache==0.3.1
|
| 119 |
+
proto-plus==1.26.1
|
| 120 |
+
protobuf==5.29.4
|
| 121 |
+
psutil==7.0.0
|
| 122 |
+
py-cpuinfo==9.0.0
|
| 123 |
+
py-spy==0.4.0
|
| 124 |
+
pyarrow==19.0.1
|
| 125 |
+
pyasn1==0.6.1
|
| 126 |
+
pyasn1_modules==0.4.2
|
| 127 |
+
pybind11==2.13.6
|
| 128 |
+
pycountry==24.6.1
|
| 129 |
+
pydantic==2.11.3
|
| 130 |
+
pydantic_core==2.33.1
|
| 131 |
+
Pygments==2.19.1
|
| 132 |
+
pylatexenc==2.10
|
| 133 |
+
python-dateutil==2.9.0.post0
|
| 134 |
+
python-dotenv==1.1.0
|
| 135 |
+
python-json-logger==3.3.0
|
| 136 |
+
python-Levenshtein==0.27.1
|
| 137 |
+
python-multipart==0.0.20
|
| 138 |
+
pytz==2025.2
|
| 139 |
+
PyYAML==6.0.2
|
| 140 |
+
pyzmq==26.4.0
|
| 141 |
+
RapidFuzz==3.14.1
|
| 142 |
+
ray==2.44.1
|
| 143 |
+
referencing==0.36.2
|
| 144 |
+
regex==2024.11.6
|
| 145 |
+
requests==2.32.3
|
| 146 |
+
rich==14.0.0
|
| 147 |
+
rich-toolkit==0.14.1
|
| 148 |
+
roman-numerals-py==3.1.0
|
| 149 |
+
rouge_score==0.1.2
|
| 150 |
+
rpds-py==0.24.0
|
| 151 |
+
rsa==4.9
|
| 152 |
+
safetensors==0.5.3
|
| 153 |
+
scipy==1.15.2
|
| 154 |
+
sentencepiece==0.2.0
|
| 155 |
+
sentry-sdk==2.25.1
|
| 156 |
+
setproctitle==1.3.5
|
| 157 |
+
shellingham==1.5.4
|
| 158 |
+
six==1.17.0
|
| 159 |
+
smart-open==7.1.0
|
| 160 |
+
smmap==5.0.2
|
| 161 |
+
sniffio==1.3.1
|
| 162 |
+
snowballstemmer==2.2.0
|
| 163 |
+
Sphinx==8.2.3
|
| 164 |
+
sphinxcontrib-applehelp==2.0.0
|
| 165 |
+
sphinxcontrib-devhelp==2.0.0
|
| 166 |
+
sphinxcontrib-htmlhelp==2.1.0
|
| 167 |
+
sphinxcontrib-jsmath==1.0.1
|
| 168 |
+
sphinxcontrib-qthelp==2.0.0
|
| 169 |
+
sphinxcontrib-serializinghtml==2.0.0
|
| 170 |
+
starlette==0.46.1
|
| 171 |
+
sympy==1.13.1
|
| 172 |
+
tensordict==0.6.2
|
| 173 |
+
tiktoken==0.9.0
|
| 174 |
+
timeout-decorator==0.5.0
|
| 175 |
+
tokenizers==0.21.1
|
| 176 |
+
torch==2.6.0
|
| 177 |
+
torchaudio==2.6.0
|
| 178 |
+
torchdata==0.11.0
|
| 179 |
+
torchvision==0.21.0
|
| 180 |
+
tqdm==4.67.1
|
| 181 |
+
transformers==4.51.2
|
| 182 |
+
triton==3.2.0
|
| 183 |
+
typer==0.15.2
|
| 184 |
+
typing-inspection==0.4.0
|
| 185 |
+
typing_extensions==4.12.2
|
| 186 |
+
tzdata==2025.2
|
| 187 |
+
urllib3==2.4.0
|
| 188 |
+
uvicorn==0.34.0
|
| 189 |
+
uvloop==0.21.0
|
| 190 |
+
virtualenv==20.30.0
|
| 191 |
+
vllm==0.8.2
|
| 192 |
+
wandb==0.19.9
|
| 193 |
+
watchfiles==1.0.5
|
| 194 |
+
websockets==15.0.1
|
| 195 |
+
wrapt==1.17.2
|
| 196 |
+
xformers==0.0.29.post2
|
| 197 |
+
xgrammar==0.1.16
|
| 198 |
+
xxhash==3.5.0
|
| 199 |
+
yarl==1.19.0
|
| 200 |
+
zipp==3.21.0
|
scripts/sft_infer_pass1.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
#!/bin/bash
|
| 4 |
+
set -x
|
| 5 |
+
|
| 6 |
+
python ./_infer.py \
|
| 7 |
+
model.path=./checkpoints/model \
|
| 8 |
+
model.load_param=False \
|
| 9 |
+
data.path=./data/parquet/test.parquet \
|
| 10 |
+
data.output_path=./model_output/[email protected] \
|
| 11 |
+
data.batch_size=32 data.n_samples=1 \
|
| 12 |
+
rollout.tensor_model_parallel_size=1 \
|
| 13 |
+
rollout.temperature=0.7 rollout.top_p=0.9 rollout.n=1 rollout.do_sample=True \
|
| 14 |
+
rollout.prompt_length=1200 rollout.response_length=512 \
|
| 15 |
+
rollout.enable_chunked_prefill=True \
|
| 16 |
+
+rollout.kv_cache_dtype=fp8_e5m2 \
|
| 17 |
+
rollout.max_model_len=1800 \
|
| 18 |
+
rollout.max_num_batched_tokens=1800 \
|
| 19 |
+
rollout.max_num_seqs=1 \
|
| 20 |
+
+model.trust_remote_code=True \
|
| 21 |
+
+rollout.kv_cache_block_size=16 \
|
| 22 |
+
+rollout.swap_space=16 \
|
| 23 |
+
rollout.gpu_memory_utilization=0.7
|
| 24 |
+
|
| 25 |
+
# python ./_infer.py \
|
| 26 |
+
# model.load_param=True \
|
| 27 |
+
# model.load_param_path="./checkpoints/merged_r1qwen14b/model.pt" \
|
| 28 |
+
# data.output_path="./model_output/[email protected]" \
|
| 29 |
+
# data.n_samples=10\
|
| 30 |
+
# data.path="./data/parquet/test.parquet" \
|
| 31 |
+
# rollout.temperature=0.9\
|
| 32 |
+
# rollout.top_p=0.9 \
|
| 33 |
+
# rollout.n=1\
|
verl/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
from .protocol import DataProto
|
| 19 |
+
from .utils.logging_utils import set_basic_config
|
| 20 |
+
|
| 21 |
+
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
|
| 22 |
+
|
| 23 |
+
with open(os.path.join(version_folder, "version/version")) as f:
|
| 24 |
+
__version__ = f.read().strip()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
set_basic_config(level=logging.WARNING)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = ["DataProto", "__version__"]
|
| 31 |
+
|
| 32 |
+
if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true":
|
| 33 |
+
import importlib
|
| 34 |
+
|
| 35 |
+
if importlib.util.find_spec("modelscope") is None:
|
| 36 |
+
raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`")
|
| 37 |
+
# Patch hub to download models from modelscope to speed up.
|
| 38 |
+
from modelscope.utils.hf_util import patch_hub
|
| 39 |
+
|
| 40 |
+
patch_hub()
|
verl/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
verl/__pycache__/protocol.cpython-311.pyc
ADDED
|
Binary file (47 kB). View file
|
|
|
verl/models/README.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Models
|
| 2 |
+
Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl.
|
| 3 |
+
## Adding a New Huggingface Model
|
| 4 |
+
### Step 1: Copy the model file from HF to verl
|
| 5 |
+
- Add a new file under verl/models/hf
|
| 6 |
+
- Copy ONLY the model file from huggingface/transformers/models to verl/models/hf
|
| 7 |
+
|
| 8 |
+
### Step 2: Modify the model file to use packed inputs
|
| 9 |
+
- Remove all the code related to inference (kv cache)
|
| 10 |
+
- Modify the inputs to include only
|
| 11 |
+
- input_ids (total_nnz,)
|
| 12 |
+
- cu_seqlens (total_nnz + 1,)
|
| 13 |
+
- max_seqlen_in_batch: int
|
| 14 |
+
- Note that this requires using flash attention with causal mask.
|
| 15 |
+
|
| 16 |
+
### Step 2.5: Add tests
|
| 17 |
+
- Add a test to compare this version and the huggingface version
|
| 18 |
+
- Following the infrastructure and add tests to tests/models/hf
|
| 19 |
+
|
| 20 |
+
### Step 3: Add a function to apply tensor parallelism
|
| 21 |
+
- Please follow
|
| 22 |
+
- https://pytorch.org/docs/stable/distributed.tensor.parallel.html
|
| 23 |
+
- https://pytorch.org/tutorials/intermediate/TP_tutorial.html
|
| 24 |
+
- General comments
|
| 25 |
+
- Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward.
|
| 26 |
+
|
| 27 |
+
### Step 4: Add a function to apply data parallelism
|
| 28 |
+
- Please use FSDP2 APIs
|
| 29 |
+
- See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413
|
| 30 |
+
|
| 31 |
+
### Step 5: Add a function to apply pipeline parallelism
|
| 32 |
+
- Comes in Pytorch 2.4
|
| 33 |
+
- Currently only in alpha in nightly version
|
| 34 |
+
- Check torchtitan for more details
|
| 35 |
+
|
verl/models/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
verl/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (176 Bytes). View file
|
|
|
verl/models/__pycache__/registry.cpython-311.pyc
ADDED
|
Binary file (2.13 kB). View file
|
|
|
verl/models/llama/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
verl/models/llama/megatron/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .modeling_llama_megatron import (
|
| 16 |
+
ParallelLlamaForCausalLM,
|
| 17 |
+
# rmpad with megatron
|
| 18 |
+
ParallelLlamaForCausalLMRmPad,
|
| 19 |
+
# rmpad with megatron and pipeline parallelism
|
| 20 |
+
ParallelLlamaForCausalLMRmPadPP,
|
| 21 |
+
ParallelLlamaForValueRmPad,
|
| 22 |
+
ParallelLlamaForValueRmPadPP,
|
| 23 |
+
# original model with megatron
|
| 24 |
+
ParallelLlamaModel,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
"ParallelLlamaForCausalLM",
|
| 29 |
+
"ParallelLlamaForCausalLMRmPad",
|
| 30 |
+
"ParallelLlamaForCausalLMRmPadPP",
|
| 31 |
+
"ParallelLlamaForValueRmPad",
|
| 32 |
+
"ParallelLlamaForValueRmPadPP",
|
| 33 |
+
"ParallelLlamaModel",
|
| 34 |
+
]
|
verl/models/llama/megatron/checkpoint_utils/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
verl/models/llama/megatron/checkpoint_utils/llama_loader.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _megatron_calc_layer_map(config):
|
| 22 |
+
"""Calculate the mapping of global layer_idx to local layer_idx
|
| 23 |
+
Returns:
|
| 24 |
+
layer_map (Dict: int -> tuple(int, int, int)):
|
| 25 |
+
mapping from the global layer index to
|
| 26 |
+
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
| 27 |
+
"""
|
| 28 |
+
from megatron.core import mpu
|
| 29 |
+
|
| 30 |
+
print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}")
|
| 31 |
+
|
| 32 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 33 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 34 |
+
|
| 35 |
+
layer_map = dict()
|
| 36 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 37 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 38 |
+
|
| 39 |
+
for pp_rank_idx in range(pp_size):
|
| 40 |
+
for virtual_pp_rank_idx in range(virtual_pp_size):
|
| 41 |
+
layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
|
| 42 |
+
for layer_idx in range(num_layers_per_model):
|
| 43 |
+
layer_map[layer_offset + layer_idx] = (
|
| 44 |
+
pp_rank_idx,
|
| 45 |
+
virtual_pp_rank_idx,
|
| 46 |
+
layer_idx,
|
| 47 |
+
)
|
| 48 |
+
return layer_map
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):
|
| 52 |
+
"""Load merged state_dict to sharded Megatron module in training."""
|
| 53 |
+
from megatron.core import DistributedDataParallel as LocalDDP
|
| 54 |
+
from megatron.core import mpu
|
| 55 |
+
from megatron.core.transformer.module import Float16Module
|
| 56 |
+
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
| 57 |
+
|
| 58 |
+
from verl.utils.megatron_utils import print_rank_0, unwrap_model
|
| 59 |
+
|
| 60 |
+
start_time = time.time()
|
| 61 |
+
|
| 62 |
+
def _get_gpt_model(model):
|
| 63 |
+
return model
|
| 64 |
+
|
| 65 |
+
def fetch_params(module):
|
| 66 |
+
for param in module.parameters():
|
| 67 |
+
torch.distributed.fetch(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())
|
| 68 |
+
|
| 69 |
+
dp_rank = mpu.get_data_parallel_rank()
|
| 70 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 71 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 72 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 73 |
+
mp_group = mpu.get_model_parallel_group()
|
| 74 |
+
|
| 75 |
+
if torch.distributed.get_rank() == 0:
|
| 76 |
+
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
| 77 |
+
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
| 78 |
+
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
| 79 |
+
|
| 80 |
+
if not isinstance(wrapped_models, (list, tuple)):
|
| 81 |
+
wrapped_models = list(wrapped_models)
|
| 82 |
+
|
| 83 |
+
assert len(wrapped_models) == virtual_pp_size
|
| 84 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 85 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}"
|
| 86 |
+
|
| 87 |
+
models = [None] * len(wrapped_models)
|
| 88 |
+
|
| 89 |
+
for i, wrapped_model in enumerate(wrapped_models):
|
| 90 |
+
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
| 91 |
+
gpt_model_module = _get_gpt_model(models[i])
|
| 92 |
+
assert len(gpt_model_module.model.layers) == num_layers_per_model
|
| 93 |
+
|
| 94 |
+
def _fetch_tensor(tensor, name) -> torch.Tensor:
|
| 95 |
+
"""fetch tensor"""
|
| 96 |
+
nonlocal state_dict
|
| 97 |
+
if tensor is not None:
|
| 98 |
+
tensor.data.copy_(state_dict[name])
|
| 99 |
+
|
| 100 |
+
def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
| 101 |
+
"""fetch tensor in tp shards"""
|
| 102 |
+
nonlocal state_dict
|
| 103 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 104 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 105 |
+
if name in state_dict:
|
| 106 |
+
full_weight = state_dict[name]
|
| 107 |
+
|
| 108 |
+
if mutate_func is not None:
|
| 109 |
+
full_weight = mutate_func(full_weight)
|
| 110 |
+
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
| 111 |
+
if tensor is not None:
|
| 112 |
+
tensor.data.copy_(tensor_chunk[tp_rank])
|
| 113 |
+
else:
|
| 114 |
+
print(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
| 115 |
+
|
| 116 |
+
def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
| 117 |
+
"""fetch tensor in tp shards"""
|
| 118 |
+
nonlocal state_dict
|
| 119 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 120 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 121 |
+
if name in state_dict:
|
| 122 |
+
full_weight = state_dict[name]
|
| 123 |
+
|
| 124 |
+
if mutate_func is not None:
|
| 125 |
+
full_weight = mutate_func(full_weight)
|
| 126 |
+
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
| 127 |
+
if tensor is not None:
|
| 128 |
+
tensor.data.copy_(tensor_chunk[tp_rank])
|
| 129 |
+
else:
|
| 130 |
+
print(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
| 131 |
+
|
| 132 |
+
def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
|
| 133 |
+
"""fetch gate_up tensor in tp shards"""
|
| 134 |
+
nonlocal state_dict
|
| 135 |
+
nonlocal mp_group
|
| 136 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 137 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 138 |
+
if gate_name in state_dict and up_name in state_dict:
|
| 139 |
+
gate_weight = state_dict[gate_name]
|
| 140 |
+
up_weight = state_dict[up_name]
|
| 141 |
+
new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 142 |
+
for i in range(tp_size):
|
| 143 |
+
intermediate_size_tp = config.intermediate_size // tp_size
|
| 144 |
+
gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
|
| 145 |
+
up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
|
| 146 |
+
new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))
|
| 147 |
+
|
| 148 |
+
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
|
| 149 |
+
if tensor is not None:
|
| 150 |
+
tensor.data.copy_(tensor_chunk[tp_rank])
|
| 151 |
+
else:
|
| 152 |
+
print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading")
|
| 153 |
+
|
| 154 |
+
def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:
|
| 155 |
+
"""fetch tensor in tp shards across mp_group"""
|
| 156 |
+
nonlocal state_dict
|
| 157 |
+
nonlocal mp_group
|
| 158 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 159 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 160 |
+
assert q_name in state_dict and k_name in state_dict and v_name in state_dict
|
| 161 |
+
full_weight_q = state_dict[q_name]
|
| 162 |
+
full_weight_k = state_dict[k_name]
|
| 163 |
+
full_weight_v = state_dict[v_name]
|
| 164 |
+
|
| 165 |
+
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
| 166 |
+
|
| 167 |
+
if config.num_key_value_heads >= tp_size:
|
| 168 |
+
q_size_tp = config.hidden_size // tp_size
|
| 169 |
+
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
| 170 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 171 |
+
new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 172 |
+
for i in range(tp_size):
|
| 173 |
+
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
|
| 174 |
+
k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
|
| 175 |
+
v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
|
| 176 |
+
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
|
| 177 |
+
|
| 178 |
+
else:
|
| 179 |
+
q_size_tp = config.hidden_size // tp_size
|
| 180 |
+
kv_size_tp = hidden_size_per_head
|
| 181 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 182 |
+
new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 183 |
+
for i in range(tp_size):
|
| 184 |
+
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
|
| 185 |
+
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
|
| 186 |
+
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
|
| 187 |
+
k_part = full_weight_k[start_idx:end_idx]
|
| 188 |
+
v_part = full_weight_v[start_idx:end_idx]
|
| 189 |
+
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
|
| 190 |
+
|
| 191 |
+
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
|
| 192 |
+
if tensor is not None:
|
| 193 |
+
tensor.data.copy_(tensor_chunk[tp_rank])
|
| 194 |
+
|
| 195 |
+
# Embeddings
|
| 196 |
+
# -------------------
|
| 197 |
+
print_rank_0("loading embeddings...")
|
| 198 |
+
gpt_model_module = _get_gpt_model(models[0])
|
| 199 |
+
embed_tokens_weight = None
|
| 200 |
+
if pp_rank == 0:
|
| 201 |
+
embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
|
| 202 |
+
_fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
|
| 203 |
+
|
| 204 |
+
# Transformer layers
|
| 205 |
+
# -------------------
|
| 206 |
+
layer_map = _megatron_calc_layer_map(config)
|
| 207 |
+
|
| 208 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 209 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 210 |
+
num_layer_per_pp = config.num_hidden_layers // pp_size
|
| 211 |
+
vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
|
| 212 |
+
|
| 213 |
+
layer_list = []
|
| 214 |
+
if vpp_size is not None:
|
| 215 |
+
for vpp_rank in range(vpp_size):
|
| 216 |
+
num_layer_vpp_chunk = num_layer_per_pp // vpp_size
|
| 217 |
+
num_layer_this_model = num_layer_vpp_chunk
|
| 218 |
+
offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk)
|
| 219 |
+
layer_list.extend(list(range(offset, offset + num_layer_this_model)))
|
| 220 |
+
else:
|
| 221 |
+
num_layer_this_model = num_layer_per_pp
|
| 222 |
+
offset = pp_rank * num_layer_per_pp
|
| 223 |
+
layer_list.extend(list(range(offset, offset + num_layer_this_model)))
|
| 224 |
+
|
| 225 |
+
for layer in layer_list:
|
| 226 |
+
print_rank_0(f"loading layer #{layer}...")
|
| 227 |
+
layer_name = f"model.layers.{layer}"
|
| 228 |
+
dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
|
| 229 |
+
|
| 230 |
+
gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
|
| 231 |
+
sync_layer = gpt_model_module.model.layers[dst_layer_idx]
|
| 232 |
+
|
| 233 |
+
_fetch_tensor(
|
| 234 |
+
sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
|
| 235 |
+
f"{layer_name}.input_layernorm.weight",
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
_fetch_tp_shard_tensor_qkv(
|
| 239 |
+
sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
|
| 240 |
+
f"{layer_name}.self_attn.q_proj.weight",
|
| 241 |
+
f"{layer_name}.self_attn.k_proj.weight",
|
| 242 |
+
f"{layer_name}.self_attn.v_proj.weight",
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
_fetch_tp_shard_tensor(
|
| 246 |
+
sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
|
| 247 |
+
f"{layer_name}.self_attn.o_proj.weight",
|
| 248 |
+
chunk_dim=1,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
_fetch_tensor(
|
| 252 |
+
sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
|
| 253 |
+
f"{layer_name}.post_attention_layernorm.weight",
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
_fetch_tp_shard_tensor_gate_up(
|
| 257 |
+
sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
|
| 258 |
+
f"{layer_name}.mlp.gate_proj.weight",
|
| 259 |
+
f"{layer_name}.mlp.up_proj.weight",
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
_fetch_tp_shard_tensor(
|
| 263 |
+
sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
|
| 264 |
+
f"{layer_name}.mlp.down_proj.weight",
|
| 265 |
+
chunk_dim=1,
|
| 266 |
+
)
|
| 267 |
+
# Final Layernorm
|
| 268 |
+
# -------------------
|
| 269 |
+
print_rank_0("loading final layernorm...")
|
| 270 |
+
gpt_model_module = _get_gpt_model(models[-1])
|
| 271 |
+
_fetch_tensor(
|
| 272 |
+
getattr(gpt_model_module.model.norm, "weight", None),
|
| 273 |
+
"model.norm.weight",
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
print_rank_0("loading lm_head...")
|
| 277 |
+
if pp_rank + 1 == pp_size:
|
| 278 |
+
lm_head_weight = gpt_model_module.lm_head.weight
|
| 279 |
+
|
| 280 |
+
if is_value_model:
|
| 281 |
+
if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
|
| 282 |
+
_fetch_tensor(lm_head_weight, "lm_head.weight")
|
| 283 |
+
print_rank_0("load lm_head weight")
|
| 284 |
+
elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
|
| 285 |
+
_fetch_tensor(lm_head_weight, "reward_head.weight")
|
| 286 |
+
print_rank_0("load lm_head from value_head weight")
|
| 287 |
+
else:
|
| 288 |
+
_fetch_tensor(None, "lm_head.weight")
|
| 289 |
+
print_rank_0("fail to match lm_head in value_model")
|
| 290 |
+
else:
|
| 291 |
+
_fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight")
|
| 292 |
+
|
| 293 |
+
dist.barrier()
|
| 294 |
+
torch.cuda.empty_cache()
|
| 295 |
+
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
|
verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _megatron_calc_layer_map(config):
|
| 22 |
+
"""Calculate the mapping of global layer_idx to local layer_idx
|
| 23 |
+
Returns:
|
| 24 |
+
layer_map (Dict: int -> tuple(int, int, int)):
|
| 25 |
+
mapping from the global layer index to
|
| 26 |
+
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
| 27 |
+
"""
|
| 28 |
+
from megatron.core import mpu
|
| 29 |
+
|
| 30 |
+
print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}")
|
| 31 |
+
|
| 32 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 33 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 34 |
+
|
| 35 |
+
layer_map = dict()
|
| 36 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 37 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 38 |
+
|
| 39 |
+
for pp_rank_idx in range(pp_size):
|
| 40 |
+
for virtual_pp_rank_idx in range(virtual_pp_size):
|
| 41 |
+
layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
|
| 42 |
+
for layer_idx in range(num_layers_per_model):
|
| 43 |
+
layer_map[layer_offset + layer_idx] = (
|
| 44 |
+
pp_rank_idx,
|
| 45 |
+
virtual_pp_rank_idx,
|
| 46 |
+
layer_idx,
|
| 47 |
+
)
|
| 48 |
+
return layer_map
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):
|
| 52 |
+
"""Load merged state_dict to sharded Megatron module in training."""
|
| 53 |
+
from megatron.core import DistributedDataParallel as LocalDDP
|
| 54 |
+
from megatron.core import mpu
|
| 55 |
+
from megatron.core.transformer.module import Float16Module
|
| 56 |
+
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
| 57 |
+
|
| 58 |
+
from verl.utils.megatron_utils import print_rank_0, unwrap_model
|
| 59 |
+
|
| 60 |
+
start_time = time.time()
|
| 61 |
+
|
| 62 |
+
def _get_gpt_model(model):
|
| 63 |
+
return model
|
| 64 |
+
|
| 65 |
+
def broadcast_params(module):
|
| 66 |
+
for param in module.parameters():
|
| 67 |
+
torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())
|
| 68 |
+
|
| 69 |
+
dp_rank = mpu.get_data_parallel_rank()
|
| 70 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 71 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 72 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 73 |
+
mp_group = mpu.get_model_parallel_group()
|
| 74 |
+
|
| 75 |
+
if torch.distributed.get_rank() == 0:
|
| 76 |
+
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
| 77 |
+
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
| 78 |
+
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
| 79 |
+
|
| 80 |
+
if not isinstance(wrapped_models, (list, tuple)):
|
| 81 |
+
wrapped_models = list(wrapped_models)
|
| 82 |
+
|
| 83 |
+
assert len(wrapped_models) == virtual_pp_size
|
| 84 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 85 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}"
|
| 86 |
+
|
| 87 |
+
models = [None] * len(wrapped_models)
|
| 88 |
+
|
| 89 |
+
for i, wrapped_model in enumerate(wrapped_models):
|
| 90 |
+
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
| 91 |
+
gpt_model_module = _get_gpt_model(models[i])
|
| 92 |
+
assert len(gpt_model_module.model.layers) == num_layers_per_model
|
| 93 |
+
|
| 94 |
+
def _broadcast_tensor(tensor, name) -> torch.Tensor:
|
| 95 |
+
"""broadcast tensor from rank0 across mp_group"""
|
| 96 |
+
nonlocal state_dict
|
| 97 |
+
nonlocal mp_group
|
| 98 |
+
if torch.distributed.get_rank() == 0:
|
| 99 |
+
if name in state_dict:
|
| 100 |
+
weight = state_dict[name]
|
| 101 |
+
tensor_shape = weight.shape
|
| 102 |
+
else:
|
| 103 |
+
tensor_shape = None
|
| 104 |
+
else:
|
| 105 |
+
weight = None
|
| 106 |
+
tensor_shape = None
|
| 107 |
+
|
| 108 |
+
obj_list = [tensor_shape]
|
| 109 |
+
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
| 110 |
+
tensor_shape = obj_list[0]
|
| 111 |
+
|
| 112 |
+
if tensor_shape is None:
|
| 113 |
+
# all or none ranks in the mp_group should reach here
|
| 114 |
+
print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
if tensor is None:
|
| 118 |
+
tensor = torch.empty(
|
| 119 |
+
tensor_shape,
|
| 120 |
+
dtype=params_dtype,
|
| 121 |
+
device=torch.cuda.current_device(),
|
| 122 |
+
requires_grad=False,
|
| 123 |
+
)
|
| 124 |
+
if torch.distributed.get_rank() == 0:
|
| 125 |
+
tensor.data.copy_(weight)
|
| 126 |
+
dist.broadcast(tensor, src=0, group=mp_group)
|
| 127 |
+
|
| 128 |
+
def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
| 129 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 130 |
+
nonlocal state_dict
|
| 131 |
+
nonlocal mp_group
|
| 132 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 133 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 134 |
+
|
| 135 |
+
if torch.distributed.get_rank() == 0:
|
| 136 |
+
if name in state_dict:
|
| 137 |
+
full_weight = state_dict[name]
|
| 138 |
+
|
| 139 |
+
if mutate_func is not None:
|
| 140 |
+
full_weight = mutate_func(full_weight)
|
| 141 |
+
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
| 142 |
+
chunk_shape = tensor_chunk[0].shape
|
| 143 |
+
else:
|
| 144 |
+
chunk_shape = None
|
| 145 |
+
else:
|
| 146 |
+
chunk_shape = None
|
| 147 |
+
|
| 148 |
+
obj_list = [chunk_shape]
|
| 149 |
+
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
| 150 |
+
chunk_shape = obj_list[0]
|
| 151 |
+
if chunk_shape is None:
|
| 152 |
+
# all or none ranks in the mp_group should reach here
|
| 153 |
+
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
if tensor is None:
|
| 157 |
+
sync_tensor = torch.empty(
|
| 158 |
+
chunk_shape,
|
| 159 |
+
dtype=params_dtype,
|
| 160 |
+
device=torch.cuda.current_device(),
|
| 161 |
+
requires_grad=False,
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
|
| 165 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 166 |
+
|
| 167 |
+
for i in range(tp_size):
|
| 168 |
+
if torch.distributed.get_rank() == 0:
|
| 169 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 170 |
+
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
| 171 |
+
if (i == tp_rank) and (tensor is not None):
|
| 172 |
+
tensor.data.copy_(sync_tensor)
|
| 173 |
+
|
| 174 |
+
def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
| 175 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 176 |
+
nonlocal state_dict
|
| 177 |
+
nonlocal mp_group
|
| 178 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 179 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 180 |
+
|
| 181 |
+
if torch.distributed.get_rank() == 0:
|
| 182 |
+
if name in state_dict:
|
| 183 |
+
full_weight = state_dict[name]
|
| 184 |
+
if mutate_func is not None:
|
| 185 |
+
full_weight = mutate_func(full_weight)
|
| 186 |
+
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
| 187 |
+
chunk_shape = tensor_chunk[0].shape
|
| 188 |
+
else:
|
| 189 |
+
chunk_shape = None
|
| 190 |
+
else:
|
| 191 |
+
chunk_shape = None
|
| 192 |
+
|
| 193 |
+
obj_list = [chunk_shape]
|
| 194 |
+
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
| 195 |
+
chunk_shape = obj_list[0]
|
| 196 |
+
if chunk_shape is None:
|
| 197 |
+
# all or none ranks in the mp_group should reach here
|
| 198 |
+
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
| 199 |
+
return
|
| 200 |
+
|
| 201 |
+
if tensor is None:
|
| 202 |
+
sync_tensor = torch.empty(
|
| 203 |
+
chunk_shape,
|
| 204 |
+
dtype=params_dtype,
|
| 205 |
+
device=torch.cuda.current_device(),
|
| 206 |
+
requires_grad=False,
|
| 207 |
+
)
|
| 208 |
+
else:
|
| 209 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
|
| 210 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 211 |
+
|
| 212 |
+
for i in range(tp_size):
|
| 213 |
+
if torch.distributed.get_rank() == 0:
|
| 214 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 215 |
+
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
| 216 |
+
if (i == tp_rank) and (tensor is not None):
|
| 217 |
+
tensor.data.copy_(sync_tensor)
|
| 218 |
+
|
| 219 |
+
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
|
| 220 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 221 |
+
nonlocal state_dict
|
| 222 |
+
nonlocal mp_group
|
| 223 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 224 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 225 |
+
|
| 226 |
+
if torch.distributed.get_rank() == 0:
|
| 227 |
+
gate_weight = state_dict[gate_name]
|
| 228 |
+
up_weight = state_dict[up_name]
|
| 229 |
+
new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 230 |
+
for i in range(tp_size):
|
| 231 |
+
intermediate_size_tp = config.intermediate_size // tp_size
|
| 232 |
+
gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
|
| 233 |
+
up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
|
| 234 |
+
new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))
|
| 235 |
+
|
| 236 |
+
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
|
| 237 |
+
chunk_shape = tensor_chunk[0].shape
|
| 238 |
+
else:
|
| 239 |
+
chunk_shape = None
|
| 240 |
+
|
| 241 |
+
obj_list = [chunk_shape]
|
| 242 |
+
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
| 243 |
+
chunk_shape = obj_list[0]
|
| 244 |
+
if chunk_shape is None:
|
| 245 |
+
# all or none ranks in the mp_group should reach here
|
| 246 |
+
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
|
| 247 |
+
return
|
| 248 |
+
|
| 249 |
+
if tensor is None:
|
| 250 |
+
sync_tensor = torch.empty(
|
| 251 |
+
chunk_shape,
|
| 252 |
+
dtype=params_dtype,
|
| 253 |
+
device=torch.cuda.current_device(),
|
| 254 |
+
requires_grad=False,
|
| 255 |
+
)
|
| 256 |
+
else:
|
| 257 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
|
| 258 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 259 |
+
|
| 260 |
+
for i in range(tp_size):
|
| 261 |
+
if torch.distributed.get_rank() == 0:
|
| 262 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 263 |
+
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
| 264 |
+
if (i == tp_rank) and (tensor is not None):
|
| 265 |
+
tensor.data.copy_(sync_tensor)
|
| 266 |
+
|
| 267 |
+
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:
|
| 268 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 269 |
+
nonlocal state_dict
|
| 270 |
+
nonlocal mp_group
|
| 271 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 272 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 273 |
+
|
| 274 |
+
if torch.distributed.get_rank() == 0:
|
| 275 |
+
assert q_name in state_dict and k_name in state_dict and v_name in state_dict
|
| 276 |
+
full_weight_q = state_dict[q_name]
|
| 277 |
+
full_weight_k = state_dict[k_name]
|
| 278 |
+
full_weight_v = state_dict[v_name]
|
| 279 |
+
|
| 280 |
+
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
| 281 |
+
|
| 282 |
+
if config.num_key_value_heads >= tp_size:
|
| 283 |
+
q_size_tp = config.hidden_size // tp_size
|
| 284 |
+
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
| 285 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 286 |
+
new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 287 |
+
for i in range(tp_size):
|
| 288 |
+
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
|
| 289 |
+
k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
|
| 290 |
+
v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
|
| 291 |
+
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
|
| 292 |
+
|
| 293 |
+
else:
|
| 294 |
+
q_size_tp = config.hidden_size // tp_size
|
| 295 |
+
kv_size_tp = hidden_size_per_head
|
| 296 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 297 |
+
new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 298 |
+
for i in range(tp_size):
|
| 299 |
+
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
|
| 300 |
+
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
|
| 301 |
+
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
|
| 302 |
+
k_part = full_weight_k[start_idx:end_idx]
|
| 303 |
+
v_part = full_weight_v[start_idx:end_idx]
|
| 304 |
+
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
|
| 305 |
+
|
| 306 |
+
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
|
| 307 |
+
chunk_shape = tensor_chunk[0].shape
|
| 308 |
+
else:
|
| 309 |
+
chunk_shape = None
|
| 310 |
+
|
| 311 |
+
obj_list = [chunk_shape]
|
| 312 |
+
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
| 313 |
+
chunk_shape = obj_list[0]
|
| 314 |
+
if chunk_shape is None:
|
| 315 |
+
# all or none ranks in the mp_group should reach here
|
| 316 |
+
print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading")
|
| 317 |
+
return
|
| 318 |
+
|
| 319 |
+
if tensor is None:
|
| 320 |
+
sync_tensor = torch.empty(
|
| 321 |
+
chunk_shape,
|
| 322 |
+
dtype=params_dtype,
|
| 323 |
+
device=torch.cuda.current_device(),
|
| 324 |
+
requires_grad=False,
|
| 325 |
+
)
|
| 326 |
+
else:
|
| 327 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
|
| 328 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 329 |
+
|
| 330 |
+
for i in range(tp_size):
|
| 331 |
+
if torch.distributed.get_rank() == 0:
|
| 332 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 333 |
+
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
| 334 |
+
if (i == tp_rank) and (tensor is not None):
|
| 335 |
+
tensor.data.copy_(sync_tensor)
|
| 336 |
+
|
| 337 |
+
if dp_rank == 0:
|
| 338 |
+
# Embeddings
|
| 339 |
+
# -------------------
|
| 340 |
+
print_rank_0("loading embeddings...")
|
| 341 |
+
gpt_model_module = _get_gpt_model(models[0])
|
| 342 |
+
embed_tokens_weight = None
|
| 343 |
+
if pp_rank == 0:
|
| 344 |
+
embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
|
| 345 |
+
_broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
|
| 346 |
+
|
| 347 |
+
# Transformer layers
|
| 348 |
+
# -------------------
|
| 349 |
+
layer_map = _megatron_calc_layer_map(config)
|
| 350 |
+
|
| 351 |
+
for layer in range(config.num_hidden_layers):
|
| 352 |
+
print_rank_0(f"loading layer #{layer}...")
|
| 353 |
+
layer_name = f"model.layers.{layer}"
|
| 354 |
+
dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
|
| 355 |
+
|
| 356 |
+
gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
|
| 357 |
+
sync_layer = gpt_model_module.model.layers[dst_layer_idx]
|
| 358 |
+
|
| 359 |
+
_broadcast_tensor(
|
| 360 |
+
sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
|
| 361 |
+
f"{layer_name}.input_layernorm.weight",
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
_broadcast_tp_shard_tensor_qkv(
|
| 365 |
+
sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
|
| 366 |
+
f"{layer_name}.self_attn.q_proj.weight",
|
| 367 |
+
f"{layer_name}.self_attn.k_proj.weight",
|
| 368 |
+
f"{layer_name}.self_attn.v_proj.weight",
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
_broadcast_tp_shard_tensor(
|
| 372 |
+
sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
|
| 373 |
+
f"{layer_name}.self_attn.o_proj.weight",
|
| 374 |
+
chunk_dim=1,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
_broadcast_tensor(
|
| 378 |
+
sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
|
| 379 |
+
f"{layer_name}.post_attention_layernorm.weight",
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
_broadcast_tp_shard_tensor_gate_up(
|
| 383 |
+
sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
|
| 384 |
+
f"{layer_name}.mlp.gate_proj.weight",
|
| 385 |
+
f"{layer_name}.mlp.up_proj.weight",
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
_broadcast_tp_shard_tensor(
|
| 389 |
+
sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
|
| 390 |
+
f"{layer_name}.mlp.down_proj.weight",
|
| 391 |
+
chunk_dim=1,
|
| 392 |
+
)
|
| 393 |
+
# Final Layernorm
|
| 394 |
+
# -------------------
|
| 395 |
+
print_rank_0("loading final layernorm...")
|
| 396 |
+
gpt_model_module = _get_gpt_model(models[-1])
|
| 397 |
+
_broadcast_tensor(
|
| 398 |
+
getattr(gpt_model_module.model.norm, "weight", None),
|
| 399 |
+
"model.norm.weight",
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
print_rank_0("loading lm_head...")
|
| 403 |
+
lm_head_weight = None
|
| 404 |
+
if pp_rank + 1 == pp_size:
|
| 405 |
+
lm_head_weight = gpt_model_module.lm_head.weight
|
| 406 |
+
|
| 407 |
+
if is_value_model:
|
| 408 |
+
if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
|
| 409 |
+
_broadcast_tensor(lm_head_weight, "lm_head.weight")
|
| 410 |
+
print_rank_0("load lm_head weight")
|
| 411 |
+
elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
|
| 412 |
+
_broadcast_tensor(lm_head_weight, "reward_head.weight")
|
| 413 |
+
print_rank_0("load lm_head from value_head weight")
|
| 414 |
+
else:
|
| 415 |
+
_broadcast_tensor(None, "lm_head.weight")
|
| 416 |
+
print_rank_0("fail to match lm_head in value_model")
|
| 417 |
+
else:
|
| 418 |
+
_broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
|
| 419 |
+
dist.barrier()
|
| 420 |
+
# Broadcast weights inside data parallel groups
|
| 421 |
+
for wrapped_model in wrapped_models:
|
| 422 |
+
broadcast_params(wrapped_model)
|
| 423 |
+
|
| 424 |
+
torch.cuda.empty_cache()
|
| 425 |
+
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
|
verl/models/llama/megatron/checkpoint_utils/llama_saver.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
from megatron.core import mpu
|
| 20 |
+
from megatron.core.distributed import DistributedDataParallel as LocalDDP
|
| 21 |
+
from megatron.core.transformer.module import Float16Module
|
| 22 |
+
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
| 23 |
+
|
| 24 |
+
from verl.utils.megatron_utils import print_rank_0, unwrap_model
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
|
| 28 |
+
"""given TP,DP,PP rank to get the global rank."""
|
| 29 |
+
|
| 30 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 31 |
+
dp_size = mpu.get_data_parallel_world_size()
|
| 32 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 33 |
+
assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
|
| 34 |
+
# We only support TP-DP-PP grouping, for correctness when resharding
|
| 35 |
+
return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _megatron_calc_layer_map(config):
|
| 39 |
+
"""Calculate the mapping of global layer_idx to local layer_idx
|
| 40 |
+
Returns:
|
| 41 |
+
layer_map (Dict: int -> tuple(int, int, int)):
|
| 42 |
+
mapping from the global layer index to
|
| 43 |
+
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
| 44 |
+
"""
|
| 45 |
+
from megatron.core import mpu
|
| 46 |
+
|
| 47 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 48 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 49 |
+
|
| 50 |
+
layer_map = dict()
|
| 51 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 52 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 53 |
+
|
| 54 |
+
for pp_rank_idx in range(pp_size):
|
| 55 |
+
for virtual_pp_rank_idx in range(virtual_pp_size):
|
| 56 |
+
layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
|
| 57 |
+
for layer_idx in range(num_layers_per_model):
|
| 58 |
+
layer_map[layer_offset + layer_idx] = (
|
| 59 |
+
pp_rank_idx,
|
| 60 |
+
virtual_pp_rank_idx,
|
| 61 |
+
layer_idx,
|
| 62 |
+
)
|
| 63 |
+
return layer_map
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
|
| 67 |
+
"""Merge sharded parameters of a Megatron module into a merged checkpoint.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
wrapped_models (list of megatron.core.distributed.DistributedDataParallel):
|
| 71 |
+
The local DDP wrapped megatron modules.
|
| 72 |
+
config (str or None):
|
| 73 |
+
HF config for model
|
| 74 |
+
dtype: model params type
|
| 75 |
+
is_value_model: if model is value model
|
| 76 |
+
tie_word_embeddings: tie_word_embeddings, not used in llama, only to keep same interface with qwen2
|
| 77 |
+
Returns:
|
| 78 |
+
state_dict (dict):
|
| 79 |
+
The merged state_dict in rank 0, and an empty dictionary in other ranks.
|
| 80 |
+
"""
|
| 81 |
+
start_time = time.time()
|
| 82 |
+
|
| 83 |
+
def _get_gpt_model(model):
|
| 84 |
+
return model
|
| 85 |
+
|
| 86 |
+
dp_rank = mpu.get_data_parallel_rank()
|
| 87 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 88 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 89 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 90 |
+
mp_group = mpu.get_model_parallel_group()
|
| 91 |
+
|
| 92 |
+
if dist.get_rank() == 0:
|
| 93 |
+
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
| 94 |
+
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
| 95 |
+
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
| 96 |
+
|
| 97 |
+
if not isinstance(wrapped_models, (list, tuple)):
|
| 98 |
+
wrapped_models = list(wrapped_models)
|
| 99 |
+
|
| 100 |
+
assert len(wrapped_models) == virtual_pp_size
|
| 101 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 102 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 103 |
+
|
| 104 |
+
models = [None] * len(wrapped_models)
|
| 105 |
+
|
| 106 |
+
for i, wrapped_model in enumerate(wrapped_models):
|
| 107 |
+
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
| 108 |
+
assert len(models[i].model.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].model.layers), num_layers_per_model)
|
| 109 |
+
|
| 110 |
+
state_dict = dict()
|
| 111 |
+
|
| 112 |
+
def _get_cpu_tensor(tensor: torch.Tensor):
|
| 113 |
+
if tensor is None:
|
| 114 |
+
return None
|
| 115 |
+
if tensor.device == torch.device("cpu"):
|
| 116 |
+
return tensor.detach().clone()
|
| 117 |
+
return tensor.detach().cpu()
|
| 118 |
+
|
| 119 |
+
def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
|
| 120 |
+
"""broadcast tensor across mp_group"""
|
| 121 |
+
nonlocal state_dict
|
| 122 |
+
nonlocal mp_group
|
| 123 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
| 124 |
+
|
| 125 |
+
if torch.distributed.get_rank() == src_rank:
|
| 126 |
+
if tensor is None:
|
| 127 |
+
weight = None
|
| 128 |
+
tensor_shape = None
|
| 129 |
+
else:
|
| 130 |
+
weight = tensor
|
| 131 |
+
tensor_shape = weight.shape
|
| 132 |
+
else:
|
| 133 |
+
weight = None
|
| 134 |
+
tensor_shape = None
|
| 135 |
+
|
| 136 |
+
obj_list = [tensor_shape]
|
| 137 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 138 |
+
tensor_shape = obj_list[0]
|
| 139 |
+
|
| 140 |
+
if tensor_shape is None:
|
| 141 |
+
# all or none ranks in the mp_group should reach here
|
| 142 |
+
print_rank_0(f"tensor:[{name}] not exist, skip collect")
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
if weight is None:
|
| 146 |
+
weight = torch.empty(
|
| 147 |
+
tensor_shape,
|
| 148 |
+
dtype=dtype,
|
| 149 |
+
device=torch.cuda.current_device(),
|
| 150 |
+
requires_grad=False,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
dist.broadcast(weight, src=src_rank, group=mp_group)
|
| 154 |
+
|
| 155 |
+
if torch.distributed.get_rank() == 0:
|
| 156 |
+
state_dict[name] = _get_cpu_tensor(weight)
|
| 157 |
+
|
| 158 |
+
def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
|
| 159 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 160 |
+
nonlocal state_dict
|
| 161 |
+
nonlocal mp_group
|
| 162 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 163 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
| 164 |
+
|
| 165 |
+
chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
|
| 166 |
+
|
| 167 |
+
obj_list = [chunk_shape]
|
| 168 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 169 |
+
chunk_shape = obj_list[0]
|
| 170 |
+
if chunk_shape is None:
|
| 171 |
+
# all or none ranks in the mp_group should reach here
|
| 172 |
+
print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
+
buffer_tensor = torch.empty(
|
| 176 |
+
chunk_shape,
|
| 177 |
+
dtype=dtype,
|
| 178 |
+
device=torch.cuda.current_device(),
|
| 179 |
+
requires_grad=False,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
chunk_tensors = [None] * tp_size
|
| 183 |
+
|
| 184 |
+
for i in range(tp_size):
|
| 185 |
+
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
| 186 |
+
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
| 187 |
+
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
| 188 |
+
|
| 189 |
+
if torch.distributed.get_rank() == 0:
|
| 190 |
+
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
| 191 |
+
|
| 192 |
+
if torch.distributed.get_rank() == 0:
|
| 193 |
+
full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
|
| 194 |
+
if mutate_func is not None:
|
| 195 |
+
full_tensor = mutate_func(full_tensor)
|
| 196 |
+
state_dict[name] = full_tensor
|
| 197 |
+
|
| 198 |
+
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
|
| 199 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 200 |
+
nonlocal state_dict
|
| 201 |
+
nonlocal mp_group
|
| 202 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 203 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
| 204 |
+
|
| 205 |
+
chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
|
| 206 |
+
|
| 207 |
+
obj_list = [chunk_shape]
|
| 208 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 209 |
+
chunk_shape = obj_list[0]
|
| 210 |
+
if chunk_shape is None:
|
| 211 |
+
# all or none ranks in the mp_group should reach here
|
| 212 |
+
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
|
| 213 |
+
return
|
| 214 |
+
|
| 215 |
+
buffer_tensor = torch.empty(
|
| 216 |
+
chunk_shape,
|
| 217 |
+
dtype=dtype,
|
| 218 |
+
device=torch.cuda.current_device(),
|
| 219 |
+
requires_grad=False,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
chunk_tensors = [None] * tp_size
|
| 223 |
+
|
| 224 |
+
for i in range(tp_size):
|
| 225 |
+
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
| 226 |
+
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
| 227 |
+
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
| 228 |
+
|
| 229 |
+
if torch.distributed.get_rank() == 0:
|
| 230 |
+
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
| 231 |
+
|
| 232 |
+
if torch.distributed.get_rank() == 0:
|
| 233 |
+
full_tensor = torch.concat(chunk_tensors, dim=0)
|
| 234 |
+
intermediate_size_tp = config.intermediate_size // tp_size
|
| 235 |
+
gate_weight_list = []
|
| 236 |
+
up_weight_list = []
|
| 237 |
+
for i in range(tp_size):
|
| 238 |
+
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]
|
| 239 |
+
gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
|
| 240 |
+
up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
|
| 241 |
+
gate_weight_list.append(gate_weight_tp)
|
| 242 |
+
up_weight_list.append(up_weight_tp)
|
| 243 |
+
|
| 244 |
+
state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
|
| 245 |
+
state_dict[up_name] = torch.cat(up_weight_list, dim=0)
|
| 246 |
+
|
| 247 |
+
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
|
| 248 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 249 |
+
nonlocal state_dict
|
| 250 |
+
nonlocal mp_group
|
| 251 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 252 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
| 253 |
+
|
| 254 |
+
chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
|
| 255 |
+
|
| 256 |
+
obj_list = [chunk_shape]
|
| 257 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 258 |
+
chunk_shape = obj_list[0]
|
| 259 |
+
if chunk_shape is None:
|
| 260 |
+
# all or none ranks in the mp_group should reach here
|
| 261 |
+
print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
|
| 262 |
+
return
|
| 263 |
+
|
| 264 |
+
buffer_tensor = torch.empty(
|
| 265 |
+
chunk_shape,
|
| 266 |
+
dtype=dtype,
|
| 267 |
+
device=torch.cuda.current_device(),
|
| 268 |
+
requires_grad=False,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
chunk_tensors = [None] * tp_size
|
| 272 |
+
|
| 273 |
+
for i in range(tp_size):
|
| 274 |
+
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
| 275 |
+
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
| 276 |
+
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
| 277 |
+
|
| 278 |
+
if torch.distributed.get_rank() == 0:
|
| 279 |
+
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
| 280 |
+
|
| 281 |
+
if torch.distributed.get_rank() == 0:
|
| 282 |
+
full_tensor = torch.concat(chunk_tensors, dim=0)
|
| 283 |
+
q_weight_list = []
|
| 284 |
+
k_weight_list = []
|
| 285 |
+
v_weight_list = []
|
| 286 |
+
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
| 287 |
+
|
| 288 |
+
if config.num_key_value_heads >= tp_size:
|
| 289 |
+
q_size_tp = config.hidden_size // tp_size
|
| 290 |
+
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
| 291 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 292 |
+
for i in range(tp_size):
|
| 293 |
+
qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
|
| 294 |
+
q_part = qkv_part[:q_size_tp]
|
| 295 |
+
k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]
|
| 296 |
+
v_part = qkv_part[q_size_tp + kv_size_tp : total_size]
|
| 297 |
+
q_weight_list.append(q_part)
|
| 298 |
+
k_weight_list.append(k_part)
|
| 299 |
+
v_weight_list.append(v_part)
|
| 300 |
+
else:
|
| 301 |
+
q_size_tp = config.hidden_size // tp_size
|
| 302 |
+
kv_size_tp = hidden_size_per_head
|
| 303 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 304 |
+
for i in range(tp_size):
|
| 305 |
+
qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
|
| 306 |
+
q_part = qkv_part[:q_size_tp]
|
| 307 |
+
k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]
|
| 308 |
+
v_part = qkv_part[q_size_tp + kv_size_tp : total_size]
|
| 309 |
+
q_weight_list.append(q_part)
|
| 310 |
+
if i * config.num_key_value_heads % tp_size == 0:
|
| 311 |
+
k_weight_list.append(k_part)
|
| 312 |
+
v_weight_list.append(v_part)
|
| 313 |
+
|
| 314 |
+
state_dict[q_name] = torch.cat(q_weight_list, dim=0)
|
| 315 |
+
state_dict[k_name] = torch.cat(k_weight_list, dim=0)
|
| 316 |
+
state_dict[v_name] = torch.cat(v_weight_list, dim=0)
|
| 317 |
+
|
| 318 |
+
# empty cache before collecting weights
|
| 319 |
+
torch.cuda.empty_cache()
|
| 320 |
+
# Embeddings
|
| 321 |
+
# -------------------
|
| 322 |
+
if dp_rank == 0:
|
| 323 |
+
# Embeddings
|
| 324 |
+
# -------------------
|
| 325 |
+
print_rank_0("collecting embeddings...")
|
| 326 |
+
gpt_model_module = _get_gpt_model(models[0])
|
| 327 |
+
_broadcast_tp_shard_tensor(
|
| 328 |
+
gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,
|
| 329 |
+
"model.embed_tokens.weight",
|
| 330 |
+
src_pp_rank=0,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Transformer layers
|
| 334 |
+
# -------------------
|
| 335 |
+
layer_map = _megatron_calc_layer_map(config)
|
| 336 |
+
for layer in range(config.num_hidden_layers):
|
| 337 |
+
print_rank_0(f"collecting layer #{layer}...")
|
| 338 |
+
layer_name = f"model.layers.{layer}"
|
| 339 |
+
src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]
|
| 340 |
+
|
| 341 |
+
gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
|
| 342 |
+
sync_layer = gpt_model_module.model.layers[src_layer_idx]
|
| 343 |
+
|
| 344 |
+
_broadcast_tensor(
|
| 345 |
+
sync_layer.input_layernorm.weight,
|
| 346 |
+
f"{layer_name}.input_layernorm.weight",
|
| 347 |
+
src_pp_rank=src_pp_rank,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
_broadcast_tp_shard_tensor_qkv(
|
| 351 |
+
sync_layer.self_attn.qkv_proj.weight,
|
| 352 |
+
f"{layer_name}.self_attn.q_proj.weight",
|
| 353 |
+
f"{layer_name}.self_attn.k_proj.weight",
|
| 354 |
+
f"{layer_name}.self_attn.v_proj.weight",
|
| 355 |
+
src_pp_rank=src_pp_rank,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
_broadcast_tp_shard_tensor(
|
| 359 |
+
sync_layer.self_attn.o_proj.weight,
|
| 360 |
+
f"{layer_name}.self_attn.o_proj.weight",
|
| 361 |
+
concat_dim=1,
|
| 362 |
+
src_pp_rank=src_pp_rank,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
_broadcast_tensor(
|
| 366 |
+
sync_layer.post_attention_layernorm.weight,
|
| 367 |
+
f"{layer_name}.post_attention_layernorm.weight",
|
| 368 |
+
src_pp_rank=src_pp_rank,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
_broadcast_tp_shard_tensor_gate_up(
|
| 372 |
+
sync_layer.mlp.gate_up_proj.weight,
|
| 373 |
+
f"{layer_name}.mlp.gate_proj.weight",
|
| 374 |
+
f"{layer_name}.mlp.up_proj.weight",
|
| 375 |
+
src_pp_rank=src_pp_rank,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
_broadcast_tp_shard_tensor(
|
| 379 |
+
sync_layer.mlp.down_proj.weight,
|
| 380 |
+
f"{layer_name}.mlp.down_proj.weight",
|
| 381 |
+
concat_dim=1,
|
| 382 |
+
src_pp_rank=src_pp_rank,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Final Layernorm
|
| 386 |
+
# -------------------
|
| 387 |
+
print_rank_0("collecting final layernorm...")
|
| 388 |
+
gpt_model_module = _get_gpt_model(models[-1])
|
| 389 |
+
_broadcast_tensor(
|
| 390 |
+
getattr(gpt_model_module.model.norm, "weight", None),
|
| 391 |
+
"model.norm.weight",
|
| 392 |
+
src_pp_rank=pp_size - 1,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
print_rank_0("collecting lm_head...")
|
| 396 |
+
|
| 397 |
+
if is_value_model:
|
| 398 |
+
if pp_rank == pp_size - 1:
|
| 399 |
+
print(f"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}")
|
| 400 |
+
_broadcast_tensor(
|
| 401 |
+
gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,
|
| 402 |
+
"lm_head.weight",
|
| 403 |
+
src_pp_rank=pp_size - 1,
|
| 404 |
+
)
|
| 405 |
+
_broadcast_tensor(
|
| 406 |
+
gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None else None,
|
| 407 |
+
"reward_head.weight",
|
| 408 |
+
src_pp_rank=pp_size - 1,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
else:
|
| 412 |
+
_broadcast_tp_shard_tensor(
|
| 413 |
+
getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
|
| 414 |
+
"lm_head.weight",
|
| 415 |
+
src_pp_rank=pp_size - 1,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
dist.barrier()
|
| 419 |
+
|
| 420 |
+
torch.cuda.empty_cache()
|
| 421 |
+
if torch.distributed.get_rank() == 0:
|
| 422 |
+
if dtype not in [torch.float16, torch.bfloat16, torch.float32]:
|
| 423 |
+
print(f'Unknown/unsupported dtype to save: {dtype}"')
|
| 424 |
+
exit(1)
|
| 425 |
+
for k, v in state_dict.items():
|
| 426 |
+
if dtype != v.dtype:
|
| 427 |
+
state_dict[k] = v.to(dtype)
|
| 428 |
+
|
| 429 |
+
print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
|
| 430 |
+
return state_dict
|
verl/models/llama/megatron/layers/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .parallel_attention import ParallelLlamaAttention
|
| 16 |
+
from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad
|
| 17 |
+
from .parallel_linear import (
|
| 18 |
+
LinearForLastLayer,
|
| 19 |
+
MergedColumnParallelLinear,
|
| 20 |
+
QKVParallelLinear,
|
| 21 |
+
)
|
| 22 |
+
from .parallel_mlp import ParallelLlamaMLP
|
| 23 |
+
from .parallel_rmsnorm import ParallelLlamaRMSNorm
|
| 24 |
+
|
| 25 |
+
__all__ = ["LinearForLastLayer", "MergedColumnParallelLinear", "QKVParallelLinear", "ParallelLlamaAttention", "ParallelLlamaDecoderLayer", "ParallelLlamaDecoderLayerRmPad", "ParallelLlamaMLP", "ParallelLlamaRMSNorm"]
|
verl/models/llama/megatron/layers/parallel_attention.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
import math
|
| 22 |
+
from typing import Optional, Tuple
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from einops import rearrange
|
| 27 |
+
from flash_attn.layers.rotary import apply_rotary_emb
|
| 28 |
+
from megatron.core import ModelParallelConfig, tensor_parallel
|
| 29 |
+
from megatron.core import parallel_state as mpu
|
| 30 |
+
from torch import nn
|
| 31 |
+
from transformers import LlamaConfig
|
| 32 |
+
from transformers.utils import is_flash_attn_2_available
|
| 33 |
+
|
| 34 |
+
from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear
|
| 35 |
+
from verl.utils.megatron import tensor_parallel as tp_utils
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LlamaRotaryEmbedding(nn.Module):
|
| 39 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.dim = dim
|
| 43 |
+
self.max_position_embeddings = max_position_embeddings
|
| 44 |
+
self.base = base
|
| 45 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 46 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 47 |
+
|
| 48 |
+
# Build here to make `torch.jit.trace` work.
|
| 49 |
+
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())
|
| 50 |
+
|
| 51 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 52 |
+
self.max_seq_len_cached = seq_len
|
| 53 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 54 |
+
|
| 55 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 56 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 57 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 58 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 59 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 60 |
+
|
| 61 |
+
def forward(self, x, seq_len=None):
|
| 62 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 63 |
+
if seq_len > self.max_seq_len_cached:
|
| 64 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 65 |
+
|
| 66 |
+
return (
|
| 67 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 68 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
| 73 |
+
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 76 |
+
self.scaling_factor = scaling_factor
|
| 77 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 78 |
+
|
| 79 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 80 |
+
self.max_seq_len_cached = seq_len
|
| 81 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 82 |
+
t = t / self.scaling_factor
|
| 83 |
+
|
| 84 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 85 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 86 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 87 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 88 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
| 92 |
+
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 95 |
+
self.scaling_factor = scaling_factor
|
| 96 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 97 |
+
|
| 98 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 99 |
+
self.max_seq_len_cached = seq_len
|
| 100 |
+
|
| 101 |
+
if seq_len > self.max_position_embeddings:
|
| 102 |
+
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
|
| 103 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 104 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 105 |
+
|
| 106 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 107 |
+
|
| 108 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 109 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 110 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 111 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 112 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
| 116 |
+
def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None):
|
| 117 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 118 |
+
|
| 119 |
+
self.factor = config.rope_scaling["factor"] # `8` in the original implementation
|
| 120 |
+
self.high_freq_factor = config.rope_scaling["high_freq_factor"] # `1` in the original implementation
|
| 121 |
+
self.low_freq_factor = config.rope_scaling["low_freq_factor"] # `4` in the original implementation
|
| 122 |
+
self.old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
| 123 |
+
|
| 124 |
+
low_freq_wavelen = self.old_context_len / self.low_freq_factor
|
| 125 |
+
high_freq_wavelen = self.old_context_len / self.high_freq_factor
|
| 126 |
+
|
| 127 |
+
wavelen = 2 * math.pi / self.inv_freq
|
| 128 |
+
# wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor
|
| 129 |
+
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq)
|
| 130 |
+
# otherwise: interpolate between the two, using a smooth factor
|
| 131 |
+
smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / (self.high_freq_factor - self.low_freq_factor)
|
| 132 |
+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama
|
| 133 |
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
| 134 |
+
inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
| 135 |
+
|
| 136 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 137 |
+
|
| 138 |
+
# Build here to make `torch.jit.trace` work.
|
| 139 |
+
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def rotate_half(x):
|
| 143 |
+
"""Rotates half the hidden dims of the input."""
|
| 144 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 145 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 146 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
| 150 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 151 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 152 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 153 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 154 |
+
return q_embed, k_embed
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 158 |
+
"""
|
| 159 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 160 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 161 |
+
"""
|
| 162 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 163 |
+
if n_rep == 1:
|
| 164 |
+
return hidden_states
|
| 165 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 166 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class ParallelLlamaAttention(nn.Module):
|
| 170 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 171 |
+
|
| 172 |
+
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.config = config
|
| 175 |
+
self.megatron_config = megatron_config
|
| 176 |
+
self.hidden_size = config.hidden_size
|
| 177 |
+
self.num_heads = config.num_attention_heads
|
| 178 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 179 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 180 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 181 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 182 |
+
self.rope_theta = config.rope_theta
|
| 183 |
+
|
| 184 |
+
# assign values after tp
|
| 185 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 186 |
+
assert self.num_heads % tp_size == 0, f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}"
|
| 187 |
+
assert self.num_key_value_heads % tp_size == 0, f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}"
|
| 188 |
+
|
| 189 |
+
self.num_heads_per_tp = self.num_heads // tp_size
|
| 190 |
+
self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size
|
| 191 |
+
self.hidden_size_per_tp = self.hidden_size // tp_size
|
| 192 |
+
|
| 193 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 194 |
+
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).")
|
| 195 |
+
|
| 196 |
+
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
| 197 |
+
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
|
| 198 |
+
|
| 199 |
+
if megatron_config is not None:
|
| 200 |
+
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 201 |
+
assert row_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 202 |
+
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
|
| 203 |
+
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
|
| 204 |
+
|
| 205 |
+
# [self.q_size, self.k_size, self.v_size]
|
| 206 |
+
self.qkv_proj = QKVParallelLinear(
|
| 207 |
+
input_size=self.hidden_size,
|
| 208 |
+
num_heads=self.num_heads,
|
| 209 |
+
num_key_value_heads=self.num_key_value_heads,
|
| 210 |
+
head_dim=self.head_dim,
|
| 211 |
+
bias=config.attention_bias,
|
| 212 |
+
gather_output=False,
|
| 213 |
+
skip_bias_add=False,
|
| 214 |
+
**column_kwargs,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
self.q_size = self.num_heads_per_tp * self.head_dim
|
| 218 |
+
self.k_size = self.num_key_value_heads_per_tp * self.head_dim
|
| 219 |
+
self.v_size = self.num_key_value_heads_per_tp * self.head_dim
|
| 220 |
+
|
| 221 |
+
self.o_proj = tensor_parallel.RowParallelLinear(
|
| 222 |
+
input_size=self.num_heads * self.head_dim,
|
| 223 |
+
output_size=self.hidden_size,
|
| 224 |
+
bias=config.attention_bias,
|
| 225 |
+
input_is_parallel=True,
|
| 226 |
+
skip_bias_add=False,
|
| 227 |
+
**row_kwargs,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
self._init_rope()
|
| 231 |
+
|
| 232 |
+
def _init_rope(self):
|
| 233 |
+
if self.config.rope_scaling is None:
|
| 234 |
+
self.rotary_emb = LlamaRotaryEmbedding(
|
| 235 |
+
self.head_dim,
|
| 236 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 237 |
+
base=self.rope_theta,
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
rope_type_key = "type" if "type" in self.config.rope_scaling else "rope_type"
|
| 241 |
+
scaling_type = self.config.rope_scaling[rope_type_key]
|
| 242 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 243 |
+
if scaling_type == "linear":
|
| 244 |
+
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
| 245 |
+
self.head_dim,
|
| 246 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 247 |
+
scaling_factor=scaling_factor,
|
| 248 |
+
base=self.rope_theta,
|
| 249 |
+
)
|
| 250 |
+
elif scaling_type == "dynamic":
|
| 251 |
+
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
| 252 |
+
self.head_dim,
|
| 253 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 254 |
+
scaling_factor=scaling_factor,
|
| 255 |
+
base=self.rope_theta,
|
| 256 |
+
)
|
| 257 |
+
elif scaling_type == "llama3":
|
| 258 |
+
self.rotary_emb = LlamaLlama3ScalingRotaryEmbedding(
|
| 259 |
+
self.head_dim,
|
| 260 |
+
self.config,
|
| 261 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 262 |
+
base=self.rope_theta,
|
| 263 |
+
)
|
| 264 |
+
else:
|
| 265 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 266 |
+
|
| 267 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 268 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 269 |
+
|
| 270 |
+
def forward(
|
| 271 |
+
self,
|
| 272 |
+
hidden_states: torch.Tensor,
|
| 273 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 274 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 275 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 276 |
+
bsz, q_len, _ = hidden_states.size()
|
| 277 |
+
qkv = self.qkv_proj(hidden_states)[0]
|
| 278 |
+
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
| 279 |
+
|
| 280 |
+
query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)
|
| 281 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
|
| 282 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
|
| 283 |
+
|
| 284 |
+
kv_seq_len = key_states.shape[-2]
|
| 285 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 286 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 287 |
+
|
| 288 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 289 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 290 |
+
|
| 291 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 292 |
+
|
| 293 |
+
if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):
|
| 294 |
+
raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is {attn_weights.size()}")
|
| 295 |
+
|
| 296 |
+
if attention_mask is not None:
|
| 297 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 298 |
+
raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
|
| 299 |
+
attn_weights = attn_weights + attention_mask
|
| 300 |
+
|
| 301 |
+
# upcast attention to fp32
|
| 302 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 303 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 304 |
+
|
| 305 |
+
if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):
|
| 306 |
+
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is {attn_output.size()}")
|
| 307 |
+
|
| 308 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 309 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)
|
| 310 |
+
attn_output = self.o_proj(attn_output)[0]
|
| 311 |
+
return attn_output
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
"""
|
| 315 |
+
Remove padding Attention
|
| 316 |
+
- Using Flash-attn 2
|
| 317 |
+
- Compatible with sequence parallel
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if is_flash_attn_2_available():
|
| 322 |
+
from flash_attn import flash_attn_varlen_func
|
| 323 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):
|
| 327 |
+
batch_size = position_ids.shape[0]
|
| 328 |
+
|
| 329 |
+
q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim)
|
| 330 |
+
k = pad_input(k, indices, batch_size, sequence_length)
|
| 331 |
+
cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
|
| 332 |
+
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
|
| 333 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 334 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 335 |
+
|
| 336 |
+
q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices)
|
| 337 |
+
k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices)
|
| 338 |
+
|
| 339 |
+
return q_embed, k_embed
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# use flash-attn rotary embeddings with rmpad
|
| 343 |
+
# cos/sin shoudl be: (seq_length, rotary_dim / 2)
|
| 344 |
+
def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):
|
| 345 |
+
q_embed = apply_rotary_emb(q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
| 346 |
+
k_embed = apply_rotary_emb(k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
| 347 |
+
return q_embed, k_embed
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class ParallelLlamaAttentionRmPad(ParallelLlamaAttention):
|
| 351 |
+
def forward(
|
| 352 |
+
self,
|
| 353 |
+
hidden_states: torch.Tensor,
|
| 354 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 355 |
+
sequence_length: int = None,
|
| 356 |
+
indices: torch.Tensor = None,
|
| 357 |
+
cu_seqlens: torch.Tensor = None,
|
| 358 |
+
max_seqlen_in_batch: int = None,
|
| 359 |
+
):
|
| 360 |
+
total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel
|
| 361 |
+
|
| 362 |
+
if self.megatron_config.sequence_parallel:
|
| 363 |
+
total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()
|
| 364 |
+
|
| 365 |
+
qkv = self.qkv_proj(hidden_states)[0]
|
| 366 |
+
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) # (total_nnz, 1, hidden_size)
|
| 367 |
+
|
| 368 |
+
if self.megatron_config.sequence_parallel:
|
| 369 |
+
sequence_parallel_pad = total_nnz - cu_seqlens[-1]
|
| 370 |
+
total_nnz = cu_seqlens[-1] # total_nnz before sp padding
|
| 371 |
+
query_states = query_states[:total_nnz]
|
| 372 |
+
key_states = key_states[:total_nnz]
|
| 373 |
+
value_states = value_states[:total_nnz]
|
| 374 |
+
|
| 375 |
+
# Flash attention requires the input to have the shape
|
| 376 |
+
# batch_size x seq_length x head_dime x hidden_dim
|
| 377 |
+
# therefore we just need to keep the original shape
|
| 378 |
+
query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)
|
| 379 |
+
key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
|
| 380 |
+
value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
|
| 381 |
+
|
| 382 |
+
cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)
|
| 383 |
+
cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half
|
| 384 |
+
query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch)
|
| 385 |
+
# query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices,
|
| 386 |
+
|
| 387 |
+
# TODO: llama does not have dropout in the config??
|
| 388 |
+
# It is recommended to use dropout with FA according to the docs
|
| 389 |
+
# when training.
|
| 390 |
+
dropout_rate = 0.0 # if not self.training else self.attn_dropout
|
| 391 |
+
|
| 392 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 393 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 394 |
+
# cast them back in float16 just to be sure everything works as expected.
|
| 395 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 396 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
| 397 |
+
input_dtype = query_states.dtype
|
| 398 |
+
if input_dtype == torch.float32:
|
| 399 |
+
query_states = query_states.to(torch.float16)
|
| 400 |
+
key_states = key_states.to(torch.float16)
|
| 401 |
+
value_states = value_states.to(torch.float16)
|
| 402 |
+
|
| 403 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 404 |
+
query_states,
|
| 405 |
+
key_states,
|
| 406 |
+
value_states,
|
| 407 |
+
cu_seqlens_q=cu_seqlens,
|
| 408 |
+
cu_seqlens_k=cu_seqlens,
|
| 409 |
+
max_seqlen_q=max_seqlen_in_batch,
|
| 410 |
+
max_seqlen_k=max_seqlen_in_batch,
|
| 411 |
+
dropout_p=dropout_rate,
|
| 412 |
+
softmax_scale=None,
|
| 413 |
+
causal=True,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
attn_output_unpad = attn_output_unpad.to(input_dtype)
|
| 417 |
+
attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()
|
| 418 |
+
|
| 419 |
+
# sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled
|
| 420 |
+
# Here we need to repad
|
| 421 |
+
if self.megatron_config.sequence_parallel:
|
| 422 |
+
attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))
|
| 423 |
+
|
| 424 |
+
attn_output_unpad = self.o_proj(attn_output_unpad)[0]
|
| 425 |
+
return attn_output_unpad
|
verl/models/llama/megatron/layers/parallel_decoder.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
from typing import Optional, Tuple
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from megatron.core import ModelParallelConfig
|
| 25 |
+
from torch import nn
|
| 26 |
+
from transformers import LlamaConfig
|
| 27 |
+
|
| 28 |
+
from verl.utils.megatron_utils import TransformerConfig, convert_config
|
| 29 |
+
|
| 30 |
+
from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad
|
| 31 |
+
from .parallel_mlp import ParallelLlamaMLP
|
| 32 |
+
from .parallel_rmsnorm import ParallelLlamaRMSNorm
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ParallelLlamaDecoderLayer(nn.Module):
|
| 36 |
+
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.config: TransformerConfig = convert_config(config, megatron_config)
|
| 39 |
+
self.layer_idx = layer_idx
|
| 40 |
+
self.hidden_size = config.hidden_size
|
| 41 |
+
self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config)
|
| 42 |
+
|
| 43 |
+
self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
|
| 44 |
+
self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
| 45 |
+
self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
| 46 |
+
|
| 47 |
+
def forward(
|
| 48 |
+
self,
|
| 49 |
+
hidden_states: torch.Tensor,
|
| 50 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 51 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 52 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 53 |
+
"""
|
| 54 |
+
Args:
|
| 55 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 56 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 57 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 58 |
+
output_attentions (`bool`, *optional*):
|
| 59 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 60 |
+
returned tensors for more detail.
|
| 61 |
+
use_cache (`bool`, *optional*):
|
| 62 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 63 |
+
(see `past_key_values`).
|
| 64 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
residual = hidden_states
|
| 68 |
+
|
| 69 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 70 |
+
|
| 71 |
+
# Note: sequence parallel is hidden inside ColumnParallelLinear
|
| 72 |
+
# reduce scatter is hidden inside RowParallelLinear
|
| 73 |
+
|
| 74 |
+
# Self Attention
|
| 75 |
+
hidden_states = self.self_attn(
|
| 76 |
+
hidden_states=hidden_states,
|
| 77 |
+
attention_mask=attention_mask,
|
| 78 |
+
position_ids=position_ids,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# TODO: add sequence parallel operator reduce_scatter here
|
| 82 |
+
|
| 83 |
+
hidden_states = residual + hidden_states
|
| 84 |
+
|
| 85 |
+
# Fully Connected
|
| 86 |
+
residual = hidden_states
|
| 87 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 88 |
+
|
| 89 |
+
# TODO: add sequence parallel operator all_gather here
|
| 90 |
+
|
| 91 |
+
hidden_states = self.mlp(hidden_states)
|
| 92 |
+
|
| 93 |
+
# TODO: add sequence parallel operator reduce_scatter here
|
| 94 |
+
|
| 95 |
+
hidden_states = residual + hidden_states
|
| 96 |
+
|
| 97 |
+
outputs = hidden_states
|
| 98 |
+
|
| 99 |
+
return outputs
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ParallelLlamaDecoderLayerRmPad(nn.Module):
|
| 103 |
+
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.config: TransformerConfig = convert_config(config, megatron_config)
|
| 106 |
+
self.layer_idx = layer_idx
|
| 107 |
+
self.hidden_size = config.hidden_size
|
| 108 |
+
self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config)
|
| 109 |
+
|
| 110 |
+
self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)
|
| 111 |
+
self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
| 112 |
+
self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)
|
| 113 |
+
|
| 114 |
+
def forward(
|
| 115 |
+
self,
|
| 116 |
+
hidden_states: torch.Tensor,
|
| 117 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 118 |
+
sequence_length: int = None,
|
| 119 |
+
indices: torch.Tensor = None,
|
| 120 |
+
cu_seqlens: int = None,
|
| 121 |
+
max_seqlen_in_batch: int = None,
|
| 122 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 123 |
+
residual = hidden_states # (total_nnz // sp, 1, hidden_size)
|
| 124 |
+
|
| 125 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 126 |
+
|
| 127 |
+
# Self Attention
|
| 128 |
+
# (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)
|
| 129 |
+
# -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)
|
| 130 |
+
hidden_states = self.self_attn(
|
| 131 |
+
hidden_states=hidden_states,
|
| 132 |
+
position_ids=position_ids,
|
| 133 |
+
sequence_length=sequence_length,
|
| 134 |
+
indices=indices,
|
| 135 |
+
cu_seqlens=cu_seqlens,
|
| 136 |
+
max_seqlen_in_batch=max_seqlen_in_batch,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
hidden_states = residual + hidden_states
|
| 140 |
+
|
| 141 |
+
# Fully Connected
|
| 142 |
+
# shape changes same as attn
|
| 143 |
+
residual = hidden_states
|
| 144 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 145 |
+
hidden_states = self.mlp(hidden_states)
|
| 146 |
+
hidden_states = residual + hidden_states
|
| 147 |
+
|
| 148 |
+
outputs = hidden_states
|
| 149 |
+
|
| 150 |
+
return outputs
|
verl/models/llama/megatron/layers/parallel_linear.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2023 The vLLM team.
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from megatron.core import tensor_parallel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
input_size,
|
| 24 |
+
num_heads,
|
| 25 |
+
num_key_value_heads,
|
| 26 |
+
head_dim,
|
| 27 |
+
*,
|
| 28 |
+
bias=True,
|
| 29 |
+
gather_output=True,
|
| 30 |
+
skip_bias_add=False,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
# Keep input parameters, and already restrict the head numbers
|
| 34 |
+
self.input_size = input_size
|
| 35 |
+
self.q_output_size = num_heads * head_dim
|
| 36 |
+
self.kv_output_size = num_key_value_heads * head_dim
|
| 37 |
+
self.head_dim = head_dim
|
| 38 |
+
self.gather_output = gather_output
|
| 39 |
+
self.skip_bias_add = skip_bias_add
|
| 40 |
+
|
| 41 |
+
input_size = self.input_size
|
| 42 |
+
output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim
|
| 43 |
+
|
| 44 |
+
super().__init__(
|
| 45 |
+
input_size=input_size,
|
| 46 |
+
output_size=output_size,
|
| 47 |
+
bias=bias,
|
| 48 |
+
gather_output=gather_output,
|
| 49 |
+
skip_bias_add=skip_bias_add,
|
| 50 |
+
**kwargs,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
input_size,
|
| 58 |
+
gate_ouput_size,
|
| 59 |
+
up_output_size,
|
| 60 |
+
*,
|
| 61 |
+
bias=True,
|
| 62 |
+
gather_output=True,
|
| 63 |
+
skip_bias_add=False,
|
| 64 |
+
**kwargs,
|
| 65 |
+
):
|
| 66 |
+
# Keep input parameters, and already restrict the head numbers
|
| 67 |
+
self.input_size = input_size
|
| 68 |
+
self.output_size = gate_ouput_size + up_output_size
|
| 69 |
+
self.gather_output = gather_output
|
| 70 |
+
self.skip_bias_add = skip_bias_add
|
| 71 |
+
|
| 72 |
+
super().__init__(
|
| 73 |
+
input_size=self.input_size,
|
| 74 |
+
output_size=self.output_size,
|
| 75 |
+
bias=bias,
|
| 76 |
+
gather_output=gather_output,
|
| 77 |
+
skip_bias_add=skip_bias_add,
|
| 78 |
+
**kwargs,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class LinearForLastLayer(torch.nn.Linear):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
input_size,
|
| 86 |
+
output_size,
|
| 87 |
+
*,
|
| 88 |
+
config,
|
| 89 |
+
bias=True,
|
| 90 |
+
):
|
| 91 |
+
super().__init__(in_features=input_size, out_features=output_size, bias=bias)
|
| 92 |
+
self.sequence_parallel = config.sequence_parallel
|
| 93 |
+
if self.sequence_parallel:
|
| 94 |
+
self.weight.sequence_parallel = True
|
| 95 |
+
|
| 96 |
+
def forward(
|
| 97 |
+
self,
|
| 98 |
+
input_,
|
| 99 |
+
weight=None,
|
| 100 |
+
runtime_gather_output=None,
|
| 101 |
+
):
|
| 102 |
+
logits = super().forward(input_)
|
| 103 |
+
logits = logits.float()
|
| 104 |
+
if self.sequence_parallel:
|
| 105 |
+
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
|
| 106 |
+
return logits, None
|
verl/models/llama/megatron/layers/parallel_mlp.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
from megatron.core import ModelParallelConfig, tensor_parallel
|
| 22 |
+
from megatron.core import parallel_state as mpu
|
| 23 |
+
from torch import nn
|
| 24 |
+
from transformers.activations import ACT2FN
|
| 25 |
+
|
| 26 |
+
from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear
|
| 27 |
+
from verl.utils.megatron import tensor_parallel as tp_utils
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ParallelLlamaMLP(nn.Module):
|
| 31 |
+
def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.config = config
|
| 34 |
+
self.hidden_size = config.hidden_size
|
| 35 |
+
self.intermediate_size = config.intermediate_size
|
| 36 |
+
# The weight is only [hidden_size, intermediate_size // model_parallel_world_size]
|
| 37 |
+
|
| 38 |
+
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
| 39 |
+
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
|
| 40 |
+
|
| 41 |
+
if megatron_config is not None:
|
| 42 |
+
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 43 |
+
assert row_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 44 |
+
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
|
| 45 |
+
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
|
| 46 |
+
|
| 47 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 48 |
+
|
| 49 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 50 |
+
input_size=self.hidden_size,
|
| 51 |
+
gate_ouput_size=self.intermediate_size,
|
| 52 |
+
up_output_size=self.intermediate_size,
|
| 53 |
+
bias=False,
|
| 54 |
+
gather_output=False,
|
| 55 |
+
skip_bias_add=False,
|
| 56 |
+
**column_kwargs,
|
| 57 |
+
)
|
| 58 |
+
self.gate_size = self.intermediate_size // tp_size
|
| 59 |
+
|
| 60 |
+
self.down_proj = tensor_parallel.RowParallelLinear(
|
| 61 |
+
input_size=self.intermediate_size,
|
| 62 |
+
output_size=self.hidden_size,
|
| 63 |
+
bias=False,
|
| 64 |
+
input_is_parallel=True,
|
| 65 |
+
skip_bias_add=False,
|
| 66 |
+
**row_kwargs,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
gate_up = self.gate_up_proj(x)[0]
|
| 73 |
+
gate, up = gate_up.split(self.gate_size, dim=-1)
|
| 74 |
+
return self.down_proj(self.act_fn(gate) * up)[0]
|
verl/models/llama/megatron/layers/parallel_rmsnorm.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import numbers
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
|
| 19 |
+
from megatron.core import ModelParallelConfig
|
| 20 |
+
from torch import nn
|
| 21 |
+
from transformers import LlamaConfig
|
| 22 |
+
|
| 23 |
+
from verl.utils.megatron import sequence_parallel as sp_utils
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ParallelLlamaRMSNorm(nn.Module):
|
| 27 |
+
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
| 28 |
+
"""
|
| 29 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
if isinstance(config.hidden_size, numbers.Integral):
|
| 33 |
+
normalized_shape = (config.hidden_size,)
|
| 34 |
+
self.normalized_shape = torch.Size(normalized_shape)
|
| 35 |
+
self.weight = nn.Parameter(torch.ones(self.normalized_shape))
|
| 36 |
+
self.variance_epsilon = config.rms_norm_eps
|
| 37 |
+
|
| 38 |
+
if megatron_config.sequence_parallel:
|
| 39 |
+
sp_utils.mark_parameter_as_sequence_parallel(self.weight)
|
| 40 |
+
|
| 41 |
+
def forward(self, hidden_states):
|
| 42 |
+
return fused_rms_norm_affine(
|
| 43 |
+
input=hidden_states,
|
| 44 |
+
weight=self.weight,
|
| 45 |
+
normalized_shape=self.normalized_shape,
|
| 46 |
+
eps=self.variance_epsilon,
|
| 47 |
+
memory_efficient=True,
|
| 48 |
+
)
|
verl/models/llama/megatron/modeling_llama_megatron.py
ADDED
|
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""PyTorch LLaMA model with Megatron-style acceleration."""
|
| 21 |
+
|
| 22 |
+
from typing import Optional, Tuple, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.utils.checkpoint
|
| 26 |
+
from megatron.core import ModelParallelConfig, mpu, tensor_parallel
|
| 27 |
+
from torch import nn
|
| 28 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 29 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 30 |
+
from transformers.models.llama.modeling_llama import CausalLMOutputWithPast
|
| 31 |
+
|
| 32 |
+
from verl.utils.megatron import sequence_parallel as sp_utils
|
| 33 |
+
from verl.utils.megatron import tensor_parallel as tp_utils
|
| 34 |
+
from verl.utils.megatron_utils import TransformerConfig, convert_config
|
| 35 |
+
|
| 36 |
+
from .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm
|
| 37 |
+
|
| 38 |
+
"""
|
| 39 |
+
TODO:
|
| 40 |
+
1. Add weight initialization. Here we need to be careful on TP weight init.
|
| 41 |
+
2. Add sequence parallel
|
| 42 |
+
3. Load checkpoint from meta LLama pretrained checkpoint
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 47 |
+
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):
|
| 48 |
+
"""
|
| 49 |
+
Make causal mask used for bi-directional self-attention.
|
| 50 |
+
"""
|
| 51 |
+
bsz, tgt_len = input_ids_shape
|
| 52 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
| 53 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
| 54 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 55 |
+
mask = mask.to(dtype)
|
| 56 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
| 60 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
| 61 |
+
"""
|
| 62 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 63 |
+
"""
|
| 64 |
+
bsz, src_len = mask.size()
|
| 65 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 66 |
+
|
| 67 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 68 |
+
|
| 69 |
+
inverted_mask = 1.0 - expanded_mask
|
| 70 |
+
|
| 71 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ParallelLlamaModel(nn.Module):
|
| 75 |
+
"""
|
| 76 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
config: LlamaConfig
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.config: TransformerConfig = convert_config(config, megatron_config)
|
| 85 |
+
self.padding_idx = config.pad_token_id
|
| 86 |
+
self.vocab_size = config.vocab_size
|
| 87 |
+
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
| 88 |
+
if megatron_config is not None:
|
| 89 |
+
assert embedding_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 90 |
+
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
| 91 |
+
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)
|
| 92 |
+
|
| 93 |
+
self.layers = nn.ModuleList([ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)])
|
| 94 |
+
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
| 95 |
+
|
| 96 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
| 97 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):
|
| 98 |
+
# create causal mask
|
| 99 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 100 |
+
combined_attention_mask = None
|
| 101 |
+
if input_shape[-1] > 1:
|
| 102 |
+
combined_attention_mask = _make_causal_mask(
|
| 103 |
+
input_shape,
|
| 104 |
+
inputs_embeds.dtype,
|
| 105 |
+
device=inputs_embeds.device,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
if attention_mask is not None:
|
| 109 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 110 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
|
| 111 |
+
combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
| 112 |
+
|
| 113 |
+
return combined_attention_mask
|
| 114 |
+
|
| 115 |
+
def forward(
|
| 116 |
+
self,
|
| 117 |
+
input_ids: torch.LongTensor = None,
|
| 118 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 119 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 120 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
input_ids: input ids. shape (batch_size, seq_length)
|
| 125 |
+
attention_mask: attention_mask. shape (batch_size, seq_length)
|
| 126 |
+
position_ids: position ids. shape (batch_size, seq_length)
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
|
| 130 |
+
"""
|
| 131 |
+
batch_size, seq_length = input_ids.shape
|
| 132 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 133 |
+
# embed positions
|
| 134 |
+
|
| 135 |
+
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)
|
| 136 |
+
|
| 137 |
+
hidden_states = inputs_embeds
|
| 138 |
+
|
| 139 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 140 |
+
layer_outputs = decoder_layer(
|
| 141 |
+
hidden_states,
|
| 142 |
+
attention_mask=attention_mask,
|
| 143 |
+
position_ids=position_ids,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
hidden_states = layer_outputs
|
| 147 |
+
|
| 148 |
+
hidden_states = self.norm(hidden_states)
|
| 149 |
+
|
| 150 |
+
return hidden_states
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class ParallelLlamaForCausalLM(nn.Module):
|
| 154 |
+
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.config: TransformerConfig = convert_config(config, megatron_config)
|
| 157 |
+
self.model = ParallelLlamaModel(config, megatron_config=megatron_config)
|
| 158 |
+
self.vocab_size = config.vocab_size
|
| 159 |
+
|
| 160 |
+
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
| 161 |
+
if megatron_config is not None:
|
| 162 |
+
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 163 |
+
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
| 164 |
+
|
| 165 |
+
self.lm_head = tensor_parallel.ColumnParallelLinear(
|
| 166 |
+
input_size=config.hidden_size,
|
| 167 |
+
output_size=config.vocab_size,
|
| 168 |
+
bias=False,
|
| 169 |
+
gather_output=False,
|
| 170 |
+
skip_bias_add=False,
|
| 171 |
+
**column_kwargs,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
input_ids: torch.LongTensor = None,
|
| 177 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 178 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 179 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 180 |
+
r"""
|
| 181 |
+
Args:
|
| 182 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 183 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 184 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 185 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
```"""
|
| 189 |
+
|
| 190 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 191 |
+
outputs = self.model(
|
| 192 |
+
input_ids=input_ids,
|
| 193 |
+
attention_mask=attention_mask,
|
| 194 |
+
position_ids=position_ids,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
hidden_states = outputs
|
| 198 |
+
logits = self.lm_head(hidden_states)[0]
|
| 199 |
+
|
| 200 |
+
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
|
| 201 |
+
|
| 202 |
+
logits = logits.float()
|
| 203 |
+
return CausalLMOutputWithPast(
|
| 204 |
+
loss=None,
|
| 205 |
+
logits=logits,
|
| 206 |
+
past_key_values=None,
|
| 207 |
+
hidden_states=None,
|
| 208 |
+
attentions=None,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class ParallelLlamaModelRmPad(nn.Module):
|
| 216 |
+
"""
|
| 217 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
config: LlamaConfig
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.config: TransformerConfig = convert_config(config, megatron_config)
|
| 226 |
+
self.padding_idx = config.pad_token_id
|
| 227 |
+
self.vocab_size = config.vocab_size
|
| 228 |
+
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
| 229 |
+
self.megatron_config = megatron_config
|
| 230 |
+
if megatron_config is not None:
|
| 231 |
+
assert embedding_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 232 |
+
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
| 233 |
+
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)
|
| 234 |
+
|
| 235 |
+
self.layers = nn.ModuleList([ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)])
|
| 236 |
+
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
| 237 |
+
|
| 238 |
+
def forward(
|
| 239 |
+
self,
|
| 240 |
+
input_ids: torch.Tensor,
|
| 241 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 242 |
+
sequence_length: int = None,
|
| 243 |
+
indices: torch.Tensor = None,
|
| 244 |
+
cu_seqlens: int = None,
|
| 245 |
+
max_seqlen_in_batch: int = None,
|
| 246 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
input_ids: input ids. shape (1, totol_nnz)
|
| 251 |
+
position_ids: position ids. shape (batch_size, seq_length)
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
|
| 255 |
+
"""
|
| 256 |
+
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
|
| 257 |
+
|
| 258 |
+
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
|
| 259 |
+
inputs_embeds = inputs_embeds.transpose(0, 1)
|
| 260 |
+
if self.megatron_config.sequence_parallel:
|
| 261 |
+
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
|
| 262 |
+
|
| 263 |
+
hidden_states = inputs_embeds
|
| 264 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 265 |
+
layer_outputs = decoder_layer(
|
| 266 |
+
hidden_states,
|
| 267 |
+
position_ids=position_ids,
|
| 268 |
+
sequence_length=sequence_length,
|
| 269 |
+
indices=indices,
|
| 270 |
+
cu_seqlens=cu_seqlens,
|
| 271 |
+
max_seqlen_in_batch=max_seqlen_in_batch,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
hidden_states = layer_outputs
|
| 275 |
+
|
| 276 |
+
hidden_states = self.norm(hidden_states)
|
| 277 |
+
|
| 278 |
+
return hidden_states
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class ParallelLlamaForCausalLMRmPad(nn.Module):
|
| 282 |
+
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.config: TransformerConfig = convert_config(config, megatron_config)
|
| 285 |
+
self.megatron_config = megatron_config
|
| 286 |
+
self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)
|
| 287 |
+
self.vocab_size = config.vocab_size
|
| 288 |
+
self._init_head(config)
|
| 289 |
+
|
| 290 |
+
def _init_head(self, config):
|
| 291 |
+
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
| 292 |
+
if self.megatron_config is not None:
|
| 293 |
+
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 294 |
+
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
| 295 |
+
self.lm_head = tensor_parallel.ColumnParallelLinear(
|
| 296 |
+
input_size=config.hidden_size,
|
| 297 |
+
output_size=config.vocab_size,
|
| 298 |
+
bias=False,
|
| 299 |
+
gather_output=False,
|
| 300 |
+
skip_bias_add=False,
|
| 301 |
+
**column_kwargs,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def _forward_head(self, hidden_states):
|
| 305 |
+
# all_gather from sequence parallel region is performed inside lm_head
|
| 306 |
+
logits = self.lm_head(hidden_states)[0]
|
| 307 |
+
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
|
| 308 |
+
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size)
|
| 309 |
+
return logits
|
| 310 |
+
|
| 311 |
+
def forward(
|
| 312 |
+
self,
|
| 313 |
+
input_ids: torch.LongTensor = None,
|
| 314 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 315 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 316 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 317 |
+
r"""
|
| 318 |
+
Args:
|
| 319 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 320 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 321 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 322 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
```"""
|
| 326 |
+
batch_size, sequence_length = input_ids.shape
|
| 327 |
+
|
| 328 |
+
# remove padding here
|
| 329 |
+
input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1)
|
| 330 |
+
|
| 331 |
+
# pad input_ids to multiple of tp for all tp ranks
|
| 332 |
+
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
|
| 333 |
+
if self.megatron_config.sequence_parallel:
|
| 334 |
+
input_ids = sp_utils.pad_to_sequence_parallel(input_ids)
|
| 335 |
+
|
| 336 |
+
input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad)
|
| 337 |
+
|
| 338 |
+
outputs = self.model(
|
| 339 |
+
input_ids=input_ids,
|
| 340 |
+
position_ids=position_ids,
|
| 341 |
+
sequence_length=sequence_length,
|
| 342 |
+
indices=indices,
|
| 343 |
+
cu_seqlens=cu_seqlens,
|
| 344 |
+
max_seqlen_in_batch=max_seqlen_in_batch,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
hidden_states = outputs
|
| 348 |
+
|
| 349 |
+
logits = self._forward_head(hidden_states)
|
| 350 |
+
|
| 351 |
+
# remove padding from sequence parallel
|
| 352 |
+
if self.megatron_config.sequence_parallel:
|
| 353 |
+
totol_nnz = cu_seqlens[-1]
|
| 354 |
+
logits = logits[:totol_nnz] # (total_nnz_padded)
|
| 355 |
+
|
| 356 |
+
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension
|
| 357 |
+
# add removed padding back
|
| 358 |
+
logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
|
| 359 |
+
|
| 360 |
+
return CausalLMOutputWithPast(
|
| 361 |
+
loss=None,
|
| 362 |
+
logits=logits,
|
| 363 |
+
past_key_values=None,
|
| 364 |
+
hidden_states=None,
|
| 365 |
+
attentions=None,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):
|
| 370 |
+
def _init_head(self, config):
|
| 371 |
+
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
| 372 |
+
if self.megatron_config is not None:
|
| 373 |
+
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 374 |
+
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
| 375 |
+
self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)
|
| 376 |
+
# lm_head is effectively the same as sequence parallel
|
| 377 |
+
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
|
| 378 |
+
|
| 379 |
+
def _forward_head(self, hidden_states):
|
| 380 |
+
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
|
| 381 |
+
logits = logits.float()
|
| 382 |
+
if self.megatron_config.sequence_parallel:
|
| 383 |
+
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
|
| 384 |
+
return logits
|
| 385 |
+
|
| 386 |
+
def forward(
|
| 387 |
+
self,
|
| 388 |
+
input_ids: torch.LongTensor = None,
|
| 389 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 390 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 391 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 392 |
+
output = super().forward(input_ids, attention_mask, position_ids)
|
| 393 |
+
output.logits = torch.squeeze(output.logits, dim=-1)
|
| 394 |
+
return output
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
"""
|
| 398 |
+
Support pipeline parallelism
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class ParallelLlamaModelRmPadPP(nn.Module):
|
| 403 |
+
"""
|
| 404 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
| 405 |
+
This model definition supports pipeline parallelism. To support pp and vpp,
|
| 406 |
+
- This model only contains layer in this pp stage and vpp chunk
|
| 407 |
+
- When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.
|
| 408 |
+
Args:
|
| 409 |
+
config: LlamaConfig
|
| 410 |
+
"""
|
| 411 |
+
|
| 412 |
+
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
|
| 413 |
+
super().__init__()
|
| 414 |
+
self.config: TransformerConfig = convert_config(config, megatron_config)
|
| 415 |
+
self.padding_idx = config.pad_token_id
|
| 416 |
+
self.vocab_size = config.vocab_size
|
| 417 |
+
self.pre_process = pre_process
|
| 418 |
+
self.post_process = post_process
|
| 419 |
+
self.megatron_config = megatron_config
|
| 420 |
+
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
|
| 421 |
+
if megatron_config is not None:
|
| 422 |
+
assert embedding_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 423 |
+
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
|
| 424 |
+
if pre_process:
|
| 425 |
+
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)
|
| 426 |
+
else:
|
| 427 |
+
self.embed_tokens = None
|
| 428 |
+
|
| 429 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 430 |
+
pp_size = megatron_config.pipeline_model_parallel_size
|
| 431 |
+
self.num_layer_per_pp = config.num_hidden_layers // pp_size
|
| 432 |
+
vpp_size = megatron_config.virtual_pipeline_model_parallel_size
|
| 433 |
+
vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()
|
| 434 |
+
|
| 435 |
+
if vpp_size is not None:
|
| 436 |
+
self.layers = nn.ModuleList()
|
| 437 |
+
self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size
|
| 438 |
+
self.num_layer_this_model = self.num_layer_vpp_chunk
|
| 439 |
+
offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk)
|
| 440 |
+
else:
|
| 441 |
+
self.num_layer_this_model = self.num_layer_per_pp
|
| 442 |
+
offset = pp_rank * self.num_layer_per_pp
|
| 443 |
+
|
| 444 |
+
self.layers = nn.ModuleList()
|
| 445 |
+
for i in range(self.num_layer_this_model):
|
| 446 |
+
layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i)
|
| 447 |
+
self.layers.add_module(f"{i}", layer)
|
| 448 |
+
|
| 449 |
+
if post_process:
|
| 450 |
+
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
|
| 451 |
+
else:
|
| 452 |
+
self.norm = None
|
| 453 |
+
|
| 454 |
+
def set_input_tensor(self, input_tensor):
|
| 455 |
+
"""Set input tensor to be used instead of forward()'s input.
|
| 456 |
+
|
| 457 |
+
When doing pipeline parallelism the input from the previous
|
| 458 |
+
stage comes from communication, not from the input, so the
|
| 459 |
+
model's forward_step_func won't have it. This function is thus
|
| 460 |
+
used by internal code to bypass the input provided by the
|
| 461 |
+
forward_step_func"""
|
| 462 |
+
self.input_tensor = input_tensor
|
| 463 |
+
|
| 464 |
+
def forward(
|
| 465 |
+
self,
|
| 466 |
+
input_ids: torch.Tensor,
|
| 467 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 468 |
+
sequence_length: int = None,
|
| 469 |
+
indices: torch.Tensor = None,
|
| 470 |
+
cu_seqlens: int = None,
|
| 471 |
+
max_seqlen_in_batch: int = None,
|
| 472 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
input_ids: input ids. shape (1, totol_nnz)
|
| 477 |
+
position_ids: position ids. shape (batch_size, seq_length)
|
| 478 |
+
|
| 479 |
+
Returns:
|
| 480 |
+
|
| 481 |
+
"""
|
| 482 |
+
if self.pre_process:
|
| 483 |
+
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
|
| 484 |
+
|
| 485 |
+
# vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron
|
| 486 |
+
# so need to deal with it by handle here:
|
| 487 |
+
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
|
| 488 |
+
inputs_embeds = inputs_embeds.transpose(0, 1)
|
| 489 |
+
if self.megatron_config.sequence_parallel:
|
| 490 |
+
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
|
| 491 |
+
|
| 492 |
+
hidden_states = inputs_embeds
|
| 493 |
+
else:
|
| 494 |
+
# self.hidden_states should be passed by Megatron
|
| 495 |
+
hidden_states = self.input_tensor
|
| 496 |
+
|
| 497 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 498 |
+
layer_outputs = decoder_layer(
|
| 499 |
+
hidden_states,
|
| 500 |
+
position_ids=position_ids,
|
| 501 |
+
sequence_length=sequence_length,
|
| 502 |
+
indices=indices,
|
| 503 |
+
cu_seqlens=cu_seqlens,
|
| 504 |
+
max_seqlen_in_batch=max_seqlen_in_batch,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
hidden_states = layer_outputs
|
| 508 |
+
|
| 509 |
+
if self.post_process:
|
| 510 |
+
hidden_states = self.norm(hidden_states)
|
| 511 |
+
|
| 512 |
+
return hidden_states
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class ParallelLlamaForCausalLMRmPadPP(nn.Module):
|
| 516 |
+
def __init__(
|
| 517 |
+
self,
|
| 518 |
+
config: LlamaConfig,
|
| 519 |
+
megatron_config: ModelParallelConfig,
|
| 520 |
+
pre_process,
|
| 521 |
+
post_process,
|
| 522 |
+
share_embeddings_and_output_weights=False,
|
| 523 |
+
):
|
| 524 |
+
super().__init__()
|
| 525 |
+
self.config: TransformerConfig = convert_config(config, megatron_config)
|
| 526 |
+
self.megatron_config = megatron_config
|
| 527 |
+
self.model = ParallelLlamaModelRmPadPP(config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process)
|
| 528 |
+
assert share_embeddings_and_output_weights is False, "Llama Model not supports sharing embedding and output weights"
|
| 529 |
+
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
|
| 530 |
+
self.vocab_size = config.vocab_size
|
| 531 |
+
self.pre_process = pre_process
|
| 532 |
+
self.post_process = post_process
|
| 533 |
+
if post_process:
|
| 534 |
+
self._init_head(config)
|
| 535 |
+
|
| 536 |
+
def set_input_tensor(self, input_tensor):
|
| 537 |
+
"""Set input tensor to be used instead of forward()'s input.
|
| 538 |
+
|
| 539 |
+
When doing pipeline parallelism the input from the previous
|
| 540 |
+
stage comes from communication, not from the input, so the
|
| 541 |
+
model's forward_step_func won't have it. This function is thus
|
| 542 |
+
used by internal code to bypass the input provided by the
|
| 543 |
+
forward_step_func"""
|
| 544 |
+
assert len(input_tensor) == 1
|
| 545 |
+
self.model.set_input_tensor(input_tensor[0])
|
| 546 |
+
|
| 547 |
+
def _init_head(self, config):
|
| 548 |
+
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
| 549 |
+
if self.megatron_config is not None:
|
| 550 |
+
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 551 |
+
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
| 552 |
+
self.lm_head = tensor_parallel.ColumnParallelLinear(
|
| 553 |
+
input_size=config.hidden_size,
|
| 554 |
+
output_size=config.vocab_size,
|
| 555 |
+
bias=False,
|
| 556 |
+
gather_output=False,
|
| 557 |
+
skip_bias_add=False,
|
| 558 |
+
**column_kwargs,
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
def _forward_head(self, hidden_states):
|
| 562 |
+
# all_gather from sequence parallel region is performed inside lm_head
|
| 563 |
+
# logits shape before forward_head hidden_states.shape: [4, 32, 4096]
|
| 564 |
+
logits = self.lm_head(hidden_states)[0]
|
| 565 |
+
# logits shape after forward_head logits.shape: [8, 32, 8]
|
| 566 |
+
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
|
| 567 |
+
return logits
|
| 568 |
+
|
| 569 |
+
def forward(
|
| 570 |
+
self,
|
| 571 |
+
# original input
|
| 572 |
+
*,
|
| 573 |
+
input_ids: torch.LongTensor = None,
|
| 574 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 575 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 576 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 577 |
+
r"""
|
| 578 |
+
Args:
|
| 579 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 580 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 581 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 582 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 583 |
+
|
| 584 |
+
Returns:
|
| 585 |
+
```"""
|
| 586 |
+
|
| 587 |
+
# Note that input_ids, attention_mask and position_ids should be passed to every pp layer.
|
| 588 |
+
# In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
|
| 589 |
+
batch_size, sequence_length = input_ids.shape
|
| 590 |
+
# remove padding here
|
| 591 |
+
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1)
|
| 592 |
+
|
| 593 |
+
# pad input_ids to multiple of tp for all tp ranks
|
| 594 |
+
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
|
| 595 |
+
if self.megatron_config.sequence_parallel:
|
| 596 |
+
input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)
|
| 597 |
+
|
| 598 |
+
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad)
|
| 599 |
+
|
| 600 |
+
outputs = self.model(
|
| 601 |
+
input_ids=input_ids_rmpad,
|
| 602 |
+
position_ids=position_ids,
|
| 603 |
+
sequence_length=sequence_length,
|
| 604 |
+
indices=indices,
|
| 605 |
+
cu_seqlens=cu_seqlens,
|
| 606 |
+
max_seqlen_in_batch=max_seqlen_in_batch,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
if self.post_process:
|
| 610 |
+
hidden_states = outputs
|
| 611 |
+
# print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])
|
| 612 |
+
logits = self._forward_head(hidden_states)
|
| 613 |
+
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16])
|
| 614 |
+
|
| 615 |
+
# remove padding from sequence parallel
|
| 616 |
+
if self.megatron_config.sequence_parallel:
|
| 617 |
+
totol_nnz = cu_seqlens[-1]
|
| 618 |
+
logits = logits[:totol_nnz] # (total_nnz_padded)
|
| 619 |
+
# add removed padding back. If input is already rmpad, we let the caller pad_input
|
| 620 |
+
logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
|
| 621 |
+
|
| 622 |
+
return CausalLMOutputWithPast(
|
| 623 |
+
loss=None,
|
| 624 |
+
logits=logits,
|
| 625 |
+
past_key_values=None,
|
| 626 |
+
hidden_states=None,
|
| 627 |
+
attentions=None,
|
| 628 |
+
)
|
| 629 |
+
else:
|
| 630 |
+
return outputs
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):
|
| 634 |
+
def _init_head(self, config):
|
| 635 |
+
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
|
| 636 |
+
if self.megatron_config is not None:
|
| 637 |
+
assert column_kwargs.get("config", False), "must have ModelParallelConfig"
|
| 638 |
+
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
|
| 639 |
+
self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)
|
| 640 |
+
# lm_head is effectively the same as sequence parallel
|
| 641 |
+
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
|
| 642 |
+
|
| 643 |
+
def _forward_head(self, hidden_states):
|
| 644 |
+
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
|
| 645 |
+
logits = logits.float()
|
| 646 |
+
if self.megatron_config.sequence_parallel:
|
| 647 |
+
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
|
| 648 |
+
return logits
|
| 649 |
+
|
| 650 |
+
def forward(
|
| 651 |
+
self,
|
| 652 |
+
*,
|
| 653 |
+
input_ids: torch.LongTensor = None,
|
| 654 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 655 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 656 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 657 |
+
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
|
| 658 |
+
if self.post_process:
|
| 659 |
+
output.logits = torch.squeeze(output.logits, dim=-1)
|
| 660 |
+
return output
|
| 661 |
+
else:
|
| 662 |
+
return output
|
verl/models/mcore/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from .registry import get_mcore_forward_fn, get_mcore_weight_converter, hf_to_mcore_config, init_mcore_model
|
| 17 |
+
|
| 18 |
+
__all__ = ["hf_to_mcore_config", "init_mcore_model", "get_mcore_forward_fn", "get_mcore_weight_converter"]
|
verl/models/mcore/config_converter.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
# convert huggingface config to mcore transformer config
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from megatron.core.transformer import MLATransformerConfig, TransformerConfig
|
| 23 |
+
from transformers import PretrainedConfig
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **kwargs) -> TransformerConfig:
|
| 27 |
+
"""
|
| 28 |
+
Create a base TransformerConfig with common parameters across different model architectures.
|
| 29 |
+
TODO: (ycl) use dataclass or converter config?
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
hf_config: HuggingFace model configuration
|
| 33 |
+
dtype: Data type for the model
|
| 34 |
+
**kwargs: Additional parameters to override defaults
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
TransformerConfig with common parameters
|
| 38 |
+
"""
|
| 39 |
+
from megatron.core import parallel_state as mpu
|
| 40 |
+
|
| 41 |
+
# Common parallel state parameters
|
| 42 |
+
overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1
|
| 43 |
+
batch_p2p_comm = False
|
| 44 |
+
|
| 45 |
+
# Base configuration with common parameters
|
| 46 |
+
base_config = {
|
| 47 |
+
# Model architecture parameters
|
| 48 |
+
"num_layers": hf_config.num_hidden_layers,
|
| 49 |
+
"hidden_size": hf_config.hidden_size,
|
| 50 |
+
"num_attention_heads": hf_config.num_attention_heads,
|
| 51 |
+
"num_query_groups": hf_config.num_key_value_heads,
|
| 52 |
+
"ffn_hidden_size": hf_config.intermediate_size,
|
| 53 |
+
"attention_dropout": hf_config.attention_dropout,
|
| 54 |
+
"hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0),
|
| 55 |
+
"kv_channels": getattr(hf_config, "head_dim", None),
|
| 56 |
+
"layernorm_epsilon": hf_config.rms_norm_eps,
|
| 57 |
+
# Activation and normalization
|
| 58 |
+
"activation_func": F.silu,
|
| 59 |
+
"normalization": "RMSNorm",
|
| 60 |
+
"gated_linear_unit": True,
|
| 61 |
+
# Data types
|
| 62 |
+
"pipeline_dtype": dtype,
|
| 63 |
+
"params_dtype": dtype,
|
| 64 |
+
"bf16": dtype is torch.bfloat16,
|
| 65 |
+
# Parallel configuration
|
| 66 |
+
"tensor_model_parallel_size": mpu.get_tensor_model_parallel_world_size(),
|
| 67 |
+
"pipeline_model_parallel_size": mpu.get_pipeline_model_parallel_world_size(),
|
| 68 |
+
"virtual_pipeline_model_parallel_size": mpu.get_virtual_pipeline_model_parallel_world_size(),
|
| 69 |
+
"context_parallel_size": mpu.get_context_parallel_world_size(),
|
| 70 |
+
"overlap_p2p_comm": overlap_p2p_comm,
|
| 71 |
+
"batch_p2p_comm": batch_p2p_comm,
|
| 72 |
+
"sequence_parallel": mpu.get_tensor_model_parallel_world_size() > 1,
|
| 73 |
+
# Common settings
|
| 74 |
+
"variable_seq_lengths": True,
|
| 75 |
+
"masked_softmax_fusion": True,
|
| 76 |
+
"moe_token_dispatcher_type": "alltoall",
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
# Update with any provided overrides
|
| 80 |
+
base_config.update(kwargs)
|
| 81 |
+
print(f"Overridden TF init config: {base_config}")
|
| 82 |
+
|
| 83 |
+
return TransformerConfig(**base_config)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
|
| 87 |
+
# for LlamaForCausalLM or Qwen2ForCausalLM
|
| 88 |
+
qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False)
|
| 89 |
+
qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False
|
| 90 |
+
|
| 91 |
+
return _get_base_transformer_config(
|
| 92 |
+
hf_config=hf_config,
|
| 93 |
+
dtype=dtype,
|
| 94 |
+
use_cpu_initialization=False,
|
| 95 |
+
add_bias_linear=False,
|
| 96 |
+
add_qkv_bias=qkv_bias,
|
| 97 |
+
qk_layernorm=qk_layernorm,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
|
| 102 |
+
return _get_base_transformer_config(
|
| 103 |
+
hf_config=hf_config,
|
| 104 |
+
dtype=dtype,
|
| 105 |
+
use_cpu_initialization=False,
|
| 106 |
+
add_bias_linear=False,
|
| 107 |
+
layernorm_epsilon=hf_config.rms_norm_eps,
|
| 108 |
+
# MoE specific
|
| 109 |
+
moe_ffn_hidden_size=hf_config.moe_intermediate_size,
|
| 110 |
+
moe_router_bias_update_rate=0.001,
|
| 111 |
+
moe_router_topk=hf_config.num_experts_per_tok,
|
| 112 |
+
num_moe_experts=hf_config.num_experts,
|
| 113 |
+
moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size,
|
| 114 |
+
moe_aux_loss_coeff=hf_config.router_aux_loss_coef,
|
| 115 |
+
# moe_aux_loss_coeff=0.0,
|
| 116 |
+
moe_router_load_balancing_type="aux_loss",
|
| 117 |
+
moe_shared_expert_overlap=True,
|
| 118 |
+
moe_grouped_gemm=True,
|
| 119 |
+
moe_router_score_function="softmax",
|
| 120 |
+
# Other optimizations
|
| 121 |
+
persist_layer_norm=True,
|
| 122 |
+
bias_activation_fusion=True,
|
| 123 |
+
bias_dropout_fusion=True,
|
| 124 |
+
# Qwen specific
|
| 125 |
+
moe_router_pre_softmax=True,
|
| 126 |
+
add_qkv_bias=True,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
|
| 131 |
+
return _get_base_transformer_config(
|
| 132 |
+
hf_config=hf_config,
|
| 133 |
+
dtype=dtype,
|
| 134 |
+
use_cpu_initialization=False,
|
| 135 |
+
add_bias_linear=False,
|
| 136 |
+
layernorm_epsilon=hf_config.rms_norm_eps,
|
| 137 |
+
# MoE specific
|
| 138 |
+
num_moe_experts=hf_config.num_local_experts,
|
| 139 |
+
moe_aux_loss_coeff=hf_config.router_aux_loss_coef,
|
| 140 |
+
moe_router_topk=hf_config.num_experts_per_tok,
|
| 141 |
+
moe_router_pre_softmax=True,
|
| 142 |
+
moe_router_load_balancing_type="aux_loss",
|
| 143 |
+
moe_router_score_function="softmax",
|
| 144 |
+
moe_shared_expert_intermediate_size=None, # mixtral has no shared expert
|
| 145 |
+
moe_shared_expert_overlap=False, # mixtral has no shared expert
|
| 146 |
+
moe_ffn_hidden_size=hf_config.intermediate_size,
|
| 147 |
+
moe_router_bias_update_rate=0.001,
|
| 148 |
+
# moe_permute_fusion=True, # need TE 2.1+
|
| 149 |
+
moe_grouped_gemm=True,
|
| 150 |
+
# Other optimizations
|
| 151 |
+
persist_layer_norm=True,
|
| 152 |
+
apply_rope_fusion=True,
|
| 153 |
+
bias_activation_fusion=True,
|
| 154 |
+
bias_dropout_fusion=True,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
|
| 159 |
+
return _get_base_transformer_config(
|
| 160 |
+
hf_config=hf_config,
|
| 161 |
+
dtype=dtype,
|
| 162 |
+
use_cpu_initialization=False,
|
| 163 |
+
add_bias_linear=False,
|
| 164 |
+
layernorm_epsilon=hf_config.rms_norm_eps,
|
| 165 |
+
# MoE specific
|
| 166 |
+
moe_ffn_hidden_size=hf_config.moe_intermediate_size,
|
| 167 |
+
moe_router_bias_update_rate=0.001,
|
| 168 |
+
moe_router_topk=hf_config.num_experts_per_tok,
|
| 169 |
+
num_moe_experts=hf_config.num_experts,
|
| 170 |
+
moe_aux_loss_coeff=hf_config.router_aux_loss_coef,
|
| 171 |
+
# moe_aux_loss_coeff=0.0,
|
| 172 |
+
moe_router_load_balancing_type="aux_loss",
|
| 173 |
+
moe_grouped_gemm=True,
|
| 174 |
+
moe_router_score_function="softmax",
|
| 175 |
+
# Other optimizations
|
| 176 |
+
persist_layer_norm=True,
|
| 177 |
+
bias_activation_fusion=True,
|
| 178 |
+
bias_dropout_fusion=True,
|
| 179 |
+
# Qwen specific
|
| 180 |
+
moe_router_pre_softmax=True,
|
| 181 |
+
qk_layernorm=True,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype) -> MLATransformerConfig:
|
| 186 |
+
# DeepseekV3ForCausalLM
|
| 187 |
+
raise NotImplementedError("DeepseekV3ForCausalLM is not supported yet")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
|
| 191 |
+
# Qwen2_5_VLForConditionalGeneration
|
| 192 |
+
raise NotImplementedError("Qwen2_5_VLForConditionalGeneration is not supported yet")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def hf_to_mcore_config_llama4(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
|
| 196 |
+
# Llama4ForConditionalGeneration
|
| 197 |
+
raise NotImplementedError("Llama4ForConditionalGeneration is not supported yet")
|
verl/models/mcore/loader.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
|
| 21 |
+
from .saver import _megatron_calc_global_rank
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _megatron_calc_layer_map(config):
|
| 25 |
+
"""Calculate the mapping of global layer_idx to local layer_idx
|
| 26 |
+
Returns:
|
| 27 |
+
layer_map (Dict: int -> tuple(int, int, int)):
|
| 28 |
+
mapping from the global layer index to
|
| 29 |
+
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
| 30 |
+
"""
|
| 31 |
+
from megatron.core import mpu
|
| 32 |
+
|
| 33 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 34 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 35 |
+
|
| 36 |
+
layer_map = dict()
|
| 37 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 38 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 39 |
+
|
| 40 |
+
for pp_rank_idx in range(pp_size):
|
| 41 |
+
for virtual_pp_rank_idx in range(virtual_pp_size):
|
| 42 |
+
layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
|
| 43 |
+
for layer_idx in range(num_layers_per_model):
|
| 44 |
+
layer_map[layer_offset + layer_idx] = (
|
| 45 |
+
pp_rank_idx,
|
| 46 |
+
virtual_pp_rank_idx,
|
| 47 |
+
layer_idx,
|
| 48 |
+
)
|
| 49 |
+
return layer_map
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False):
|
| 53 |
+
"""Load merged state_dict to sharded Megatron module in training."""
|
| 54 |
+
from megatron.core import DistributedDataParallel as LocalDDP
|
| 55 |
+
from megatron.core import mpu
|
| 56 |
+
from megatron.core.transformer.module import Float16Module
|
| 57 |
+
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
| 58 |
+
|
| 59 |
+
from verl.utils.megatron_utils import print_rank_0, unwrap_model
|
| 60 |
+
|
| 61 |
+
start_time = time.time()
|
| 62 |
+
|
| 63 |
+
def _get_gpt_model(model):
|
| 64 |
+
return model
|
| 65 |
+
|
| 66 |
+
def broadcast_params(module):
|
| 67 |
+
for param in module.parameters():
|
| 68 |
+
torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())
|
| 69 |
+
|
| 70 |
+
dp_rank = mpu.get_data_parallel_rank()
|
| 71 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 72 |
+
cp_rank = mpu.get_context_parallel_rank()
|
| 73 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank)
|
| 74 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 75 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 76 |
+
mp_group = mpu.get_model_parallel_group()
|
| 77 |
+
|
| 78 |
+
if torch.distributed.get_rank() == src_rank:
|
| 79 |
+
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
| 80 |
+
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
| 81 |
+
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
| 82 |
+
|
| 83 |
+
if not isinstance(wrapped_models, (list, tuple)):
|
| 84 |
+
wrapped_models = list(wrapped_models)
|
| 85 |
+
|
| 86 |
+
assert len(wrapped_models) == virtual_pp_size
|
| 87 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 88 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 89 |
+
|
| 90 |
+
models = [None] * len(wrapped_models)
|
| 91 |
+
|
| 92 |
+
for i, wrapped_model in enumerate(wrapped_models):
|
| 93 |
+
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
| 94 |
+
gpt_model_module = _get_gpt_model(models[i])
|
| 95 |
+
assert len(gpt_model_module.decoder.layers) == num_layers_per_model
|
| 96 |
+
|
| 97 |
+
def _broadcast_tensor(tensor, name) -> torch.Tensor:
|
| 98 |
+
"""broadcast tensor from rank0 across mp_group"""
|
| 99 |
+
nonlocal state_dict
|
| 100 |
+
nonlocal mp_group
|
| 101 |
+
if torch.distributed.get_rank() == src_rank:
|
| 102 |
+
if name in state_dict:
|
| 103 |
+
weight = state_dict[name]
|
| 104 |
+
tensor_shape = weight.shape
|
| 105 |
+
else:
|
| 106 |
+
tensor_shape = None
|
| 107 |
+
else:
|
| 108 |
+
weight = None
|
| 109 |
+
tensor_shape = None
|
| 110 |
+
|
| 111 |
+
obj_list = [tensor_shape]
|
| 112 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 113 |
+
tensor_shape = obj_list[0]
|
| 114 |
+
|
| 115 |
+
if tensor_shape is None:
|
| 116 |
+
# all or none ranks in the mp_group should reach here
|
| 117 |
+
print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
if tensor is None:
|
| 121 |
+
tensor = torch.empty(
|
| 122 |
+
tensor_shape,
|
| 123 |
+
dtype=params_dtype,
|
| 124 |
+
device=torch.cuda.current_device(),
|
| 125 |
+
requires_grad=False,
|
| 126 |
+
)
|
| 127 |
+
if torch.distributed.get_rank() == src_rank:
|
| 128 |
+
tensor.data.copy_(weight)
|
| 129 |
+
dist.broadcast(tensor, src=src_rank, group=mp_group)
|
| 130 |
+
|
| 131 |
+
def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
| 132 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 133 |
+
nonlocal state_dict
|
| 134 |
+
nonlocal mp_group
|
| 135 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 136 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 137 |
+
|
| 138 |
+
if torch.distributed.get_rank() == src_rank:
|
| 139 |
+
if name in state_dict:
|
| 140 |
+
full_weight = state_dict[name]
|
| 141 |
+
|
| 142 |
+
if mutate_func is not None:
|
| 143 |
+
full_weight = mutate_func(full_weight)
|
| 144 |
+
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
| 145 |
+
chunk_shape = tensor_chunk[0].shape
|
| 146 |
+
else:
|
| 147 |
+
chunk_shape = None
|
| 148 |
+
else:
|
| 149 |
+
chunk_shape = None
|
| 150 |
+
|
| 151 |
+
obj_list = [chunk_shape]
|
| 152 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 153 |
+
chunk_shape = obj_list[0]
|
| 154 |
+
if chunk_shape is None:
|
| 155 |
+
# all or none ranks in the mp_group should reach here
|
| 156 |
+
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
| 157 |
+
return
|
| 158 |
+
|
| 159 |
+
if tensor is None:
|
| 160 |
+
sync_tensor = torch.empty(
|
| 161 |
+
chunk_shape,
|
| 162 |
+
dtype=params_dtype,
|
| 163 |
+
device=torch.cuda.current_device(),
|
| 164 |
+
requires_grad=False,
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
|
| 168 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 169 |
+
|
| 170 |
+
for i in range(tp_size):
|
| 171 |
+
if torch.distributed.get_rank() == src_rank:
|
| 172 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 173 |
+
dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
|
| 174 |
+
if (i == tp_rank) and (tensor is not None):
|
| 175 |
+
tensor.data.copy_(sync_tensor)
|
| 176 |
+
|
| 177 |
+
def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
| 178 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 179 |
+
nonlocal state_dict
|
| 180 |
+
nonlocal mp_group
|
| 181 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 182 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 183 |
+
|
| 184 |
+
if torch.distributed.get_rank() == src_rank:
|
| 185 |
+
if name in state_dict:
|
| 186 |
+
full_weight = state_dict[name]
|
| 187 |
+
if mutate_func is not None:
|
| 188 |
+
full_weight = mutate_func(full_weight)
|
| 189 |
+
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
| 190 |
+
chunk_shape = tensor_chunk[0].shape
|
| 191 |
+
else:
|
| 192 |
+
chunk_shape = None
|
| 193 |
+
else:
|
| 194 |
+
chunk_shape = None
|
| 195 |
+
|
| 196 |
+
obj_list = [chunk_shape]
|
| 197 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 198 |
+
chunk_shape = obj_list[0]
|
| 199 |
+
if chunk_shape is None:
|
| 200 |
+
# all or none ranks in the mp_group should reach here
|
| 201 |
+
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
if tensor is None:
|
| 205 |
+
sync_tensor = torch.empty(
|
| 206 |
+
chunk_shape,
|
| 207 |
+
dtype=params_dtype,
|
| 208 |
+
device=torch.cuda.current_device(),
|
| 209 |
+
requires_grad=False,
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
|
| 213 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 214 |
+
|
| 215 |
+
for i in range(tp_size):
|
| 216 |
+
if torch.distributed.get_rank() == src_rank:
|
| 217 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 218 |
+
dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
|
| 219 |
+
if (i == tp_rank) and (tensor is not None):
|
| 220 |
+
tensor.data.copy_(sync_tensor)
|
| 221 |
+
|
| 222 |
+
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
|
| 223 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 224 |
+
nonlocal state_dict
|
| 225 |
+
nonlocal mp_group
|
| 226 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 227 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 228 |
+
|
| 229 |
+
if torch.distributed.get_rank() == src_rank:
|
| 230 |
+
gate_weight = state_dict[gate_name]
|
| 231 |
+
up_weight = state_dict[up_name]
|
| 232 |
+
new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 233 |
+
for i in range(tp_size):
|
| 234 |
+
intermediate_size_tp = config.intermediate_size // tp_size
|
| 235 |
+
gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
|
| 236 |
+
up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
|
| 237 |
+
new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))
|
| 238 |
+
|
| 239 |
+
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
|
| 240 |
+
chunk_shape = tensor_chunk[0].shape
|
| 241 |
+
else:
|
| 242 |
+
chunk_shape = None
|
| 243 |
+
|
| 244 |
+
obj_list = [chunk_shape]
|
| 245 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 246 |
+
chunk_shape = obj_list[0]
|
| 247 |
+
if chunk_shape is None:
|
| 248 |
+
# all or none ranks in the mp_group should reach here
|
| 249 |
+
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
|
| 250 |
+
return
|
| 251 |
+
|
| 252 |
+
if tensor is None:
|
| 253 |
+
sync_tensor = torch.empty(
|
| 254 |
+
chunk_shape,
|
| 255 |
+
dtype=params_dtype,
|
| 256 |
+
device=torch.cuda.current_device(),
|
| 257 |
+
requires_grad=False,
|
| 258 |
+
)
|
| 259 |
+
else:
|
| 260 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
|
| 261 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 262 |
+
|
| 263 |
+
for i in range(tp_size):
|
| 264 |
+
if torch.distributed.get_rank() == src_rank:
|
| 265 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 266 |
+
dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
|
| 267 |
+
if (i == tp_rank) and (tensor is not None):
|
| 268 |
+
tensor.data.copy_(sync_tensor)
|
| 269 |
+
|
| 270 |
+
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:
|
| 271 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 272 |
+
nonlocal state_dict
|
| 273 |
+
nonlocal mp_group
|
| 274 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 275 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 276 |
+
|
| 277 |
+
if torch.distributed.get_rank() == src_rank:
|
| 278 |
+
assert q_name in state_dict and k_name in state_dict and v_name in state_dict
|
| 279 |
+
full_weight_q = state_dict[q_name]
|
| 280 |
+
full_weight_k = state_dict[k_name]
|
| 281 |
+
full_weight_v = state_dict[v_name]
|
| 282 |
+
|
| 283 |
+
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
| 284 |
+
|
| 285 |
+
if config.num_key_value_heads >= tp_size:
|
| 286 |
+
q_size_tp = config.hidden_size // tp_size
|
| 287 |
+
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
| 288 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 289 |
+
sizes = [total_size * tp_size]
|
| 290 |
+
if not bias:
|
| 291 |
+
sizes.append(config.hidden_size)
|
| 292 |
+
new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device())
|
| 293 |
+
for i in range(tp_size):
|
| 294 |
+
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
|
| 295 |
+
k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
|
| 296 |
+
v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
|
| 297 |
+
num_query_groups_per_partition = models[0].config.num_query_groups // tp_size
|
| 298 |
+
new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]
|
| 299 |
+
q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0)
|
| 300 |
+
k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0)
|
| 301 |
+
v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0)
|
| 302 |
+
total_size_per_head = total_size // num_query_groups_per_partition
|
| 303 |
+
for j in range(num_query_groups_per_partition):
|
| 304 |
+
new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0))
|
| 305 |
+
|
| 306 |
+
else:
|
| 307 |
+
q_size_tp = config.hidden_size // tp_size
|
| 308 |
+
kv_size_tp = hidden_size_per_head
|
| 309 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 310 |
+
sizes = [total_size * tp_size]
|
| 311 |
+
if not bias:
|
| 312 |
+
sizes.append(config.hidden_size)
|
| 313 |
+
new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device())
|
| 314 |
+
for i in range(tp_size):
|
| 315 |
+
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
|
| 316 |
+
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
|
| 317 |
+
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
|
| 318 |
+
k_part = full_weight_k[start_idx:end_idx]
|
| 319 |
+
v_part = full_weight_v[start_idx:end_idx]
|
| 320 |
+
new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]
|
| 321 |
+
q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0)
|
| 322 |
+
k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0)
|
| 323 |
+
v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0)
|
| 324 |
+
total_size_per_head = total_size // config.num_attention_heads
|
| 325 |
+
for j in range(config.num_attention_heads):
|
| 326 |
+
new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0))
|
| 327 |
+
|
| 328 |
+
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
|
| 329 |
+
chunk_shape = tensor_chunk[0].shape
|
| 330 |
+
else:
|
| 331 |
+
chunk_shape = None
|
| 332 |
+
|
| 333 |
+
obj_list = [chunk_shape]
|
| 334 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 335 |
+
chunk_shape = obj_list[0]
|
| 336 |
+
if chunk_shape is None:
|
| 337 |
+
# all or none ranks in the mp_group should reach here
|
| 338 |
+
print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading")
|
| 339 |
+
return
|
| 340 |
+
|
| 341 |
+
if tensor is None:
|
| 342 |
+
sync_tensor = torch.empty(
|
| 343 |
+
chunk_shape,
|
| 344 |
+
dtype=params_dtype,
|
| 345 |
+
device=torch.cuda.current_device(),
|
| 346 |
+
requires_grad=False,
|
| 347 |
+
)
|
| 348 |
+
else:
|
| 349 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
|
| 350 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 351 |
+
|
| 352 |
+
for i in range(tp_size):
|
| 353 |
+
if torch.distributed.get_rank() == src_rank:
|
| 354 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 355 |
+
dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
|
| 356 |
+
if (i == tp_rank) and (tensor is not None):
|
| 357 |
+
tensor.data.copy_(sync_tensor)
|
| 358 |
+
|
| 359 |
+
if dp_rank == 0:
|
| 360 |
+
# Embeddings
|
| 361 |
+
# -------------------
|
| 362 |
+
print_rank_0("loading embeddings...")
|
| 363 |
+
gpt_model_module = _get_gpt_model(models[0])
|
| 364 |
+
embed_tokens_weight = None
|
| 365 |
+
if pp_rank == 0:
|
| 366 |
+
embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight
|
| 367 |
+
_broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
|
| 368 |
+
|
| 369 |
+
# Transformer layers
|
| 370 |
+
# -------------------
|
| 371 |
+
layer_map = _megatron_calc_layer_map(config)
|
| 372 |
+
|
| 373 |
+
for layer in range(config.num_hidden_layers):
|
| 374 |
+
print_rank_0(f"loading layer #{layer}...")
|
| 375 |
+
layer_name = f"model.layers.{layer}"
|
| 376 |
+
dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
|
| 377 |
+
|
| 378 |
+
gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
|
| 379 |
+
sync_layer = gpt_model_module.decoder.layers[dst_layer_idx]
|
| 380 |
+
|
| 381 |
+
_broadcast_tensor(
|
| 382 |
+
sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None,
|
| 383 |
+
f"{layer_name}.input_layernorm.weight",
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
if f"{layer_name}.self_attn.q_norm.weight" in state_dict:
|
| 387 |
+
_broadcast_tensor(
|
| 388 |
+
sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None,
|
| 389 |
+
f"{layer_name}.self_attn.q_norm.weight",
|
| 390 |
+
)
|
| 391 |
+
_broadcast_tensor(
|
| 392 |
+
sync_layer.self_attention.k_layernorm.weight if dst_pp_rank == pp_rank else None,
|
| 393 |
+
f"{layer_name}.self_attn.k_norm.weight",
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
_broadcast_tp_shard_tensor_qkv(
|
| 397 |
+
sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None,
|
| 398 |
+
f"{layer_name}.self_attn.q_proj.weight",
|
| 399 |
+
f"{layer_name}.self_attn.k_proj.weight",
|
| 400 |
+
f"{layer_name}.self_attn.v_proj.weight",
|
| 401 |
+
)
|
| 402 |
+
if f"{layer_name}.self_attn.q_proj.bias" in state_dict:
|
| 403 |
+
_broadcast_tp_shard_tensor_qkv(
|
| 404 |
+
sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None,
|
| 405 |
+
f"{layer_name}.self_attn.q_proj.bias",
|
| 406 |
+
f"{layer_name}.self_attn.k_proj.bias",
|
| 407 |
+
f"{layer_name}.self_attn.v_proj.bias",
|
| 408 |
+
bias=True,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
_broadcast_tp_shard_tensor(
|
| 412 |
+
sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None,
|
| 413 |
+
f"{layer_name}.self_attn.o_proj.weight",
|
| 414 |
+
chunk_dim=1,
|
| 415 |
+
)
|
| 416 |
+
_broadcast_tensor(
|
| 417 |
+
sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None,
|
| 418 |
+
f"{layer_name}.post_attention_layernorm.weight",
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
_broadcast_tp_shard_tensor_gate_up(
|
| 422 |
+
sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None,
|
| 423 |
+
f"{layer_name}.mlp.gate_proj.weight",
|
| 424 |
+
f"{layer_name}.mlp.up_proj.weight",
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
_broadcast_tp_shard_tensor(
|
| 428 |
+
sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None,
|
| 429 |
+
f"{layer_name}.mlp.down_proj.weight",
|
| 430 |
+
chunk_dim=1,
|
| 431 |
+
)
|
| 432 |
+
# Final Layernorm
|
| 433 |
+
# -------------------
|
| 434 |
+
print_rank_0("loading final layernorm...")
|
| 435 |
+
gpt_model_module = _get_gpt_model(models[-1])
|
| 436 |
+
_broadcast_tensor(
|
| 437 |
+
getattr(gpt_model_module.decoder.final_layernorm, "weight", None),
|
| 438 |
+
"model.norm.weight",
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
print_rank_0("loading lm_head...")
|
| 442 |
+
lm_head_weight = None
|
| 443 |
+
if pp_rank + 1 == pp_size:
|
| 444 |
+
lm_head_weight = gpt_model_module.output_layer.weight
|
| 445 |
+
|
| 446 |
+
if is_value_model:
|
| 447 |
+
# if torch.distributed.get_rank() == src_rank:
|
| 448 |
+
if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
|
| 449 |
+
_broadcast_tensor(lm_head_weight, "lm_head.weight")
|
| 450 |
+
elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
|
| 451 |
+
_broadcast_tensor(lm_head_weight, "reward_head.weight")
|
| 452 |
+
print_rank_0("load lm_head from value_head weight")
|
| 453 |
+
else:
|
| 454 |
+
_broadcast_tensor(None, "lm_head.weight")
|
| 455 |
+
print_rank_0("fail to match lm_head in value_model")
|
| 456 |
+
# else:
|
| 457 |
+
|
| 458 |
+
# _broadcast_tensor(lm_head_weight, "lm_head.weight")
|
| 459 |
+
|
| 460 |
+
else:
|
| 461 |
+
_broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
|
| 462 |
+
dist.barrier()
|
| 463 |
+
# Broadcast weights inside data parallel groups
|
| 464 |
+
for wrapped_model in wrapped_models:
|
| 465 |
+
broadcast_params(wrapped_model)
|
| 466 |
+
pass
|
| 467 |
+
torch.cuda.empty_cache()
|
| 468 |
+
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
|
verl/models/mcore/model_forward.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from verl.utils.megatron_utils import unwrap_model
|
| 18 |
+
|
| 19 |
+
from .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def gptmodel_forward(model, input_ids, attention_mask, position_ids, sequence_parallel, value_model=False, pack_seqs=True):
|
| 23 |
+
"""Default forward pass for GPT models with optional sequence packing."""
|
| 24 |
+
pre_process = unwrap_model(model).pre_process
|
| 25 |
+
post_process = unwrap_model(model).post_process
|
| 26 |
+
if pack_seqs:
|
| 27 |
+
batch_size, seq_len = attention_mask.shape[:2]
|
| 28 |
+
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)
|
| 29 |
+
input_ids_rmpad = input_ids_rmpad.contiguous()
|
| 30 |
+
output_orig = model(
|
| 31 |
+
input_ids=input_ids_rmpad,
|
| 32 |
+
attention_mask=None,
|
| 33 |
+
position_ids=position_ids,
|
| 34 |
+
packed_seq_params=packed_seq_params,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
output = postprocess_packed_seqs(output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process)
|
| 38 |
+
else:
|
| 39 |
+
batch_size, sequence_length = attention_mask.shape
|
| 40 |
+
new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process)
|
| 41 |
+
output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids)
|
| 42 |
+
output = recover_left_padding(output, new_attention_mask, attention_mask, sequence_length, post_process=post_process)
|
| 43 |
+
if value_model and post_process:
|
| 44 |
+
output = output[..., 0]
|
| 45 |
+
return output
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def gptmodel_forward_qwen2_5_vl(*args, **kwargs):
|
| 49 |
+
"""Forward pass for Qwen2.5 VL model (not implemented)."""
|
| 50 |
+
raise NotImplementedError("VLM is not supported yet")
|
verl/models/mcore/model_initializer.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
# use mcore transformer config to initialize the model
|
| 18 |
+
from abc import ABC, abstractmethod
|
| 19 |
+
|
| 20 |
+
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
|
| 21 |
+
from megatron.core.models.gpt.gpt_model import GPTModel
|
| 22 |
+
|
| 23 |
+
from .config_converter import PretrainedConfig, TransformerConfig
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BaseModelInitializer(ABC):
|
| 27 |
+
"""Base class for model initializers."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig):
|
| 30 |
+
self.tfconfig = tfconfig
|
| 31 |
+
self.hf_config = hf_config
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def get_transformer_layer_spec(self):
|
| 35 |
+
"""Get the transformer layer specification.
|
| 36 |
+
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py"""
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def get_rope_scaling_args(self) -> dict:
|
| 40 |
+
"""Get rope scaling args."""
|
| 41 |
+
rope_scaling_args = {}
|
| 42 |
+
if "rope_scaling" in self.hf_config:
|
| 43 |
+
if self.hf_config.rope_scaling is not None:
|
| 44 |
+
assert self.hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now"
|
| 45 |
+
rope_scaling_args["seq_len_interpolation_factor"] = self.hf_config.rope_scaling["factor"]
|
| 46 |
+
return rope_scaling_args
|
| 47 |
+
|
| 48 |
+
def initialize(
|
| 49 |
+
self,
|
| 50 |
+
pre_process: bool = True,
|
| 51 |
+
post_process: bool = True,
|
| 52 |
+
share_embeddings_and_output_weights: bool = False,
|
| 53 |
+
value: bool = False,
|
| 54 |
+
**extra_kwargs,
|
| 55 |
+
) -> GPTModel:
|
| 56 |
+
"""Initialize a GPT model with the given configuration.
|
| 57 |
+
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
pre_process (bool): include embedding layer.
|
| 61 |
+
post_process (bool): including an output layer.
|
| 62 |
+
share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared.
|
| 63 |
+
value (bool): add an extra linear layer for classification or regression.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
GPTModel: An initialized GPT model instance
|
| 67 |
+
"""
|
| 68 |
+
transformer_layer_spec = self.get_transformer_layer_spec()
|
| 69 |
+
rope_scaling_args = self.get_rope_scaling_args()
|
| 70 |
+
|
| 71 |
+
model = GPTModel(
|
| 72 |
+
config=self.tfconfig,
|
| 73 |
+
transformer_layer_spec=transformer_layer_spec,
|
| 74 |
+
vocab_size=self.hf_config.vocab_size,
|
| 75 |
+
max_sequence_length=self.hf_config.max_position_embeddings,
|
| 76 |
+
pre_process=pre_process,
|
| 77 |
+
post_process=post_process,
|
| 78 |
+
share_embeddings_and_output_weights=share_embeddings_and_output_weights,
|
| 79 |
+
position_embedding_type="rope",
|
| 80 |
+
rotary_base=self.hf_config.rope_theta,
|
| 81 |
+
**rope_scaling_args,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if post_process and value:
|
| 85 |
+
from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer
|
| 86 |
+
|
| 87 |
+
model.output_layer = LinearForLastLayer(input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig)
|
| 88 |
+
|
| 89 |
+
return model
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class DenseModel(BaseModelInitializer):
|
| 93 |
+
"""Initializer for dense models like Llama and Qwen2."""
|
| 94 |
+
|
| 95 |
+
def get_transformer_layer_spec(self):
|
| 96 |
+
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
|
| 97 |
+
return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class Qwen2MoEModel(BaseModelInitializer):
|
| 101 |
+
"""Initializer for Qwen2 MoE models."""
|
| 102 |
+
|
| 103 |
+
def get_transformer_layer_spec(self):
|
| 104 |
+
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
|
| 105 |
+
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
|
| 106 |
+
|
| 107 |
+
# Patch layer spec for shared experts
|
| 108 |
+
for i in range(len(transformer_layer_spec.layer_specs)):
|
| 109 |
+
transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params["gate"] = True
|
| 110 |
+
|
| 111 |
+
return transformer_layer_spec
|
| 112 |
+
|
| 113 |
+
def initialize(self, freeze_moe_router: bool = True, **kwargs):
|
| 114 |
+
# Qwen default freeze_moe_router: true
|
| 115 |
+
model = super().initialize(**kwargs)
|
| 116 |
+
if freeze_moe_router:
|
| 117 |
+
for layer in model.decoder.layers:
|
| 118 |
+
layer.mlp.router.weight.requires_grad = False
|
| 119 |
+
layer.mlp.shared_experts.gate_weight.requires_grad = False
|
| 120 |
+
return model
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class MixtralModel(BaseModelInitializer):
|
| 124 |
+
"""Initializer for Mixtral models."""
|
| 125 |
+
|
| 126 |
+
def get_transformer_layer_spec(self):
|
| 127 |
+
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
|
| 128 |
+
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
|
| 129 |
+
return transformer_layer_spec
|
| 130 |
+
|
| 131 |
+
def initialize(self, freeze_moe_router: bool = False, **kwargs):
|
| 132 |
+
model = super().initialize(**kwargs)
|
| 133 |
+
if freeze_moe_router:
|
| 134 |
+
for layer in model.decoder.layers:
|
| 135 |
+
layer.mlp.router.weight.requires_grad = False
|
| 136 |
+
return model
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class Qwen3MoEModel(BaseModelInitializer):
|
| 140 |
+
"""Initializer for Qwen3 MoE models."""
|
| 141 |
+
|
| 142 |
+
def get_transformer_layer_spec(self):
|
| 143 |
+
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
|
| 144 |
+
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
|
| 145 |
+
return transformer_layer_spec
|
| 146 |
+
|
| 147 |
+
def initialize(self, freeze_moe_router: bool = True, **kwargs):
|
| 148 |
+
# Qwen default freeze_moe_router: true
|
| 149 |
+
model = super().initialize(**kwargs)
|
| 150 |
+
if freeze_moe_router:
|
| 151 |
+
for layer in model.decoder.layers:
|
| 152 |
+
layer.mlp.router.weight.requires_grad = False
|
| 153 |
+
return model
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Qwen25VLModel(BaseModelInitializer):
|
| 157 |
+
"""Initializer for Qwen2.5 VL models."""
|
| 158 |
+
|
| 159 |
+
def get_transformer_layer_spec(self):
|
| 160 |
+
raise NotImplementedError("VLM is not supported yet")
|
verl/models/mcore/readme.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# verl Megatron-Core Models
|
| 2 |
+
The earlier versions of verl use `Megatron-LM` 0.4 and workaround huggingface model classes. To better use the latest features and speedup of modern Megatron, we are migrating to `Megatron-Core`(mcore), and use the recommended `GPTModel` class for all language models. With mcore `GPTModel`, we can use the latest features like `context parallel`, `expert parallel`, `dist_checkpointing`, etc. and we can update mcore with little effort in the future for new features.
|
| 3 |
+
|
| 4 |
+
The migration has been successful with the help of the mcore team and the community. What we have done is:
|
| 5 |
+
1. update `Megatron` version to `0.11.0`
|
| 6 |
+
2. migrate `LlamaForCausalLM` and `Qwen2ForCausalLM` to mcore `GPTModel`
|
| 7 |
+
3. support sequence packing/thd format.
|
| 8 |
+
4. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`.
|
| 9 |
+
5. support the mcore `dist_checkpointing` feature and a basic offline weighs conversion scipt from huggingface to mcore `dist_checkpointing` format.
|
| 10 |
+
|
| 11 |
+
We are working on the following features:
|
| 12 |
+
- support `Qwen2MoeForCausalLM`
|
| 13 |
+
- support `MixtralForCausalLM`
|
| 14 |
+
- support `DeepseekV3ForCausalLM`
|
| 15 |
+
- support `expert parallel`
|
| 16 |
+
|
| 17 |
+
Features we invite the community to contribute:
|
| 18 |
+
- better scipts for offline weights conversion from huggingface to mcore `dist_checkpointing` format.
|
| 19 |
+
- conversion of large models with multiple GPUs
|
| 20 |
+
- conversion of large models with single GPU
|
| 21 |
+
- refactor the `megatron_checkpoint_manager.py` by `dist_checkpointing` format.
|
| 22 |
+
- support llama4
|
| 23 |
+
- support qwen2.5-vl
|
| 24 |
+
|
| 25 |
+
To track the progress of verl mcore integration, please refer to the [mcore integration issue](https://github.com/volcengine/verl/issues/1033).
|
| 26 |
+
|
| 27 |
+
## How things work now
|
| 28 |
+
To engage the community in contributing, here are the key steps in our mcore integration process and features under development.
|
| 29 |
+
|
| 30 |
+
The huggingface `transformers` is the de facto standard of model zoo while mcore is good at computation efficiency. The main challenge is conversion between the two.
|
| 31 |
+
main steps:
|
| 32 |
+
1. modelling the huggingface model with mcore `GPTModel`
|
| 33 |
+
- a. convert the huggingface config to mcore `TransformerConfig`
|
| 34 |
+
- b. init the mcore `GPTModel` with the converted config
|
| 35 |
+
- c. load the huggingface model weights to the `GPTModel`
|
| 36 |
+
2. online weight conversion from mcore to huggingface (due the the rollout engine `vLLM` is using huggingface format)
|
| 37 |
+
- a. bridge the gap between mcore and huggingface weights format and name mapping
|
| 38 |
+
- b. online resharding the mcore weights to rollout engine
|
| 39 |
+
- this part is very complicated with multiple parallel strategies composition between mcore and rollout engine
|
| 40 |
+
3. support the mcore features in verl
|
| 41 |
+
- a. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`
|
| 42 |
+
- b. support recompute and other mcore speed up features
|
| 43 |
+
|
| 44 |
+
4. checkpointing
|
| 45 |
+
- a. support recovering the verl training.
|
| 46 |
+
- b. support exporting the mcore checkpoint to huggingface format, for downstream inference.
|
| 47 |
+
|
| 48 |
+
### Modelling the huggingface model with mcore `GPTModel`
|
| 49 |
+
The first step is to convert huggingface config to mcore `TransformerConfig` and init the mcore `GPTModel` with the converted config. See code in `verl/models/mcore/config_converter.py` and `verl/verl/models/mcore/models/model_initializer.py`. The corresponding model forward code is in `verl/verl/models/mcore/models/model_forward.py`.
|
| 50 |
+
|
| 51 |
+
There are two ways of loading the huggingface model weights to the `GPTModel`
|
| 52 |
+
1. Runtime loading
|
| 53 |
+
- every rank loads the entire huggingface model weights and then shard and convert to mcore weights.
|
| 54 |
+
- speed is slow and memory consumption is high.
|
| 55 |
+
- this way is deprecated and will not support new models.
|
| 56 |
+
2. Offline loading
|
| 57 |
+
- use offline script to convert the huggingface model weights to mcore weights and save with mcore `dist_checkpointing` format.
|
| 58 |
+
- online loading and sharding is automatically done by mcore `dist_checkpointing` format. The speed is fast and memory consumption is low.
|
| 59 |
+
- the offline script is in `verl/scripts/converter_hf_to_mcore.py`.
|
| 60 |
+
|
| 61 |
+
### online weight conversion from mcore to huggingface
|
| 62 |
+
See function `convert_megatron_model_to_transformers_model` in `verl/utils/megatron_utils.py` for the details.
|
| 63 |
+
|
| 64 |
+
It should be refatored for extensibility and better performance.
|
| 65 |
+
|
| 66 |
+
### support the mcore features in verl
|
| 67 |
+
Most of the features of `GPTModel` is out-of-the-box supported in verl through changing the `TransformerConfig`, except those about parallel strategies, such as `expert parallel`.
|
| 68 |
+
Features about parallel strategies should be supported with changes about the online weights conversion(especially the resharding part) and verl work dispatching.
|
| 69 |
+
|
| 70 |
+
### checkpointing
|
| 71 |
+
The existing checkpointing code is in `verl/utils/checkpoint/megatron_checkpoint_manager.py`. And the script to convert checkpoint to huggingface format is in `verl/scripts/model_merger.py`.
|
| 72 |
+
|
| 73 |
+
The existing checkpoint format is simplely save every rank's weights and optimizer states. It should be refactored by `dist_checkpointing` format.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## How to support new models
|
| 77 |
+
1. make sure the model is supported by vLLM
|
| 78 |
+
2. modelling the huggingface model with mcore `GPTModel` (The [Pai-Megatron-Path](https://github.com/alibaba/Pai-Megatron-Patch/tree/main) is a good reference)
|
| 79 |
+
- a. convert the huggingface config to mcore `TransformerConfig`
|
| 80 |
+
- b. init the mcore `GPTModel` with the converted config
|
| 81 |
+
- c. load the huggingface model weights to the `GPTModel`
|
| 82 |
+
- d. for VLM the interface might be different, it is ok to add a new model class with GPTModel as its module.
|
| 83 |
+
3. offline weights conversion from huggingface to mcore `dist_checkpointing` format
|
| 84 |
+
4. support online weights conversion from mcore to huggingface
|
| 85 |
+
- it is recommended to initilize a vLLM model with the converted mcore weights, and then test if the generating sequence is correct.
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
## How to scale up to larger models like deepseek-v3 or other 100B+ models
|
| 89 |
+
The greatest challenge for scaling up to larger models is the memory consumption.
|
| 90 |
+
|
| 91 |
+
The necessary features under development for scaling up are
|
| 92 |
+
1. Training engine part
|
| 93 |
+
- expert parallel
|
| 94 |
+
2. Rollout engine part
|
| 95 |
+
- pipeline parallel
|
| 96 |
+
- expert parallel
|
| 97 |
+
- more efficient and general weight resharding and loading
|
| 98 |
+
3. Offline weights conversion
|
| 99 |
+
- support weights larger then single GPU memory
|
verl/models/mcore/registry.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Registry module for model architecture components.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from enum import Enum
|
| 20 |
+
from typing import Callable, Dict, Type
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
|
| 25 |
+
from .config_converter import (
|
| 26 |
+
PretrainedConfig,
|
| 27 |
+
TransformerConfig,
|
| 28 |
+
hf_to_mcore_config_dense,
|
| 29 |
+
hf_to_mcore_config_dpskv3,
|
| 30 |
+
hf_to_mcore_config_llama4,
|
| 31 |
+
hf_to_mcore_config_mixtral,
|
| 32 |
+
hf_to_mcore_config_qwen2_5_vl,
|
| 33 |
+
hf_to_mcore_config_qwen2moe,
|
| 34 |
+
hf_to_mcore_config_qwen3moe,
|
| 35 |
+
)
|
| 36 |
+
from .model_forward import (
|
| 37 |
+
gptmodel_forward,
|
| 38 |
+
)
|
| 39 |
+
from .model_initializer import (
|
| 40 |
+
BaseModelInitializer,
|
| 41 |
+
DenseModel,
|
| 42 |
+
MixtralModel,
|
| 43 |
+
Qwen2MoEModel,
|
| 44 |
+
Qwen3MoEModel,
|
| 45 |
+
Qwen25VLModel,
|
| 46 |
+
)
|
| 47 |
+
from .weight_converter import (
|
| 48 |
+
McoreToHFWeightConverterDense,
|
| 49 |
+
McoreToHFWeightConverterMixtral,
|
| 50 |
+
McoreToHFWeightConverterQwen2Moe,
|
| 51 |
+
McoreToHFWeightConverterQwen3Moe,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class SupportedModel(Enum):
|
| 56 |
+
LLAMA = "LlamaForCausalLM" # tested
|
| 57 |
+
QWEN2 = "Qwen2ForCausalLM" # tested
|
| 58 |
+
QWEN2_MOE = "Qwen2MoeForCausalLM" # pending
|
| 59 |
+
DEEPSEEK_V3 = "DeepseekV3ForCausalLM" # not tested
|
| 60 |
+
MIXTRAL = "MixtralForCausalLM" # tested
|
| 61 |
+
QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" # not supported
|
| 62 |
+
LLAMA4 = "Llama4ForConditionalGeneration" # not tested
|
| 63 |
+
QWEN3 = "Qwen3ForCausalLM" # tested
|
| 64 |
+
QWEN3_MOE = "Qwen3MoeForCausalLM" # not tested
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Registry for model configuration converters
|
| 68 |
+
MODEL_CONFIG_CONVERTER_REGISTRY: Dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
|
| 69 |
+
SupportedModel.LLAMA: hf_to_mcore_config_dense,
|
| 70 |
+
SupportedModel.QWEN2: hf_to_mcore_config_dense,
|
| 71 |
+
SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,
|
| 72 |
+
SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3,
|
| 73 |
+
SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral,
|
| 74 |
+
SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
|
| 75 |
+
SupportedModel.LLAMA4: hf_to_mcore_config_llama4,
|
| 76 |
+
SupportedModel.QWEN3: hf_to_mcore_config_dense,
|
| 77 |
+
SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
# Registry for model initializers
|
| 81 |
+
MODEL_INITIALIZER_REGISTRY: Dict[SupportedModel, Type[BaseModelInitializer]] = {
|
| 82 |
+
SupportedModel.LLAMA: DenseModel,
|
| 83 |
+
SupportedModel.QWEN2: DenseModel,
|
| 84 |
+
SupportedModel.QWEN2_MOE: Qwen2MoEModel,
|
| 85 |
+
SupportedModel.MIXTRAL: MixtralModel,
|
| 86 |
+
SupportedModel.DEEPSEEK_V3: DenseModel,
|
| 87 |
+
SupportedModel.QWEN2_5_VL: Qwen25VLModel,
|
| 88 |
+
SupportedModel.LLAMA4: DenseModel,
|
| 89 |
+
SupportedModel.QWEN3: DenseModel,
|
| 90 |
+
SupportedModel.QWEN3_MOE: Qwen3MoEModel,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# Registry for model forward functions
|
| 94 |
+
MODEL_FORWARD_REGISTRY: Dict[SupportedModel, Callable] = {
|
| 95 |
+
SupportedModel.LLAMA: gptmodel_forward,
|
| 96 |
+
SupportedModel.QWEN2: gptmodel_forward,
|
| 97 |
+
SupportedModel.QWEN2_MOE: gptmodel_forward,
|
| 98 |
+
SupportedModel.MIXTRAL: gptmodel_forward,
|
| 99 |
+
SupportedModel.DEEPSEEK_V3: gptmodel_forward,
|
| 100 |
+
SupportedModel.QWEN2_5_VL: gptmodel_forward,
|
| 101 |
+
SupportedModel.LLAMA4: gptmodel_forward,
|
| 102 |
+
SupportedModel.QWEN3: gptmodel_forward,
|
| 103 |
+
SupportedModel.QWEN3_MOE: gptmodel_forward,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# Registry for model weight converters
|
| 107 |
+
MODEL_WEIGHT_CONVERTER_REGISTRY: Dict[SupportedModel, Type] = {
|
| 108 |
+
SupportedModel.LLAMA: McoreToHFWeightConverterDense,
|
| 109 |
+
SupportedModel.QWEN2: McoreToHFWeightConverterDense,
|
| 110 |
+
SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe,
|
| 111 |
+
SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral,
|
| 112 |
+
SupportedModel.QWEN3: McoreToHFWeightConverterDense,
|
| 113 |
+
SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_supported_model(model_type: str) -> SupportedModel:
|
| 118 |
+
try:
|
| 119 |
+
return SupportedModel(model_type)
|
| 120 |
+
except ValueError as err:
|
| 121 |
+
supported_models = [e.value for e in SupportedModel]
|
| 122 |
+
raise NotImplementedError(f"Model Type: {model_type} not supported. Supported models: {supported_models}") from err
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def hf_to_mcore_config(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
|
| 126 |
+
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
|
| 127 |
+
model = get_supported_model(hf_config.architectures[0])
|
| 128 |
+
return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def init_mcore_model(
|
| 132 |
+
tfconfig: TransformerConfig,
|
| 133 |
+
hf_config: PretrainedConfig,
|
| 134 |
+
pre_process: bool = True,
|
| 135 |
+
post_process: bool = None,
|
| 136 |
+
*,
|
| 137 |
+
share_embeddings_and_output_weights: bool = False,
|
| 138 |
+
value: bool = False,
|
| 139 |
+
**extra_kwargs, # may be used for vlm and moe
|
| 140 |
+
) -> nn.Module:
|
| 141 |
+
"""
|
| 142 |
+
Initialize a Mcore model.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
tfconfig: The transformer config.
|
| 146 |
+
hf_config: The HuggingFace config.
|
| 147 |
+
pre_process: Optional pre-processing function.
|
| 148 |
+
post_process: Optional post-processing function.
|
| 149 |
+
share_embeddings_and_output_weights: Whether to share embeddings and output weights.
|
| 150 |
+
value: Whether to use value.
|
| 151 |
+
**extra_kwargs: Additional keyword arguments.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
The initialized model.
|
| 155 |
+
"""
|
| 156 |
+
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
|
| 157 |
+
model = get_supported_model(hf_config.architectures[0])
|
| 158 |
+
initializer_cls = MODEL_INITIALIZER_REGISTRY[model]
|
| 159 |
+
initializer = initializer_cls(tfconfig, hf_config)
|
| 160 |
+
return initializer.initialize(pre_process=pre_process, post_process=post_process, share_embeddings_and_output_weights=share_embeddings_and_output_weights, value=value, **extra_kwargs)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:
|
| 164 |
+
"""
|
| 165 |
+
Get the forward function for given model architecture.
|
| 166 |
+
"""
|
| 167 |
+
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
|
| 168 |
+
model = get_supported_model(hf_config.architectures[0])
|
| 169 |
+
return MODEL_FORWARD_REGISTRY[model]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable:
|
| 173 |
+
"""
|
| 174 |
+
Get the weight converter for given model architecture.
|
| 175 |
+
"""
|
| 176 |
+
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
|
| 177 |
+
model = get_supported_model(hf_config.architectures[0])
|
| 178 |
+
tfconfig = hf_to_mcore_config(hf_config, dtype)
|
| 179 |
+
return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig)
|
verl/models/mcore/saver.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
from megatron.core import mpu
|
| 21 |
+
from megatron.core.distributed import DistributedDataParallel as LocalDDP
|
| 22 |
+
from megatron.core.transformer.module import Float16Module
|
| 23 |
+
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
| 24 |
+
|
| 25 |
+
from verl.utils.megatron_utils import print_rank_0, unwrap_model
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0):
|
| 29 |
+
"""Calculate global rank with support for CP/EP parallelism"""
|
| 30 |
+
|
| 31 |
+
# Get parallel sizes for each dimension
|
| 32 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 33 |
+
dp_size = mpu.get_data_parallel_world_size()
|
| 34 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 35 |
+
cp_size = mpu.get_context_parallel_world_size()
|
| 36 |
+
# ep_size = mpu.get_expert_model_parallel_world_size()
|
| 37 |
+
|
| 38 |
+
# Verify total GPU count matches (must be consistent with parallel_state.py)
|
| 39 |
+
total_size = tp_size * dp_size * pp_size * cp_size
|
| 40 |
+
assert total_size == torch.distributed.get_world_size(), f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}"
|
| 41 |
+
|
| 42 |
+
# Core calculation logic (corresponds to RankGenerator order parameter)
|
| 43 |
+
# Assumes default order is "tp-cp-ep-dp-pp"
|
| 44 |
+
return ((pp_rank * dp_size + dp_rank) * cp_size + cp_rank) * tp_size + tp_rank
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _megatron_calc_layer_map(config):
|
| 48 |
+
"""Calculate the mapping of global layer_idx to local layer_idx
|
| 49 |
+
Returns:
|
| 50 |
+
layer_map (Dict: int -> tuple(int, int, int)):
|
| 51 |
+
mapping from the global layer index to
|
| 52 |
+
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
| 53 |
+
"""
|
| 54 |
+
from megatron.core import mpu
|
| 55 |
+
|
| 56 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 57 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 58 |
+
|
| 59 |
+
layer_map = dict()
|
| 60 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 61 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 62 |
+
|
| 63 |
+
for pp_rank_idx in range(pp_size):
|
| 64 |
+
for virtual_pp_rank_idx in range(virtual_pp_size):
|
| 65 |
+
layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
|
| 66 |
+
for layer_idx in range(num_layers_per_model):
|
| 67 |
+
layer_map[layer_offset + layer_idx] = (
|
| 68 |
+
pp_rank_idx,
|
| 69 |
+
virtual_pp_rank_idx,
|
| 70 |
+
layer_idx,
|
| 71 |
+
)
|
| 72 |
+
return layer_map
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
|
| 76 |
+
"""Merge sharded parameters of a Megatron module into a merged checkpoint.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
wrapped_models (list of megatron.core.distributed.DistributedDataParallel):
|
| 80 |
+
The local DDP wrapped megatron modules.
|
| 81 |
+
config (str or None):
|
| 82 |
+
HF config for model
|
| 83 |
+
dtype: model params type
|
| 84 |
+
is_value_model: if model is value model
|
| 85 |
+
tie_word_embeddings: tie_word_embeddings
|
| 86 |
+
Returns:
|
| 87 |
+
state_dict (dict):
|
| 88 |
+
The merged state_dict in rank 0, and an empty dictionary in other ranks.
|
| 89 |
+
"""
|
| 90 |
+
start_time = time.time()
|
| 91 |
+
|
| 92 |
+
def _get_gpt_model(model):
|
| 93 |
+
return model
|
| 94 |
+
|
| 95 |
+
dp_rank = mpu.get_data_parallel_rank()
|
| 96 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 97 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 98 |
+
cp_rank = mpu.get_context_parallel_rank()
|
| 99 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 100 |
+
mp_group = mpu.get_model_parallel_group()
|
| 101 |
+
|
| 102 |
+
if dist.get_rank() == 0:
|
| 103 |
+
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
| 104 |
+
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
| 105 |
+
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
| 106 |
+
|
| 107 |
+
if not isinstance(wrapped_models, (list, tuple)):
|
| 108 |
+
wrapped_models = list(wrapped_models)
|
| 109 |
+
|
| 110 |
+
assert len(wrapped_models) == virtual_pp_size
|
| 111 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 112 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 113 |
+
|
| 114 |
+
models = [None] * len(wrapped_models)
|
| 115 |
+
|
| 116 |
+
for i, wrapped_model in enumerate(wrapped_models):
|
| 117 |
+
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
| 118 |
+
assert len(models[i].decoder.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].decoder.layers), num_layers_per_model)
|
| 119 |
+
|
| 120 |
+
state_dict = dict()
|
| 121 |
+
|
| 122 |
+
def _get_cpu_tensor(tensor: torch.Tensor):
|
| 123 |
+
if tensor is None:
|
| 124 |
+
return None
|
| 125 |
+
if tensor.device == torch.device("cpu"):
|
| 126 |
+
return tensor.detach().clone()
|
| 127 |
+
return tensor.detach().cpu()
|
| 128 |
+
|
| 129 |
+
def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
|
| 130 |
+
"""broadcast tensor across mp_group"""
|
| 131 |
+
nonlocal state_dict
|
| 132 |
+
nonlocal mp_group
|
| 133 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
|
| 134 |
+
|
| 135 |
+
if torch.distributed.get_rank() == src_rank:
|
| 136 |
+
if tensor is None:
|
| 137 |
+
weight = None
|
| 138 |
+
tensor_shape = None
|
| 139 |
+
else:
|
| 140 |
+
weight = tensor
|
| 141 |
+
tensor_shape = weight.shape
|
| 142 |
+
else:
|
| 143 |
+
weight = None
|
| 144 |
+
tensor_shape = None
|
| 145 |
+
|
| 146 |
+
obj_list = [tensor_shape]
|
| 147 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 148 |
+
tensor_shape = obj_list[0]
|
| 149 |
+
|
| 150 |
+
if tensor_shape is None:
|
| 151 |
+
# all or none ranks in the mp_group should reach here
|
| 152 |
+
print_rank_0(f"tensor:[{name}] not exist, skip collect")
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
if weight is None:
|
| 156 |
+
weight = torch.empty(
|
| 157 |
+
tensor_shape,
|
| 158 |
+
dtype=dtype,
|
| 159 |
+
device=torch.cuda.current_device(),
|
| 160 |
+
requires_grad=False,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
dist.broadcast(weight, src=src_rank, group=mp_group)
|
| 164 |
+
|
| 165 |
+
if torch.distributed.get_rank() == 0:
|
| 166 |
+
state_dict[name] = _get_cpu_tensor(weight)
|
| 167 |
+
|
| 168 |
+
def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
|
| 169 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 170 |
+
nonlocal state_dict
|
| 171 |
+
nonlocal mp_group
|
| 172 |
+
# tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 173 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 174 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
|
| 175 |
+
|
| 176 |
+
chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
|
| 177 |
+
|
| 178 |
+
obj_list = [chunk_shape]
|
| 179 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 180 |
+
chunk_shape = obj_list[0]
|
| 181 |
+
if chunk_shape is None:
|
| 182 |
+
# all or none ranks in the mp_group should reach here
|
| 183 |
+
print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
buffer_tensor = torch.empty(
|
| 187 |
+
chunk_shape,
|
| 188 |
+
dtype=dtype,
|
| 189 |
+
device=torch.cuda.current_device(),
|
| 190 |
+
requires_grad=False,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
chunk_tensors = [None] * tp_size
|
| 194 |
+
|
| 195 |
+
for i in range(tp_size):
|
| 196 |
+
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
|
| 197 |
+
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
| 198 |
+
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
| 199 |
+
|
| 200 |
+
if torch.distributed.get_rank() == 0:
|
| 201 |
+
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
| 202 |
+
|
| 203 |
+
if torch.distributed.get_rank() == 0:
|
| 204 |
+
full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
|
| 205 |
+
if mutate_func is not None:
|
| 206 |
+
full_tensor = mutate_func(full_tensor)
|
| 207 |
+
state_dict[name] = full_tensor
|
| 208 |
+
|
| 209 |
+
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
|
| 210 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 211 |
+
nonlocal state_dict
|
| 212 |
+
nonlocal mp_group
|
| 213 |
+
# tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 214 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 215 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
|
| 216 |
+
|
| 217 |
+
chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
|
| 218 |
+
|
| 219 |
+
obj_list = [chunk_shape]
|
| 220 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 221 |
+
chunk_shape = obj_list[0]
|
| 222 |
+
if chunk_shape is None:
|
| 223 |
+
# all or none ranks in the mp_group should reach here
|
| 224 |
+
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
|
| 225 |
+
return
|
| 226 |
+
|
| 227 |
+
buffer_tensor = torch.empty(
|
| 228 |
+
chunk_shape,
|
| 229 |
+
dtype=dtype,
|
| 230 |
+
device=torch.cuda.current_device(),
|
| 231 |
+
requires_grad=False,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
chunk_tensors = [None] * tp_size
|
| 235 |
+
|
| 236 |
+
for i in range(tp_size):
|
| 237 |
+
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
|
| 238 |
+
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
| 239 |
+
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
| 240 |
+
|
| 241 |
+
if torch.distributed.get_rank() == 0:
|
| 242 |
+
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
| 243 |
+
|
| 244 |
+
if torch.distributed.get_rank() == 0:
|
| 245 |
+
full_tensor = torch.concat(chunk_tensors, dim=0)
|
| 246 |
+
intermediate_size_tp = config.intermediate_size // tp_size
|
| 247 |
+
gate_weight_list = []
|
| 248 |
+
up_weight_list = []
|
| 249 |
+
for i in range(tp_size):
|
| 250 |
+
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]
|
| 251 |
+
gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
|
| 252 |
+
up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
|
| 253 |
+
gate_weight_list.append(gate_weight_tp)
|
| 254 |
+
up_weight_list.append(up_weight_tp)
|
| 255 |
+
|
| 256 |
+
state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
|
| 257 |
+
state_dict[up_name] = torch.cat(up_weight_list, dim=0)
|
| 258 |
+
|
| 259 |
+
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
|
| 260 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 261 |
+
nonlocal state_dict
|
| 262 |
+
nonlocal mp_group
|
| 263 |
+
# tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 264 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 265 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
|
| 266 |
+
|
| 267 |
+
chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
|
| 268 |
+
|
| 269 |
+
obj_list = [chunk_shape]
|
| 270 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 271 |
+
chunk_shape = obj_list[0]
|
| 272 |
+
if chunk_shape is None:
|
| 273 |
+
# all or none ranks in the mp_group should reach here
|
| 274 |
+
print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
|
| 275 |
+
return
|
| 276 |
+
|
| 277 |
+
buffer_tensor = torch.empty(
|
| 278 |
+
chunk_shape,
|
| 279 |
+
dtype=dtype,
|
| 280 |
+
device=torch.cuda.current_device(),
|
| 281 |
+
requires_grad=False,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
chunk_tensors = [None] * tp_size
|
| 285 |
+
|
| 286 |
+
for i in range(tp_size):
|
| 287 |
+
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
|
| 288 |
+
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
| 289 |
+
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
| 290 |
+
|
| 291 |
+
if torch.distributed.get_rank() == 0:
|
| 292 |
+
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
| 293 |
+
|
| 294 |
+
if torch.distributed.get_rank() == 0:
|
| 295 |
+
full_tensor = torch.concat(chunk_tensors, dim=0)
|
| 296 |
+
q_weight_list = []
|
| 297 |
+
k_weight_list = []
|
| 298 |
+
v_weight_list = []
|
| 299 |
+
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
| 300 |
+
|
| 301 |
+
if config.num_key_value_heads >= tp_size:
|
| 302 |
+
q_size_tp = config.hidden_size // tp_size
|
| 303 |
+
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
| 304 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 305 |
+
for i in range(tp_size):
|
| 306 |
+
num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size
|
| 307 |
+
qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
|
| 308 |
+
q_size_chunk = q_size_tp // num_query_groups_per_partition
|
| 309 |
+
kv_size_chunk = kv_size_tp // num_query_groups_per_partition
|
| 310 |
+
for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):
|
| 311 |
+
q_part = qkv_part_chunk[:q_size_chunk]
|
| 312 |
+
k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]
|
| 313 |
+
v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]
|
| 314 |
+
q_weight_list.append(q_part)
|
| 315 |
+
k_weight_list.append(k_part)
|
| 316 |
+
v_weight_list.append(v_part)
|
| 317 |
+
else:
|
| 318 |
+
q_size_tp = config.hidden_size // tp_size
|
| 319 |
+
kv_size_tp = hidden_size_per_head
|
| 320 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 321 |
+
for i in range(tp_size):
|
| 322 |
+
num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size
|
| 323 |
+
qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
|
| 324 |
+
q_size_chunk = q_size_tp // num_query_groups_per_partition
|
| 325 |
+
kv_size_chunk = kv_size_tp // num_query_groups_per_partition
|
| 326 |
+
for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):
|
| 327 |
+
q_part = qkv_part_chunk[:q_size_chunk]
|
| 328 |
+
k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]
|
| 329 |
+
v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]
|
| 330 |
+
q_weight_list.append(q_part)
|
| 331 |
+
if i * config.num_key_value_heads % tp_size == 0:
|
| 332 |
+
k_weight_list.append(k_part)
|
| 333 |
+
v_weight_list.append(v_part)
|
| 334 |
+
|
| 335 |
+
state_dict[q_name] = torch.cat(q_weight_list, dim=0)
|
| 336 |
+
state_dict[k_name] = torch.cat(k_weight_list, dim=0)
|
| 337 |
+
state_dict[v_name] = torch.cat(v_weight_list, dim=0)
|
| 338 |
+
|
| 339 |
+
# empty cache before collecting weights
|
| 340 |
+
torch.cuda.empty_cache()
|
| 341 |
+
# Embeddings
|
| 342 |
+
# -------------------
|
| 343 |
+
if dp_rank == 0 and cp_rank == 0: # models are identical across cp ranks
|
| 344 |
+
# Embeddings
|
| 345 |
+
# -------------------
|
| 346 |
+
print_rank_0("collecting embeddings...")
|
| 347 |
+
gpt_model_module = _get_gpt_model(models[0])
|
| 348 |
+
_broadcast_tp_shard_tensor(
|
| 349 |
+
gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None,
|
| 350 |
+
"model.embed_tokens.weight",
|
| 351 |
+
src_pp_rank=0,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Transformer layers
|
| 355 |
+
# -------------------
|
| 356 |
+
layer_map = _megatron_calc_layer_map(config)
|
| 357 |
+
for layer in range(config.num_hidden_layers):
|
| 358 |
+
print_rank_0(f"collecting layer #{layer}...")
|
| 359 |
+
layer_name = f"model.layers.{layer}"
|
| 360 |
+
src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]
|
| 361 |
+
|
| 362 |
+
gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
|
| 363 |
+
sync_layer = gpt_model_module.decoder.layers[src_layer_idx]
|
| 364 |
+
|
| 365 |
+
_broadcast_tensor(
|
| 366 |
+
sync_layer.self_attention.linear_qkv.layer_norm_weight,
|
| 367 |
+
f"{layer_name}.input_layernorm.weight",
|
| 368 |
+
src_pp_rank=src_pp_rank,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
_broadcast_tp_shard_tensor_qkv(
|
| 372 |
+
sync_layer.self_attention.linear_qkv.weight,
|
| 373 |
+
f"{layer_name}.self_attn.q_proj.weight",
|
| 374 |
+
f"{layer_name}.self_attn.k_proj.weight",
|
| 375 |
+
f"{layer_name}.self_attn.v_proj.weight",
|
| 376 |
+
src_pp_rank=src_pp_rank,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if getattr(sync_layer.self_attention.linear_qkv, "bias", None) is not None and sync_layer.self_attention.linear_qkv.bias.numel() > 0:
|
| 380 |
+
_broadcast_tp_shard_tensor_qkv(
|
| 381 |
+
sync_layer.self_attention.linear_qkv.bias,
|
| 382 |
+
f"{layer_name}.self_attn.q_proj.bias",
|
| 383 |
+
f"{layer_name}.self_attn.k_proj.bias",
|
| 384 |
+
f"{layer_name}.self_attn.v_proj.bias",
|
| 385 |
+
src_pp_rank=src_pp_rank,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
_broadcast_tp_shard_tensor(
|
| 389 |
+
sync_layer.self_attention.linear_proj.weight,
|
| 390 |
+
f"{layer_name}.self_attn.o_proj.weight",
|
| 391 |
+
concat_dim=1,
|
| 392 |
+
src_pp_rank=src_pp_rank,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
_broadcast_tensor(
|
| 396 |
+
sync_layer.mlp.linear_fc1.layer_norm_weight,
|
| 397 |
+
f"{layer_name}.post_attention_layernorm.weight",
|
| 398 |
+
src_pp_rank=src_pp_rank,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
_broadcast_tp_shard_tensor_gate_up(
|
| 402 |
+
sync_layer.mlp.linear_fc1.weight,
|
| 403 |
+
f"{layer_name}.mlp.gate_proj.weight",
|
| 404 |
+
f"{layer_name}.mlp.up_proj.weight",
|
| 405 |
+
src_pp_rank=src_pp_rank,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
_broadcast_tp_shard_tensor(
|
| 409 |
+
sync_layer.mlp.linear_fc2.weight,
|
| 410 |
+
f"{layer_name}.mlp.down_proj.weight",
|
| 411 |
+
concat_dim=1,
|
| 412 |
+
src_pp_rank=src_pp_rank,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# Final Layernorm
|
| 416 |
+
# -------------------
|
| 417 |
+
print_rank_0("collecting final layernorm...")
|
| 418 |
+
gpt_model_module = _get_gpt_model(models[-1])
|
| 419 |
+
_broadcast_tensor(
|
| 420 |
+
getattr(gpt_model_module.decoder.final_layernorm, "weight", None),
|
| 421 |
+
"model.norm.weight",
|
| 422 |
+
src_pp_rank=pp_size - 1,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
if tie_word_embeddings:
|
| 426 |
+
print_rank_0("tie word embedding skip load lm_head...")
|
| 427 |
+
else:
|
| 428 |
+
print_rank_0("collecting lm_head...")
|
| 429 |
+
|
| 430 |
+
if is_value_model:
|
| 431 |
+
lm_head_weight = None
|
| 432 |
+
if pp_rank == pp_size - 1:
|
| 433 |
+
lm_head_weight = getattr(gpt_model_module.output_layer, "weight", None)
|
| 434 |
+
_broadcast_tensor(lm_head_weight, "lm_head.weight", src_pp_rank=pp_size - 1)
|
| 435 |
+
|
| 436 |
+
else:
|
| 437 |
+
_broadcast_tp_shard_tensor(
|
| 438 |
+
getattr(gpt_model_module.output_layer, "weight", None) if pp_rank == pp_size - 1 else None,
|
| 439 |
+
"lm_head.weight",
|
| 440 |
+
src_pp_rank=pp_size - 1,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
dist.barrier()
|
| 444 |
+
torch.cuda.empty_cache()
|
| 445 |
+
if torch.distributed.get_rank() == 0:
|
| 446 |
+
for k, v in state_dict.items():
|
| 447 |
+
if dtype != v.dtype:
|
| 448 |
+
state_dict[k] = v.to(dtype)
|
| 449 |
+
|
| 450 |
+
print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
|
| 451 |
+
return state_dict
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def merge_megatron_ckpt_gptmodel_qwen_moe(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
|
| 455 |
+
raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen_moe is not implemented")
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def merge_megatron_ckpt_gptmodel_mixtral(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
|
| 459 |
+
raise NotImplementedError("merge_megatron_ckpt_gptmodel_mixtral is not implemented")
|
verl/models/mcore/util.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from megatron.core import parallel_state as mpu
|
| 18 |
+
from megatron.core.packed_seq_params import PackedSeqParams
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def preprocess_packed_seqs(input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True) -> tuple[torch.Tensor, PackedSeqParams]:
|
| 22 |
+
"""
|
| 23 |
+
Preprocess packed sequences
|
| 24 |
+
CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 gets second and second last chunks, and so on), this is for load balancing with causal masking.
|
| 25 |
+
See https://github.com/NVIDIA/TransformerEngine/issues/1368
|
| 26 |
+
"""
|
| 27 |
+
batch_size = input_ids.shape[0]
|
| 28 |
+
|
| 29 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 30 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 31 |
+
cp_size = mpu.get_context_parallel_world_size()
|
| 32 |
+
cp_rank = mpu.get_context_parallel_rank()
|
| 33 |
+
align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size
|
| 34 |
+
|
| 35 |
+
pad_size = (align_size - seqlens_in_batch % align_size) % align_size
|
| 36 |
+
seqlens_in_batch_padded = seqlens_in_batch + pad_size
|
| 37 |
+
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
|
| 38 |
+
cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
|
| 39 |
+
cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
|
| 40 |
+
cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)
|
| 41 |
+
max_seqlen_in_batch = seqlens_in_batch_padded.max().item()
|
| 42 |
+
|
| 43 |
+
shape = list(input_ids.shape[1:])
|
| 44 |
+
shape[0] = seqlens_in_batch_padded.sum().item() // cp_size
|
| 45 |
+
if pre_process:
|
| 46 |
+
input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
|
| 47 |
+
for i in range(batch_size):
|
| 48 |
+
if cp_size <= 1:
|
| 49 |
+
seqlen = seqlens_in_batch[i]
|
| 50 |
+
input_ids_rmpad[cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]]
|
| 51 |
+
continue
|
| 52 |
+
seqlen = seqlens_in_batch_padded[i] // cp_size
|
| 53 |
+
half_seqlen = seqlen // 2
|
| 54 |
+
start_idx = cu_seqlens_padded[i] // cp_size
|
| 55 |
+
# split to 2 chunks
|
| 56 |
+
d = input_ids[i, attention_mask[i]]
|
| 57 |
+
input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)]
|
| 58 |
+
|
| 59 |
+
remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1)
|
| 60 |
+
remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank
|
| 61 |
+
remain_end = min(remain_end, d.shape[0])
|
| 62 |
+
remain_len = remain_end - remain_start
|
| 63 |
+
if remain_len > 0:
|
| 64 |
+
input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[remain_start:remain_end]
|
| 65 |
+
|
| 66 |
+
packed_seq_params = PackedSeqParams(
|
| 67 |
+
qkv_format="thd",
|
| 68 |
+
cu_seqlens_q=cu_seqlens_padded,
|
| 69 |
+
max_seqlen_q=max_seqlen_in_batch,
|
| 70 |
+
cu_seqlens_kv=cu_seqlens_padded,
|
| 71 |
+
max_seqlen_kv=max_seqlen_in_batch,
|
| 72 |
+
cu_seqlens_q_padded=cu_seqlens_padded,
|
| 73 |
+
cu_seqlens_kv_padded=cu_seqlens_padded,
|
| 74 |
+
)
|
| 75 |
+
if pre_process:
|
| 76 |
+
return input_ids_rmpad.unsqueeze(0), packed_seq_params
|
| 77 |
+
else:
|
| 78 |
+
return input_ids, packed_seq_params
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def postprocess_packed_seqs(
|
| 82 |
+
output: torch.Tensor,
|
| 83 |
+
packed_seq_params: PackedSeqParams,
|
| 84 |
+
attention_mask: torch.Tensor,
|
| 85 |
+
batch_size: int,
|
| 86 |
+
seq_len: int,
|
| 87 |
+
post_process: bool = True,
|
| 88 |
+
) -> torch.Tensor:
|
| 89 |
+
"""
|
| 90 |
+
Postprocess packed sequences
|
| 91 |
+
"""
|
| 92 |
+
if not post_process:
|
| 93 |
+
return output
|
| 94 |
+
shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim
|
| 95 |
+
output_new = torch.zeros(shape, dtype=output.dtype, device=output.device)
|
| 96 |
+
|
| 97 |
+
cp_size = mpu.get_context_parallel_world_size()
|
| 98 |
+
# all gather output across context parallel group
|
| 99 |
+
if cp_size > 1:
|
| 100 |
+
# output shape: [1, packed_len, hidden_dim]
|
| 101 |
+
# need to gather across cp group and concatenate in sequence dimension
|
| 102 |
+
output_list = [torch.empty_like(output) for _ in range(cp_size)]
|
| 103 |
+
torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())
|
| 104 |
+
output_list[mpu.get_context_parallel_rank()] = output
|
| 105 |
+
else:
|
| 106 |
+
output_list = [output]
|
| 107 |
+
for i in range(batch_size):
|
| 108 |
+
if cp_size <= 1:
|
| 109 |
+
s = attention_mask[i].sum().item()
|
| 110 |
+
output_new[i, attention_mask[i]] = output[0][packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s]
|
| 111 |
+
continue
|
| 112 |
+
s_len_padded_chunk = (packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i]) // cp_size
|
| 113 |
+
half_seqlen = s_len_padded_chunk // 2
|
| 114 |
+
s_len = attention_mask[i].sum().item()
|
| 115 |
+
s_len_padded = s_len_padded_chunk * cp_size
|
| 116 |
+
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
|
| 117 |
+
for j in range(cp_size):
|
| 118 |
+
o = output_list[j][0]
|
| 119 |
+
# split to 2 chunks
|
| 120 |
+
packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size
|
| 121 |
+
o0, o1 = (
|
| 122 |
+
o[packed_start_idx : packed_start_idx + half_seqlen],
|
| 123 |
+
o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],
|
| 124 |
+
)
|
| 125 |
+
tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0
|
| 126 |
+
tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1
|
| 127 |
+
output_new[i, attention_mask[i]] = tmp[:s_len]
|
| 128 |
+
|
| 129 |
+
return output_new
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def remove_left_padding(
|
| 133 |
+
input_ids: torch.Tensor,
|
| 134 |
+
attention_mask: torch.Tensor,
|
| 135 |
+
position_ids: torch.Tensor,
|
| 136 |
+
sequence_parallel: bool = False,
|
| 137 |
+
pre_process: bool = True,
|
| 138 |
+
):
|
| 139 |
+
"""
|
| 140 |
+
Remove left padding from input_ids, attention_mask and position_ids
|
| 141 |
+
return new_input_ids, new_attention_mask, new_position_ids
|
| 142 |
+
"""
|
| 143 |
+
assert attention_mask.ndim == 2
|
| 144 |
+
assert position_ids.ndim == 2
|
| 145 |
+
cp_size = mpu.get_context_parallel_world_size()
|
| 146 |
+
assert cp_size == 1, "Context parallel size without seq_pack is not supported"
|
| 147 |
+
batch_size = input_ids.shape[0]
|
| 148 |
+
shape = list(input_ids.shape) # batch_size, seq_len,...
|
| 149 |
+
seq_lens = attention_mask.sum(dim=1)
|
| 150 |
+
seq_len = seq_lens.max().item()
|
| 151 |
+
if sequence_parallel:
|
| 152 |
+
sp_world_size = mpu.get_tensor_model_parallel_world_size()
|
| 153 |
+
pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size
|
| 154 |
+
seq_len = seq_len + pad_size
|
| 155 |
+
shape[1] = seq_len
|
| 156 |
+
if pre_process:
|
| 157 |
+
new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape)
|
| 158 |
+
new_attention_mask = torch.zeros(dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len))
|
| 159 |
+
new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len))
|
| 160 |
+
for i in range(batch_size):
|
| 161 |
+
if pre_process:
|
| 162 |
+
new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]]
|
| 163 |
+
new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]]
|
| 164 |
+
new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]]
|
| 165 |
+
if pre_process:
|
| 166 |
+
return new_input_ids, new_attention_mask, new_position_ids
|
| 167 |
+
else:
|
| 168 |
+
return input_ids, new_attention_mask, new_position_ids
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def recover_left_padding(
|
| 172 |
+
result,
|
| 173 |
+
attention_mask: torch.Tensor,
|
| 174 |
+
original_attention_mask: torch.Tensor,
|
| 175 |
+
origin_seqlen: int,
|
| 176 |
+
post_process: bool = True,
|
| 177 |
+
):
|
| 178 |
+
"""
|
| 179 |
+
Recover left padding from result
|
| 180 |
+
return result
|
| 181 |
+
"""
|
| 182 |
+
if not post_process:
|
| 183 |
+
return result
|
| 184 |
+
shape = list(result.shape)
|
| 185 |
+
batch_size = shape[0]
|
| 186 |
+
shape[1] = origin_seqlen
|
| 187 |
+
new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape)
|
| 188 |
+
for i in range(batch_size):
|
| 189 |
+
new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]]
|
| 190 |
+
return new_result
|
verl/models/mcore/weight_converter.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
# online convert mcore weight to pure huggingface weight, no any fusion
|
| 18 |
+
# including format conversion and name mapping
|
| 19 |
+
# not including resharding
|
| 20 |
+
import torch
|
| 21 |
+
from megatron.core.transformer import TransformerConfig
|
| 22 |
+
from transformers import PretrainedConfig
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class McoreToHFWeightConverterBase:
|
| 26 |
+
def __init__(self, hf_config: PretrainedConfig, mcore_config: TransformerConfig):
|
| 27 |
+
self.hf_config = hf_config
|
| 28 |
+
self.mcore_config = mcore_config
|
| 29 |
+
|
| 30 |
+
def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> torch.Tensor:
|
| 31 |
+
raise NotImplementedError
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class McoreToHFWeightConverterDense(McoreToHFWeightConverterBase):
|
| 35 |
+
def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
|
| 36 |
+
# 'decoder.layers.0.self_attention.linear_proj.weight'
|
| 37 |
+
# 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight'
|
| 38 |
+
# 'decoder.layers.0.self_attention.linear_qkv.weight'
|
| 39 |
+
# 'decoder.layers.0.self_attention.linear_qkv.bias'
|
| 40 |
+
layer_number = name.split(".")[2]
|
| 41 |
+
convert_names = []
|
| 42 |
+
if "self_attention.linear_qkv.bias" in name or "self_attention.linear_qkv.weight" in name:
|
| 43 |
+
param_type = name.split(".")[-1]
|
| 44 |
+
assert param_type == "bias" or param_type == "weight"
|
| 45 |
+
convert_names.append(f"model.layers.{layer_number}.self_attn.q_proj.{param_type}")
|
| 46 |
+
convert_names.append(f"model.layers.{layer_number}.self_attn.k_proj.{param_type}")
|
| 47 |
+
convert_names.append(f"model.layers.{layer_number}.self_attn.v_proj.{param_type}")
|
| 48 |
+
assert len(params) == 3
|
| 49 |
+
elif "self_attention.linear_proj.weight" in name:
|
| 50 |
+
convert_names.append(f"model.layers.{layer_number}.self_attn.o_proj.weight")
|
| 51 |
+
assert len(params) == 1
|
| 52 |
+
elif "self_attention.linear_qkv.layer_norm_weight" in name:
|
| 53 |
+
convert_names.append(f"model.layers.{layer_number}.input_layernorm.weight")
|
| 54 |
+
assert len(params) == 1
|
| 55 |
+
elif "self_attention.q_layernorm.weight" in name:
|
| 56 |
+
convert_names.append(f"model.layers.{layer_number}.self_attn.q_norm.weight")
|
| 57 |
+
assert len(params) == 1
|
| 58 |
+
elif "self_attention.k_layernorm.weight" in name:
|
| 59 |
+
convert_names.append(f"model.layers.{layer_number}.self_attn.k_norm.weight")
|
| 60 |
+
assert len(params) == 1
|
| 61 |
+
else:
|
| 62 |
+
raise NotImplementedError(f"Unsupported parameter name: {name}")
|
| 63 |
+
return convert_names, params
|
| 64 |
+
|
| 65 |
+
def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
|
| 66 |
+
# 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight'
|
| 67 |
+
# 'decoder.layers.0.mlp.linear_fc1.weight'
|
| 68 |
+
# 'decoder.layers.0.mlp.linear_fc2.weight'
|
| 69 |
+
layer_number = name.split(".")[2]
|
| 70 |
+
convert_names = []
|
| 71 |
+
if "mlp.linear_fc1.weight" in name:
|
| 72 |
+
# split gate_proj and up_proj
|
| 73 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.gate_proj.weight")
|
| 74 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.up_proj.weight")
|
| 75 |
+
assert len(params) == 2
|
| 76 |
+
elif "mlp.linear_fc1.layer_norm_weight" in name:
|
| 77 |
+
convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight")
|
| 78 |
+
assert len(params) == 1
|
| 79 |
+
elif "mlp.linear_fc2.weight" in name:
|
| 80 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.down_proj.weight")
|
| 81 |
+
assert len(params) == 1
|
| 82 |
+
else:
|
| 83 |
+
raise NotImplementedError(f"Unsupported parameter name: {name}")
|
| 84 |
+
return convert_names, params
|
| 85 |
+
|
| 86 |
+
def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
|
| 87 |
+
direct_name_mapping = {
|
| 88 |
+
"embedding.word_embeddings.weight": "model.embed_tokens.weight",
|
| 89 |
+
"decoder.final_layernorm.weight": "model.norm.weight",
|
| 90 |
+
"output_layer.weight": "lm_head.weight",
|
| 91 |
+
}
|
| 92 |
+
if name in direct_name_mapping:
|
| 93 |
+
return [direct_name_mapping[name]], [params_one_group[0]]
|
| 94 |
+
|
| 95 |
+
if "self_attention" in name:
|
| 96 |
+
return self._convert_attention_param(name, params_one_group)
|
| 97 |
+
elif "mlp" in name:
|
| 98 |
+
return self._convert_mlp_param(name, params_one_group)
|
| 99 |
+
else:
|
| 100 |
+
raise NotImplementedError(f"Unsupported parameter name: {name}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense):
|
| 104 |
+
def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
|
| 105 |
+
# 'decoder.layers.0.pre_mlp_layernorm.weight',
|
| 106 |
+
# 'decoder.layers.0.mlp.router.weight',
|
| 107 |
+
# 'decoder.layers.0.mlp.shared_experts.gate_weight',
|
| 108 |
+
# 'decoder.layers.0.mlp.shared_experts.linear_fc1.weight',
|
| 109 |
+
# 'decoder.layers.0.mlp.shared_experts.linear_fc2.weight'
|
| 110 |
+
# moe1
|
| 111 |
+
# 'decoder.layers.0.mlp.experts.linear_fc1.weight0',
|
| 112 |
+
# 'decoder.layers.0.mlp.experts.linear_fc1.weight1',
|
| 113 |
+
# 'decoder.layers.0.mlp.experts.linear_fc1.weight2',
|
| 114 |
+
# 'decoder.layers.0.mlp.experts.linear_fc1.weight3',
|
| 115 |
+
# moe2
|
| 116 |
+
# 'decoder.layers.0.mlp.experts.linear_fc2.weight0',
|
| 117 |
+
# 'decoder.layers.0.mlp.experts.linear_fc2.weight1',
|
| 118 |
+
layer_number = name.split(".")[2]
|
| 119 |
+
convert_names = []
|
| 120 |
+
if "pre_mlp_layernorm" in name:
|
| 121 |
+
convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight")
|
| 122 |
+
assert len(params) == 1
|
| 123 |
+
elif "mlp.router.weight" in name:
|
| 124 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight")
|
| 125 |
+
assert len(params) == 1
|
| 126 |
+
elif "shared_experts.gate_weight" in name:
|
| 127 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert_gate.weight")
|
| 128 |
+
assert len(params) == 1
|
| 129 |
+
elif "shared_experts.linear_fc1.weight" in name: # split gate_proj and up_proj
|
| 130 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight")
|
| 131 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight")
|
| 132 |
+
assert len(params) == 2
|
| 133 |
+
elif "shared_experts.linear_fc2.weight" in name:
|
| 134 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight")
|
| 135 |
+
assert len(params) == 1
|
| 136 |
+
elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj
|
| 137 |
+
expert_id = name.split("weight")[-1]
|
| 138 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight")
|
| 139 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight")
|
| 140 |
+
assert len(params) == 2
|
| 141 |
+
elif "mlp.experts.linear_fc2" in name:
|
| 142 |
+
expert_id = name.split("weight")[-1]
|
| 143 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight")
|
| 144 |
+
assert len(params) == 1
|
| 145 |
+
else:
|
| 146 |
+
raise NotImplementedError(f"Unsupported parameter name: {name}")
|
| 147 |
+
return convert_names, params
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense):
|
| 151 |
+
def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
|
| 152 |
+
# decoder.layers.0.mlp.router.weight
|
| 153 |
+
# decoder.layers.0.mlp.experts.linear_fc1.weight0 - weight7
|
| 154 |
+
# decoder.layers.0.mlp.experts.linear_fc2.weight0 - weight7
|
| 155 |
+
|
| 156 |
+
layer_number = name.split(".")[2]
|
| 157 |
+
convert_names = []
|
| 158 |
+
if "pre_mlp_layernorm" in name:
|
| 159 |
+
convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight")
|
| 160 |
+
elif "mlp.router.weight" in name:
|
| 161 |
+
convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.gate.weight")
|
| 162 |
+
elif "mlp.experts.linear_fc1.weight" in name:
|
| 163 |
+
expert_id = name.split("weight")[-1]
|
| 164 |
+
convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight")
|
| 165 |
+
convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight")
|
| 166 |
+
elif "mlp.experts.linear_fc2.weight" in name:
|
| 167 |
+
expert_id = name.split("weight")[-1]
|
| 168 |
+
convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight")
|
| 169 |
+
else:
|
| 170 |
+
raise NotImplementedError(f"Unsupported parameter name: {name}")
|
| 171 |
+
return convert_names, params
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class McoreToHFWeightConverterQwen3Moe(McoreToHFWeightConverterDense):
|
| 175 |
+
def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
|
| 176 |
+
# qwen3 moe no share expert
|
| 177 |
+
|
| 178 |
+
# 'decoder.layers.0.pre_mlp_layernorm.weight',
|
| 179 |
+
# 'decoder.layers.0.mlp.router.weight',
|
| 180 |
+
# moe1
|
| 181 |
+
# 'decoder.layers.0.mlp.experts.linear_fc1.weight0',
|
| 182 |
+
# 'decoder.layers.0.mlp.experts.linear_fc1.weight1',
|
| 183 |
+
# 'decoder.layers.0.mlp.experts.linear_fc1.weight2',
|
| 184 |
+
# 'decoder.layers.0.mlp.experts.linear_fc1.weight3',
|
| 185 |
+
# moe2
|
| 186 |
+
# 'decoder.layers.0.mlp.experts.linear_fc2.weight0',
|
| 187 |
+
# 'decoder.layers.0.mlp.experts.linear_fc2.weight1',
|
| 188 |
+
layer_number = name.split(".")[2]
|
| 189 |
+
convert_names = []
|
| 190 |
+
if "pre_mlp_layernorm" in name:
|
| 191 |
+
convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight")
|
| 192 |
+
assert len(params) == 1
|
| 193 |
+
elif "mlp.router.weight" in name:
|
| 194 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight")
|
| 195 |
+
assert len(params) == 1
|
| 196 |
+
elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj
|
| 197 |
+
expert_id = name.split("weight")[-1]
|
| 198 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight")
|
| 199 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight")
|
| 200 |
+
assert len(params) == 2
|
| 201 |
+
elif "mlp.experts.linear_fc2" in name:
|
| 202 |
+
expert_id = name.split("weight")[-1]
|
| 203 |
+
convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight")
|
| 204 |
+
assert len(params) == 1
|
| 205 |
+
else:
|
| 206 |
+
raise NotImplementedError(f"Unsupported parameter name: {name}")
|
| 207 |
+
return convert_names, params
|
verl/models/qwen2/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
verl/models/qwen2/megatron/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .modeling_qwen2_megatron import (
|
| 16 |
+
ParallelQwen2ForCausalLM,
|
| 17 |
+
# rmpad with megatron
|
| 18 |
+
ParallelQwen2ForCausalLMRmPad,
|
| 19 |
+
# rmpad with megatron and pipeline parallelism
|
| 20 |
+
ParallelQwen2ForCausalLMRmPadPP,
|
| 21 |
+
ParallelQwen2ForValueRmPad,
|
| 22 |
+
ParallelQwen2ForValueRmPadPP,
|
| 23 |
+
# original model with megatron
|
| 24 |
+
ParallelQwen2Model,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
"ParallelQwen2ForCausalLM",
|
| 29 |
+
"ParallelQwen2ForCausalLMRmPad",
|
| 30 |
+
"ParallelQwen2ForCausalLMRmPadPP",
|
| 31 |
+
"ParallelQwen2ForValueRmPad",
|
| 32 |
+
"ParallelQwen2ForValueRmPadPP",
|
| 33 |
+
"ParallelQwen2Model",
|
| 34 |
+
]
|
verl/models/qwen2/megatron/checkpoint_utils/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _megatron_calc_layer_map(config):
|
| 22 |
+
"""Calculate the mapping of global layer_idx to local layer_idx
|
| 23 |
+
Returns:
|
| 24 |
+
layer_map (Dict: int -> tuple(int, int, int)):
|
| 25 |
+
mapping from the global layer index to
|
| 26 |
+
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
| 27 |
+
"""
|
| 28 |
+
from megatron.core import mpu
|
| 29 |
+
|
| 30 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 31 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 32 |
+
|
| 33 |
+
layer_map = dict()
|
| 34 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 35 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 36 |
+
|
| 37 |
+
for pp_rank_idx in range(pp_size):
|
| 38 |
+
for virtual_pp_rank_idx in range(virtual_pp_size):
|
| 39 |
+
layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
|
| 40 |
+
for layer_idx in range(num_layers_per_model):
|
| 41 |
+
layer_map[layer_offset + layer_idx] = (
|
| 42 |
+
pp_rank_idx,
|
| 43 |
+
virtual_pp_rank_idx,
|
| 44 |
+
layer_idx,
|
| 45 |
+
)
|
| 46 |
+
return layer_map
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):
|
| 50 |
+
"""Load merged state_dict to sharded Megatron module in training."""
|
| 51 |
+
from megatron.core import DistributedDataParallel as LocalDDP
|
| 52 |
+
from megatron.core import mpu
|
| 53 |
+
from megatron.core.transformer.module import Float16Module
|
| 54 |
+
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
| 55 |
+
|
| 56 |
+
from verl.utils.megatron_utils import print_rank_0, unwrap_model
|
| 57 |
+
|
| 58 |
+
start_time = time.time()
|
| 59 |
+
|
| 60 |
+
def _get_gpt_model(model):
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
def fetch_params(module):
|
| 64 |
+
for param in module.parameters():
|
| 65 |
+
torch.distributed.fetch(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())
|
| 66 |
+
|
| 67 |
+
dp_rank = mpu.get_data_parallel_rank()
|
| 68 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 69 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 70 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 71 |
+
mp_group = mpu.get_model_parallel_group()
|
| 72 |
+
|
| 73 |
+
if torch.distributed.get_rank() == 0:
|
| 74 |
+
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
| 75 |
+
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
| 76 |
+
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
| 77 |
+
|
| 78 |
+
if not isinstance(wrapped_models, (list, tuple)):
|
| 79 |
+
wrapped_models = list(wrapped_models)
|
| 80 |
+
|
| 81 |
+
assert len(wrapped_models) == virtual_pp_size
|
| 82 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 83 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}"
|
| 84 |
+
|
| 85 |
+
models = [None] * len(wrapped_models)
|
| 86 |
+
|
| 87 |
+
for i, wrapped_model in enumerate(wrapped_models):
|
| 88 |
+
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
| 89 |
+
gpt_model_module = _get_gpt_model(models[i])
|
| 90 |
+
assert len(gpt_model_module.model.layers) == num_layers_per_model
|
| 91 |
+
|
| 92 |
+
def _fetch_tensor(tensor, name) -> torch.Tensor:
|
| 93 |
+
"""fetch tensor"""
|
| 94 |
+
nonlocal state_dict
|
| 95 |
+
if tensor is not None:
|
| 96 |
+
tensor = tensor.data.copy_(state_dict[name], non_blocking=True)
|
| 97 |
+
|
| 98 |
+
def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
| 99 |
+
"""fetch tensor in tp shards"""
|
| 100 |
+
nonlocal state_dict
|
| 101 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 102 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 103 |
+
if name in state_dict:
|
| 104 |
+
full_weight = state_dict[name]
|
| 105 |
+
|
| 106 |
+
if mutate_func is not None:
|
| 107 |
+
full_weight = mutate_func(full_weight)
|
| 108 |
+
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
| 109 |
+
if tensor is not None:
|
| 110 |
+
tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
|
| 111 |
+
else:
|
| 112 |
+
print(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
| 113 |
+
|
| 114 |
+
def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
| 115 |
+
"""fetch tensor in tp shards"""
|
| 116 |
+
nonlocal state_dict
|
| 117 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 118 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 119 |
+
if name in state_dict:
|
| 120 |
+
full_weight = state_dict[name]
|
| 121 |
+
|
| 122 |
+
if mutate_func is not None:
|
| 123 |
+
full_weight = mutate_func(full_weight)
|
| 124 |
+
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
| 125 |
+
if tensor is not None:
|
| 126 |
+
tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
|
| 127 |
+
else:
|
| 128 |
+
print(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
| 129 |
+
|
| 130 |
+
def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
|
| 131 |
+
"""fetch gate_up tensor in tp shards"""
|
| 132 |
+
nonlocal state_dict
|
| 133 |
+
nonlocal mp_group
|
| 134 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 135 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 136 |
+
if gate_name in state_dict and up_name in state_dict:
|
| 137 |
+
gate_weight = state_dict[gate_name]
|
| 138 |
+
up_weight = state_dict[up_name]
|
| 139 |
+
new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 140 |
+
for i in range(tp_size):
|
| 141 |
+
intermediate_size_tp = config.intermediate_size // tp_size
|
| 142 |
+
gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
|
| 143 |
+
up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
|
| 144 |
+
new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))
|
| 145 |
+
|
| 146 |
+
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
|
| 147 |
+
if tensor is not None:
|
| 148 |
+
tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
|
| 149 |
+
else:
|
| 150 |
+
print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading")
|
| 151 |
+
|
| 152 |
+
def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:
|
| 153 |
+
"""fetch tensor in tp shards across mp_group"""
|
| 154 |
+
nonlocal state_dict
|
| 155 |
+
nonlocal mp_group
|
| 156 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 157 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 158 |
+
assert q_name in state_dict and k_name in state_dict and v_name in state_dict
|
| 159 |
+
full_weight_q = state_dict[q_name]
|
| 160 |
+
full_weight_k = state_dict[k_name]
|
| 161 |
+
full_weight_v = state_dict[v_name]
|
| 162 |
+
|
| 163 |
+
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
| 164 |
+
|
| 165 |
+
if config.num_key_value_heads >= tp_size:
|
| 166 |
+
q_size_tp = config.hidden_size // tp_size
|
| 167 |
+
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
| 168 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 169 |
+
if not bias:
|
| 170 |
+
new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 171 |
+
else:
|
| 172 |
+
new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 173 |
+
for i in range(tp_size):
|
| 174 |
+
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
|
| 175 |
+
k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
|
| 176 |
+
v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
|
| 177 |
+
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
|
| 178 |
+
|
| 179 |
+
else:
|
| 180 |
+
q_size_tp = config.hidden_size // tp_size
|
| 181 |
+
kv_size_tp = hidden_size_per_head
|
| 182 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 183 |
+
if not bias:
|
| 184 |
+
new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 185 |
+
else:
|
| 186 |
+
new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 187 |
+
for i in range(tp_size):
|
| 188 |
+
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
|
| 189 |
+
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
|
| 190 |
+
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
|
| 191 |
+
k_part = full_weight_k[start_idx:end_idx]
|
| 192 |
+
v_part = full_weight_v[start_idx:end_idx]
|
| 193 |
+
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
|
| 194 |
+
|
| 195 |
+
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
|
| 196 |
+
if tensor is not None:
|
| 197 |
+
tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
|
| 198 |
+
|
| 199 |
+
# Embeddings
|
| 200 |
+
# -------------------
|
| 201 |
+
print_rank_0("loading embeddings...")
|
| 202 |
+
gpt_model_module = _get_gpt_model(models[0])
|
| 203 |
+
if pp_rank == 0:
|
| 204 |
+
embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
|
| 205 |
+
_fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
|
| 206 |
+
|
| 207 |
+
# Transformer layers
|
| 208 |
+
# -------------------
|
| 209 |
+
layer_map = _megatron_calc_layer_map(config)
|
| 210 |
+
|
| 211 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 212 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 213 |
+
num_layer_per_pp = config.num_hidden_layers // pp_size
|
| 214 |
+
vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
|
| 215 |
+
|
| 216 |
+
layer_list = []
|
| 217 |
+
if vpp_size is not None:
|
| 218 |
+
for vpp_rank in range(vpp_size):
|
| 219 |
+
num_layer_vpp_chunk = num_layer_per_pp // vpp_size
|
| 220 |
+
num_layer_this_model = num_layer_vpp_chunk
|
| 221 |
+
offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk)
|
| 222 |
+
layer_list.extend(list(range(offset, offset + num_layer_this_model)))
|
| 223 |
+
else:
|
| 224 |
+
num_layer_this_model = num_layer_per_pp
|
| 225 |
+
offset = pp_rank * num_layer_per_pp
|
| 226 |
+
layer_list.extend(list(range(offset, offset + num_layer_this_model)))
|
| 227 |
+
|
| 228 |
+
for layer in layer_list:
|
| 229 |
+
print(f"{torch.distributed.get_rank()} loading layer #{layer}...")
|
| 230 |
+
layer_name = f"model.layers.{layer}"
|
| 231 |
+
dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
|
| 232 |
+
|
| 233 |
+
print(f"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}")
|
| 234 |
+
|
| 235 |
+
gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
|
| 236 |
+
sync_layer = gpt_model_module.model.layers[dst_layer_idx]
|
| 237 |
+
|
| 238 |
+
_fetch_tensor(
|
| 239 |
+
sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
|
| 240 |
+
f"{layer_name}.input_layernorm.weight",
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
_fetch_tp_shard_tensor_qkv(
|
| 244 |
+
sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
|
| 245 |
+
f"{layer_name}.self_attn.q_proj.weight",
|
| 246 |
+
f"{layer_name}.self_attn.k_proj.weight",
|
| 247 |
+
f"{layer_name}.self_attn.v_proj.weight",
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
_fetch_tp_shard_tensor_qkv(
|
| 251 |
+
sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,
|
| 252 |
+
f"{layer_name}.self_attn.q_proj.bias",
|
| 253 |
+
f"{layer_name}.self_attn.k_proj.bias",
|
| 254 |
+
f"{layer_name}.self_attn.v_proj.bias",
|
| 255 |
+
bias=True,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
_fetch_tp_shard_tensor(
|
| 259 |
+
sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
|
| 260 |
+
f"{layer_name}.self_attn.o_proj.weight",
|
| 261 |
+
chunk_dim=1,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
_fetch_tensor(
|
| 265 |
+
sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
|
| 266 |
+
f"{layer_name}.post_attention_layernorm.weight",
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
_fetch_tp_shard_tensor_gate_up(
|
| 270 |
+
sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
|
| 271 |
+
f"{layer_name}.mlp.gate_proj.weight",
|
| 272 |
+
f"{layer_name}.mlp.up_proj.weight",
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
_fetch_tp_shard_tensor(
|
| 276 |
+
sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
|
| 277 |
+
f"{layer_name}.mlp.down_proj.weight",
|
| 278 |
+
chunk_dim=1,
|
| 279 |
+
)
|
| 280 |
+
# Final Layernorm
|
| 281 |
+
# -------------------
|
| 282 |
+
print_rank_0("loading final layernorm...")
|
| 283 |
+
gpt_model_module = _get_gpt_model(models[-1])
|
| 284 |
+
_fetch_tensor(
|
| 285 |
+
getattr(gpt_model_module.model.norm, "weight", None),
|
| 286 |
+
"model.norm.weight",
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if tie_word_embeddings:
|
| 290 |
+
print_rank_0("tie_word_embeddings skip load lm_head")
|
| 291 |
+
else:
|
| 292 |
+
print_rank_0("loading lm_head...")
|
| 293 |
+
if pp_rank + 1 == pp_size:
|
| 294 |
+
lm_head_weight = gpt_model_module.lm_head.weight
|
| 295 |
+
|
| 296 |
+
if is_value_model:
|
| 297 |
+
if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
|
| 298 |
+
_fetch_tensor(lm_head_weight, "lm_head.weight")
|
| 299 |
+
print_rank_0("load lm_head from value_head weight")
|
| 300 |
+
elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
|
| 301 |
+
_fetch_tensor(lm_head_weight, "reward_head.weight")
|
| 302 |
+
print_rank_0("load lm_head from value_head weight")
|
| 303 |
+
else:
|
| 304 |
+
_fetch_tensor(None, "lm_head.weight")
|
| 305 |
+
print_rank_0("fail to match lm_head in value_model")
|
| 306 |
+
|
| 307 |
+
else:
|
| 308 |
+
_fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight")
|
| 309 |
+
|
| 310 |
+
dist.barrier()
|
| 311 |
+
torch.cuda.empty_cache()
|
| 312 |
+
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
|
verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _megatron_calc_layer_map(config):
|
| 22 |
+
"""Calculate the mapping of global layer_idx to local layer_idx
|
| 23 |
+
Returns:
|
| 24 |
+
layer_map (Dict: int -> tuple(int, int, int)):
|
| 25 |
+
mapping from the global layer index to
|
| 26 |
+
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
| 27 |
+
"""
|
| 28 |
+
from megatron.core import mpu
|
| 29 |
+
|
| 30 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 31 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 32 |
+
|
| 33 |
+
layer_map = dict()
|
| 34 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 35 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 36 |
+
|
| 37 |
+
for pp_rank_idx in range(pp_size):
|
| 38 |
+
for virtual_pp_rank_idx in range(virtual_pp_size):
|
| 39 |
+
layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
|
| 40 |
+
for layer_idx in range(num_layers_per_model):
|
| 41 |
+
layer_map[layer_offset + layer_idx] = (
|
| 42 |
+
pp_rank_idx,
|
| 43 |
+
virtual_pp_rank_idx,
|
| 44 |
+
layer_idx,
|
| 45 |
+
)
|
| 46 |
+
return layer_map
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):
|
| 50 |
+
"""Load merged state_dict to sharded Megatron module in training."""
|
| 51 |
+
from megatron.core import DistributedDataParallel as LocalDDP
|
| 52 |
+
from megatron.core import mpu
|
| 53 |
+
from megatron.core.transformer.module import Float16Module
|
| 54 |
+
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
| 55 |
+
|
| 56 |
+
from verl.utils.megatron_utils import print_rank_0, unwrap_model
|
| 57 |
+
|
| 58 |
+
start_time = time.time()
|
| 59 |
+
|
| 60 |
+
def _get_gpt_model(model):
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
def broadcast_params(module):
|
| 64 |
+
for param in module.parameters():
|
| 65 |
+
torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())
|
| 66 |
+
|
| 67 |
+
dp_rank = mpu.get_data_parallel_rank()
|
| 68 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 69 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 70 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 71 |
+
mp_group = mpu.get_model_parallel_group()
|
| 72 |
+
|
| 73 |
+
if torch.distributed.get_rank() == 0:
|
| 74 |
+
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
| 75 |
+
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
| 76 |
+
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
| 77 |
+
|
| 78 |
+
if not isinstance(wrapped_models, (list, tuple)):
|
| 79 |
+
wrapped_models = list(wrapped_models)
|
| 80 |
+
|
| 81 |
+
assert len(wrapped_models) == virtual_pp_size
|
| 82 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 83 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}"
|
| 84 |
+
|
| 85 |
+
models = [None] * len(wrapped_models)
|
| 86 |
+
|
| 87 |
+
for i, wrapped_model in enumerate(wrapped_models):
|
| 88 |
+
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
| 89 |
+
gpt_model_module = _get_gpt_model(models[i])
|
| 90 |
+
assert len(gpt_model_module.model.layers) == num_layers_per_model
|
| 91 |
+
|
| 92 |
+
def _broadcast_tensor(tensor, name) -> torch.Tensor:
|
| 93 |
+
"""broadcast tensor from rank0 across mp_group"""
|
| 94 |
+
nonlocal state_dict
|
| 95 |
+
nonlocal mp_group
|
| 96 |
+
if torch.distributed.get_rank() == 0:
|
| 97 |
+
if name in state_dict:
|
| 98 |
+
weight = state_dict[name]
|
| 99 |
+
tensor_shape = weight.shape
|
| 100 |
+
else:
|
| 101 |
+
tensor_shape = None
|
| 102 |
+
else:
|
| 103 |
+
weight = None
|
| 104 |
+
tensor_shape = None
|
| 105 |
+
|
| 106 |
+
obj_list = [tensor_shape]
|
| 107 |
+
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
| 108 |
+
tensor_shape = obj_list[0]
|
| 109 |
+
|
| 110 |
+
if tensor_shape is None:
|
| 111 |
+
# all or none ranks in the mp_group should reach here
|
| 112 |
+
print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
if tensor is None:
|
| 116 |
+
tensor = torch.empty(
|
| 117 |
+
tensor_shape,
|
| 118 |
+
dtype=params_dtype,
|
| 119 |
+
device=torch.cuda.current_device(),
|
| 120 |
+
requires_grad=False,
|
| 121 |
+
)
|
| 122 |
+
if torch.distributed.get_rank() == 0:
|
| 123 |
+
tensor.data.copy_(weight)
|
| 124 |
+
dist.broadcast(tensor, src=0, group=mp_group)
|
| 125 |
+
|
| 126 |
+
def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
| 127 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 128 |
+
nonlocal state_dict
|
| 129 |
+
nonlocal mp_group
|
| 130 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 131 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 132 |
+
|
| 133 |
+
if torch.distributed.get_rank() == 0:
|
| 134 |
+
if name in state_dict:
|
| 135 |
+
full_weight = state_dict[name]
|
| 136 |
+
|
| 137 |
+
if mutate_func is not None:
|
| 138 |
+
full_weight = mutate_func(full_weight)
|
| 139 |
+
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
| 140 |
+
chunk_shape = tensor_chunk[0].shape
|
| 141 |
+
else:
|
| 142 |
+
chunk_shape = None
|
| 143 |
+
else:
|
| 144 |
+
chunk_shape = None
|
| 145 |
+
|
| 146 |
+
obj_list = [chunk_shape]
|
| 147 |
+
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
| 148 |
+
chunk_shape = obj_list[0]
|
| 149 |
+
if chunk_shape is None:
|
| 150 |
+
# all or none ranks in the mp_group should reach here
|
| 151 |
+
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
if tensor is None:
|
| 155 |
+
sync_tensor = torch.empty(
|
| 156 |
+
chunk_shape,
|
| 157 |
+
dtype=params_dtype,
|
| 158 |
+
device=torch.cuda.current_device(),
|
| 159 |
+
requires_grad=False,
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
|
| 163 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 164 |
+
|
| 165 |
+
for i in range(tp_size):
|
| 166 |
+
if torch.distributed.get_rank() == 0:
|
| 167 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 168 |
+
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
| 169 |
+
if (i == tp_rank) and (tensor is not None):
|
| 170 |
+
tensor.data.copy_(sync_tensor)
|
| 171 |
+
|
| 172 |
+
def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
|
| 173 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 174 |
+
nonlocal state_dict
|
| 175 |
+
nonlocal mp_group
|
| 176 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 177 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 178 |
+
|
| 179 |
+
if torch.distributed.get_rank() == 0:
|
| 180 |
+
if name in state_dict:
|
| 181 |
+
full_weight = state_dict[name]
|
| 182 |
+
if mutate_func is not None:
|
| 183 |
+
full_weight = mutate_func(full_weight)
|
| 184 |
+
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
|
| 185 |
+
chunk_shape = tensor_chunk[0].shape
|
| 186 |
+
else:
|
| 187 |
+
chunk_shape = None
|
| 188 |
+
else:
|
| 189 |
+
chunk_shape = None
|
| 190 |
+
|
| 191 |
+
obj_list = [chunk_shape]
|
| 192 |
+
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
| 193 |
+
chunk_shape = obj_list[0]
|
| 194 |
+
if chunk_shape is None:
|
| 195 |
+
# all or none ranks in the mp_group should reach here
|
| 196 |
+
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
if tensor is None:
|
| 200 |
+
sync_tensor = torch.empty(
|
| 201 |
+
chunk_shape,
|
| 202 |
+
dtype=params_dtype,
|
| 203 |
+
device=torch.cuda.current_device(),
|
| 204 |
+
requires_grad=False,
|
| 205 |
+
)
|
| 206 |
+
else:
|
| 207 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
|
| 208 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 209 |
+
|
| 210 |
+
for i in range(tp_size):
|
| 211 |
+
if torch.distributed.get_rank() == 0:
|
| 212 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 213 |
+
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
| 214 |
+
if (i == tp_rank) and (tensor is not None):
|
| 215 |
+
tensor.data.copy_(sync_tensor)
|
| 216 |
+
|
| 217 |
+
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
|
| 218 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 219 |
+
nonlocal state_dict
|
| 220 |
+
nonlocal mp_group
|
| 221 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 222 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 223 |
+
|
| 224 |
+
if torch.distributed.get_rank() == 0:
|
| 225 |
+
gate_weight = state_dict[gate_name]
|
| 226 |
+
up_weight = state_dict[up_name]
|
| 227 |
+
new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 228 |
+
for i in range(tp_size):
|
| 229 |
+
intermediate_size_tp = config.intermediate_size // tp_size
|
| 230 |
+
gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
|
| 231 |
+
up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]
|
| 232 |
+
new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))
|
| 233 |
+
|
| 234 |
+
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
|
| 235 |
+
chunk_shape = tensor_chunk[0].shape
|
| 236 |
+
else:
|
| 237 |
+
chunk_shape = None
|
| 238 |
+
|
| 239 |
+
obj_list = [chunk_shape]
|
| 240 |
+
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
| 241 |
+
chunk_shape = obj_list[0]
|
| 242 |
+
if chunk_shape is None:
|
| 243 |
+
# all or none ranks in the mp_group should reach here
|
| 244 |
+
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
|
| 245 |
+
return
|
| 246 |
+
|
| 247 |
+
if tensor is None:
|
| 248 |
+
sync_tensor = torch.empty(
|
| 249 |
+
chunk_shape,
|
| 250 |
+
dtype=params_dtype,
|
| 251 |
+
device=torch.cuda.current_device(),
|
| 252 |
+
requires_grad=False,
|
| 253 |
+
)
|
| 254 |
+
else:
|
| 255 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
|
| 256 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 257 |
+
|
| 258 |
+
for i in range(tp_size):
|
| 259 |
+
if torch.distributed.get_rank() == 0:
|
| 260 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 261 |
+
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
| 262 |
+
if (i == tp_rank) and (tensor is not None):
|
| 263 |
+
tensor.data.copy_(sync_tensor)
|
| 264 |
+
|
| 265 |
+
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:
|
| 266 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 267 |
+
nonlocal state_dict
|
| 268 |
+
nonlocal mp_group
|
| 269 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 270 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 271 |
+
|
| 272 |
+
if torch.distributed.get_rank() == 0:
|
| 273 |
+
assert q_name in state_dict and k_name in state_dict and v_name in state_dict
|
| 274 |
+
full_weight_q = state_dict[q_name]
|
| 275 |
+
full_weight_k = state_dict[k_name]
|
| 276 |
+
full_weight_v = state_dict[v_name]
|
| 277 |
+
|
| 278 |
+
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
| 279 |
+
|
| 280 |
+
if config.num_key_value_heads >= tp_size:
|
| 281 |
+
q_size_tp = config.hidden_size // tp_size
|
| 282 |
+
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
| 283 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 284 |
+
if not bias:
|
| 285 |
+
new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 286 |
+
else:
|
| 287 |
+
new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 288 |
+
for i in range(tp_size):
|
| 289 |
+
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
|
| 290 |
+
k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]
|
| 291 |
+
v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]
|
| 292 |
+
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
|
| 293 |
+
|
| 294 |
+
else:
|
| 295 |
+
q_size_tp = config.hidden_size // tp_size
|
| 296 |
+
kv_size_tp = hidden_size_per_head
|
| 297 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 298 |
+
if not bias:
|
| 299 |
+
new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 300 |
+
else:
|
| 301 |
+
new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device())
|
| 302 |
+
for i in range(tp_size):
|
| 303 |
+
q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]
|
| 304 |
+
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
|
| 305 |
+
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
|
| 306 |
+
k_part = full_weight_k[start_idx:end_idx]
|
| 307 |
+
v_part = full_weight_v[start_idx:end_idx]
|
| 308 |
+
new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
|
| 309 |
+
|
| 310 |
+
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
|
| 311 |
+
chunk_shape = tensor_chunk[0].shape
|
| 312 |
+
else:
|
| 313 |
+
chunk_shape = None
|
| 314 |
+
|
| 315 |
+
obj_list = [chunk_shape]
|
| 316 |
+
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
|
| 317 |
+
chunk_shape = obj_list[0]
|
| 318 |
+
if chunk_shape is None:
|
| 319 |
+
# all or none ranks in the mp_group should reach here
|
| 320 |
+
print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading")
|
| 321 |
+
return
|
| 322 |
+
|
| 323 |
+
if tensor is None:
|
| 324 |
+
sync_tensor = torch.empty(
|
| 325 |
+
chunk_shape,
|
| 326 |
+
dtype=params_dtype,
|
| 327 |
+
device=torch.cuda.current_device(),
|
| 328 |
+
requires_grad=False,
|
| 329 |
+
)
|
| 330 |
+
else:
|
| 331 |
+
assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
|
| 332 |
+
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
|
| 333 |
+
|
| 334 |
+
for i in range(tp_size):
|
| 335 |
+
if torch.distributed.get_rank() == 0:
|
| 336 |
+
sync_tensor.data.copy_(tensor_chunk[i])
|
| 337 |
+
dist.broadcast(sync_tensor, src=0, group=mp_group)
|
| 338 |
+
if (i == tp_rank) and (tensor is not None):
|
| 339 |
+
tensor.data.copy_(sync_tensor)
|
| 340 |
+
|
| 341 |
+
if dp_rank == 0:
|
| 342 |
+
# Embeddings
|
| 343 |
+
# -------------------
|
| 344 |
+
print_rank_0("loading embeddings...")
|
| 345 |
+
gpt_model_module = _get_gpt_model(models[0])
|
| 346 |
+
embed_tokens_weight = None
|
| 347 |
+
if pp_rank == 0:
|
| 348 |
+
embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
|
| 349 |
+
_broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
|
| 350 |
+
|
| 351 |
+
# Transformer layers
|
| 352 |
+
# -------------------
|
| 353 |
+
layer_map = _megatron_calc_layer_map(config)
|
| 354 |
+
|
| 355 |
+
for layer in range(config.num_hidden_layers):
|
| 356 |
+
print_rank_0(f"loading layer #{layer}...")
|
| 357 |
+
layer_name = f"model.layers.{layer}"
|
| 358 |
+
dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
|
| 359 |
+
|
| 360 |
+
gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
|
| 361 |
+
sync_layer = gpt_model_module.model.layers[dst_layer_idx]
|
| 362 |
+
|
| 363 |
+
_broadcast_tensor(
|
| 364 |
+
sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
|
| 365 |
+
f"{layer_name}.input_layernorm.weight",
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
_broadcast_tp_shard_tensor_qkv(
|
| 369 |
+
sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
|
| 370 |
+
f"{layer_name}.self_attn.q_proj.weight",
|
| 371 |
+
f"{layer_name}.self_attn.k_proj.weight",
|
| 372 |
+
f"{layer_name}.self_attn.v_proj.weight",
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
_broadcast_tp_shard_tensor_qkv(
|
| 376 |
+
sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,
|
| 377 |
+
f"{layer_name}.self_attn.q_proj.bias",
|
| 378 |
+
f"{layer_name}.self_attn.k_proj.bias",
|
| 379 |
+
f"{layer_name}.self_attn.v_proj.bias",
|
| 380 |
+
bias=True,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
_broadcast_tp_shard_tensor(
|
| 384 |
+
sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
|
| 385 |
+
f"{layer_name}.self_attn.o_proj.weight",
|
| 386 |
+
chunk_dim=1,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
_broadcast_tensor(
|
| 390 |
+
sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
|
| 391 |
+
f"{layer_name}.post_attention_layernorm.weight",
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
_broadcast_tp_shard_tensor_gate_up(
|
| 395 |
+
sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
|
| 396 |
+
f"{layer_name}.mlp.gate_proj.weight",
|
| 397 |
+
f"{layer_name}.mlp.up_proj.weight",
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
_broadcast_tp_shard_tensor(
|
| 401 |
+
sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
|
| 402 |
+
f"{layer_name}.mlp.down_proj.weight",
|
| 403 |
+
chunk_dim=1,
|
| 404 |
+
)
|
| 405 |
+
# Final Layernorm
|
| 406 |
+
# -------------------
|
| 407 |
+
print_rank_0("loading final layernorm...")
|
| 408 |
+
gpt_model_module = _get_gpt_model(models[-1])
|
| 409 |
+
_broadcast_tensor(
|
| 410 |
+
getattr(gpt_model_module.model.norm, "weight", None),
|
| 411 |
+
"model.norm.weight",
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
if tie_word_embeddings:
|
| 415 |
+
print_rank_0("tie_word_embeddings skip load lm_head")
|
| 416 |
+
else:
|
| 417 |
+
print_rank_0("loading lm_head...")
|
| 418 |
+
lm_head_weight = None
|
| 419 |
+
if pp_rank + 1 == pp_size:
|
| 420 |
+
lm_head_weight = gpt_model_module.lm_head.weight
|
| 421 |
+
|
| 422 |
+
if is_value_model:
|
| 423 |
+
if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1:
|
| 424 |
+
_broadcast_tensor(lm_head_weight, "lm_head.weight")
|
| 425 |
+
print_rank_0("load lm_head from value_head weight")
|
| 426 |
+
elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1:
|
| 427 |
+
_broadcast_tensor(lm_head_weight, "reward_head.weight")
|
| 428 |
+
print_rank_0("load lm_head from value_head weight")
|
| 429 |
+
else:
|
| 430 |
+
_broadcast_tensor(None, "lm_head.weight")
|
| 431 |
+
print_rank_0("fail to match lm_head in value_model")
|
| 432 |
+
|
| 433 |
+
else:
|
| 434 |
+
_broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
|
| 435 |
+
|
| 436 |
+
dist.barrier()
|
| 437 |
+
# Broadcast weights inside data parallel groups
|
| 438 |
+
for wrapped_model in wrapped_models:
|
| 439 |
+
broadcast_params(wrapped_model)
|
| 440 |
+
|
| 441 |
+
torch.cuda.empty_cache()
|
| 442 |
+
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
|
verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
from megatron.core import mpu
|
| 20 |
+
from megatron.core.distributed import DistributedDataParallel as LocalDDP
|
| 21 |
+
from megatron.core.transformer.module import Float16Module
|
| 22 |
+
from torch.nn.parallel import DistributedDataParallel as torchDDP
|
| 23 |
+
|
| 24 |
+
from verl.utils.megatron_utils import print_rank_0, unwrap_model
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
|
| 28 |
+
"""given TP,DP,PP rank to get the global rank."""
|
| 29 |
+
|
| 30 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 31 |
+
dp_size = mpu.get_data_parallel_world_size()
|
| 32 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 33 |
+
assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
|
| 34 |
+
# We only support TP-DP-PP grouping, for correctness when resharding
|
| 35 |
+
return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _megatron_calc_layer_map(config):
|
| 39 |
+
"""Calculate the mapping of global layer_idx to local layer_idx
|
| 40 |
+
Returns:
|
| 41 |
+
layer_map (Dict: int -> tuple(int, int, int)):
|
| 42 |
+
mapping from the global layer index to
|
| 43 |
+
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
|
| 44 |
+
"""
|
| 45 |
+
from megatron.core import mpu
|
| 46 |
+
|
| 47 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 48 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 49 |
+
|
| 50 |
+
layer_map = dict()
|
| 51 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 52 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 53 |
+
|
| 54 |
+
for pp_rank_idx in range(pp_size):
|
| 55 |
+
for virtual_pp_rank_idx in range(virtual_pp_size):
|
| 56 |
+
layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model
|
| 57 |
+
for layer_idx in range(num_layers_per_model):
|
| 58 |
+
layer_map[layer_offset + layer_idx] = (
|
| 59 |
+
pp_rank_idx,
|
| 60 |
+
virtual_pp_rank_idx,
|
| 61 |
+
layer_idx,
|
| 62 |
+
)
|
| 63 |
+
return layer_map
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
|
| 67 |
+
"""Merge sharded parameters of a Megatron module into a merged checkpoint.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
wrapped_models (list of megatron.core.distributed.DistributedDataParallel):
|
| 71 |
+
The local DDP wrapped megatron modules.
|
| 72 |
+
config (str or None):
|
| 73 |
+
HF config for model
|
| 74 |
+
dtype: model params type
|
| 75 |
+
is_value_model: if model is value model
|
| 76 |
+
tie_word_embeddings: tie_word_embeddings
|
| 77 |
+
Returns:
|
| 78 |
+
state_dict (dict):
|
| 79 |
+
The merged state_dict in rank 0, and an empty dictionary in other ranks.
|
| 80 |
+
"""
|
| 81 |
+
start_time = time.time()
|
| 82 |
+
|
| 83 |
+
def _get_gpt_model(model):
|
| 84 |
+
return model
|
| 85 |
+
|
| 86 |
+
dp_rank = mpu.get_data_parallel_rank()
|
| 87 |
+
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
| 88 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 89 |
+
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
|
| 90 |
+
mp_group = mpu.get_model_parallel_group()
|
| 91 |
+
|
| 92 |
+
if dist.get_rank() == 0:
|
| 93 |
+
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
|
| 94 |
+
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
| 95 |
+
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
| 96 |
+
|
| 97 |
+
if not isinstance(wrapped_models, (list, tuple)):
|
| 98 |
+
wrapped_models = list(wrapped_models)
|
| 99 |
+
|
| 100 |
+
assert len(wrapped_models) == virtual_pp_size
|
| 101 |
+
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
|
| 102 |
+
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
|
| 103 |
+
|
| 104 |
+
models = [None] * len(wrapped_models)
|
| 105 |
+
|
| 106 |
+
for i, wrapped_model in enumerate(wrapped_models):
|
| 107 |
+
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
|
| 108 |
+
assert len(models[i].model.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].model.layers), num_layers_per_model)
|
| 109 |
+
|
| 110 |
+
state_dict = dict()
|
| 111 |
+
|
| 112 |
+
def _get_cpu_tensor(tensor: torch.Tensor):
|
| 113 |
+
if tensor is None:
|
| 114 |
+
return None
|
| 115 |
+
if tensor.device == torch.device("cpu"):
|
| 116 |
+
return tensor.detach().clone()
|
| 117 |
+
return tensor.detach().cpu()
|
| 118 |
+
|
| 119 |
+
def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
|
| 120 |
+
"""broadcast tensor across mp_group"""
|
| 121 |
+
nonlocal state_dict
|
| 122 |
+
nonlocal mp_group
|
| 123 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
| 124 |
+
|
| 125 |
+
if torch.distributed.get_rank() == src_rank:
|
| 126 |
+
if tensor is None:
|
| 127 |
+
weight = None
|
| 128 |
+
tensor_shape = None
|
| 129 |
+
else:
|
| 130 |
+
weight = tensor
|
| 131 |
+
tensor_shape = weight.shape
|
| 132 |
+
else:
|
| 133 |
+
weight = None
|
| 134 |
+
tensor_shape = None
|
| 135 |
+
|
| 136 |
+
obj_list = [tensor_shape]
|
| 137 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 138 |
+
tensor_shape = obj_list[0]
|
| 139 |
+
|
| 140 |
+
if tensor_shape is None:
|
| 141 |
+
# all or none ranks in the mp_group should reach here
|
| 142 |
+
print_rank_0(f"tensor:[{name}] not exist, skip collect")
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
if weight is None:
|
| 146 |
+
weight = torch.empty(
|
| 147 |
+
tensor_shape,
|
| 148 |
+
dtype=dtype,
|
| 149 |
+
device=torch.cuda.current_device(),
|
| 150 |
+
requires_grad=False,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
dist.broadcast(weight, src=src_rank, group=mp_group)
|
| 154 |
+
|
| 155 |
+
if torch.distributed.get_rank() == 0:
|
| 156 |
+
state_dict[name] = _get_cpu_tensor(weight)
|
| 157 |
+
|
| 158 |
+
def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
|
| 159 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 160 |
+
nonlocal state_dict
|
| 161 |
+
nonlocal mp_group
|
| 162 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 163 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
| 164 |
+
|
| 165 |
+
chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
|
| 166 |
+
|
| 167 |
+
obj_list = [chunk_shape]
|
| 168 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 169 |
+
chunk_shape = obj_list[0]
|
| 170 |
+
if chunk_shape is None:
|
| 171 |
+
# all or none ranks in the mp_group should reach here
|
| 172 |
+
print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
+
buffer_tensor = torch.empty(
|
| 176 |
+
chunk_shape,
|
| 177 |
+
dtype=dtype,
|
| 178 |
+
device=torch.cuda.current_device(),
|
| 179 |
+
requires_grad=False,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
chunk_tensors = [None] * tp_size
|
| 183 |
+
|
| 184 |
+
for i in range(tp_size):
|
| 185 |
+
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
| 186 |
+
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
| 187 |
+
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
| 188 |
+
|
| 189 |
+
if torch.distributed.get_rank() == 0:
|
| 190 |
+
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
| 191 |
+
|
| 192 |
+
if torch.distributed.get_rank() == 0:
|
| 193 |
+
full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
|
| 194 |
+
if mutate_func is not None:
|
| 195 |
+
full_tensor = mutate_func(full_tensor)
|
| 196 |
+
state_dict[name] = full_tensor
|
| 197 |
+
|
| 198 |
+
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
|
| 199 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 200 |
+
nonlocal state_dict
|
| 201 |
+
nonlocal mp_group
|
| 202 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 203 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
| 204 |
+
|
| 205 |
+
chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
|
| 206 |
+
|
| 207 |
+
obj_list = [chunk_shape]
|
| 208 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 209 |
+
chunk_shape = obj_list[0]
|
| 210 |
+
if chunk_shape is None:
|
| 211 |
+
# all or none ranks in the mp_group should reach here
|
| 212 |
+
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
|
| 213 |
+
return
|
| 214 |
+
|
| 215 |
+
buffer_tensor = torch.empty(
|
| 216 |
+
chunk_shape,
|
| 217 |
+
dtype=dtype,
|
| 218 |
+
device=torch.cuda.current_device(),
|
| 219 |
+
requires_grad=False,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
chunk_tensors = [None] * tp_size
|
| 223 |
+
|
| 224 |
+
for i in range(tp_size):
|
| 225 |
+
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
| 226 |
+
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
| 227 |
+
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
| 228 |
+
|
| 229 |
+
if torch.distributed.get_rank() == 0:
|
| 230 |
+
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
| 231 |
+
|
| 232 |
+
if torch.distributed.get_rank() == 0:
|
| 233 |
+
full_tensor = torch.concat(chunk_tensors, dim=0)
|
| 234 |
+
intermediate_size_tp = config.intermediate_size // tp_size
|
| 235 |
+
gate_weight_list = []
|
| 236 |
+
up_weight_list = []
|
| 237 |
+
for i in range(tp_size):
|
| 238 |
+
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]
|
| 239 |
+
gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
|
| 240 |
+
up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
|
| 241 |
+
gate_weight_list.append(gate_weight_tp)
|
| 242 |
+
up_weight_list.append(up_weight_tp)
|
| 243 |
+
|
| 244 |
+
state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
|
| 245 |
+
state_dict[up_name] = torch.cat(up_weight_list, dim=0)
|
| 246 |
+
|
| 247 |
+
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
|
| 248 |
+
"""broadcast tensor in tp shards across mp_group"""
|
| 249 |
+
nonlocal state_dict
|
| 250 |
+
nonlocal mp_group
|
| 251 |
+
tp_size = mpu.get_tensor_model_parallel_world_size()
|
| 252 |
+
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
|
| 253 |
+
|
| 254 |
+
chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None
|
| 255 |
+
|
| 256 |
+
obj_list = [chunk_shape]
|
| 257 |
+
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
|
| 258 |
+
chunk_shape = obj_list[0]
|
| 259 |
+
if chunk_shape is None:
|
| 260 |
+
# all or none ranks in the mp_group should reach here
|
| 261 |
+
print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
|
| 262 |
+
return
|
| 263 |
+
|
| 264 |
+
buffer_tensor = torch.empty(
|
| 265 |
+
chunk_shape,
|
| 266 |
+
dtype=dtype,
|
| 267 |
+
device=torch.cuda.current_device(),
|
| 268 |
+
requires_grad=False,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
chunk_tensors = [None] * tp_size
|
| 272 |
+
|
| 273 |
+
for i in range(tp_size):
|
| 274 |
+
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
|
| 275 |
+
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
|
| 276 |
+
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
|
| 277 |
+
|
| 278 |
+
if torch.distributed.get_rank() == 0:
|
| 279 |
+
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
|
| 280 |
+
|
| 281 |
+
if torch.distributed.get_rank() == 0:
|
| 282 |
+
full_tensor = torch.concat(chunk_tensors, dim=0)
|
| 283 |
+
q_weight_list = []
|
| 284 |
+
k_weight_list = []
|
| 285 |
+
v_weight_list = []
|
| 286 |
+
hidden_size_per_head = config.hidden_size // config.num_attention_heads
|
| 287 |
+
|
| 288 |
+
if config.num_key_value_heads >= tp_size:
|
| 289 |
+
q_size_tp = config.hidden_size // tp_size
|
| 290 |
+
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
|
| 291 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 292 |
+
for i in range(tp_size):
|
| 293 |
+
qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
|
| 294 |
+
q_part = qkv_part[:q_size_tp]
|
| 295 |
+
k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]
|
| 296 |
+
v_part = qkv_part[q_size_tp + kv_size_tp : total_size]
|
| 297 |
+
q_weight_list.append(q_part)
|
| 298 |
+
k_weight_list.append(k_part)
|
| 299 |
+
v_weight_list.append(v_part)
|
| 300 |
+
else:
|
| 301 |
+
q_size_tp = config.hidden_size // tp_size
|
| 302 |
+
kv_size_tp = hidden_size_per_head
|
| 303 |
+
total_size = q_size_tp + 2 * kv_size_tp
|
| 304 |
+
for i in range(tp_size):
|
| 305 |
+
qkv_part = full_tensor[i * total_size : (i + 1) * total_size]
|
| 306 |
+
q_part = qkv_part[:q_size_tp]
|
| 307 |
+
k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]
|
| 308 |
+
v_part = qkv_part[q_size_tp + kv_size_tp : total_size]
|
| 309 |
+
q_weight_list.append(q_part)
|
| 310 |
+
if i * config.num_key_value_heads % tp_size == 0:
|
| 311 |
+
k_weight_list.append(k_part)
|
| 312 |
+
v_weight_list.append(v_part)
|
| 313 |
+
|
| 314 |
+
state_dict[q_name] = torch.cat(q_weight_list, dim=0)
|
| 315 |
+
state_dict[k_name] = torch.cat(k_weight_list, dim=0)
|
| 316 |
+
state_dict[v_name] = torch.cat(v_weight_list, dim=0)
|
| 317 |
+
|
| 318 |
+
# empty cache before collecting weights
|
| 319 |
+
torch.cuda.empty_cache()
|
| 320 |
+
# Embeddings
|
| 321 |
+
# -------------------
|
| 322 |
+
if dp_rank == 0:
|
| 323 |
+
# Embeddings
|
| 324 |
+
# -------------------
|
| 325 |
+
print_rank_0("collecting embeddings...")
|
| 326 |
+
gpt_model_module = _get_gpt_model(models[0])
|
| 327 |
+
_broadcast_tp_shard_tensor(
|
| 328 |
+
gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,
|
| 329 |
+
"model.embed_tokens.weight",
|
| 330 |
+
src_pp_rank=0,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Transformer layers
|
| 334 |
+
# -------------------
|
| 335 |
+
layer_map = _megatron_calc_layer_map(config)
|
| 336 |
+
for layer in range(config.num_hidden_layers):
|
| 337 |
+
print_rank_0(f"collecting layer #{layer}...")
|
| 338 |
+
layer_name = f"model.layers.{layer}"
|
| 339 |
+
src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]
|
| 340 |
+
|
| 341 |
+
gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
|
| 342 |
+
sync_layer = gpt_model_module.model.layers[src_layer_idx]
|
| 343 |
+
|
| 344 |
+
_broadcast_tensor(
|
| 345 |
+
sync_layer.input_layernorm.weight,
|
| 346 |
+
f"{layer_name}.input_layernorm.weight",
|
| 347 |
+
src_pp_rank=src_pp_rank,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
_broadcast_tp_shard_tensor_qkv(
|
| 351 |
+
sync_layer.self_attn.qkv_proj.weight,
|
| 352 |
+
f"{layer_name}.self_attn.q_proj.weight",
|
| 353 |
+
f"{layer_name}.self_attn.k_proj.weight",
|
| 354 |
+
f"{layer_name}.self_attn.v_proj.weight",
|
| 355 |
+
src_pp_rank=src_pp_rank,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
_broadcast_tp_shard_tensor_qkv(
|
| 359 |
+
sync_layer.self_attn.qkv_proj.bias,
|
| 360 |
+
f"{layer_name}.self_attn.q_proj.bias",
|
| 361 |
+
f"{layer_name}.self_attn.k_proj.bias",
|
| 362 |
+
f"{layer_name}.self_attn.v_proj.bias",
|
| 363 |
+
src_pp_rank=src_pp_rank,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
_broadcast_tp_shard_tensor(
|
| 367 |
+
sync_layer.self_attn.o_proj.weight,
|
| 368 |
+
f"{layer_name}.self_attn.o_proj.weight",
|
| 369 |
+
concat_dim=1,
|
| 370 |
+
src_pp_rank=src_pp_rank,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
_broadcast_tensor(
|
| 374 |
+
sync_layer.post_attention_layernorm.weight,
|
| 375 |
+
f"{layer_name}.post_attention_layernorm.weight",
|
| 376 |
+
src_pp_rank=src_pp_rank,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
_broadcast_tp_shard_tensor_gate_up(
|
| 380 |
+
sync_layer.mlp.gate_up_proj.weight,
|
| 381 |
+
f"{layer_name}.mlp.gate_proj.weight",
|
| 382 |
+
f"{layer_name}.mlp.up_proj.weight",
|
| 383 |
+
src_pp_rank=src_pp_rank,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
_broadcast_tp_shard_tensor(
|
| 387 |
+
sync_layer.mlp.down_proj.weight,
|
| 388 |
+
f"{layer_name}.mlp.down_proj.weight",
|
| 389 |
+
concat_dim=1,
|
| 390 |
+
src_pp_rank=src_pp_rank,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# Final Layernorm
|
| 394 |
+
# -------------------
|
| 395 |
+
print_rank_0("collecting final layernorm...")
|
| 396 |
+
gpt_model_module = _get_gpt_model(models[-1])
|
| 397 |
+
_broadcast_tensor(
|
| 398 |
+
getattr(gpt_model_module.model.norm, "weight", None),
|
| 399 |
+
"model.norm.weight",
|
| 400 |
+
src_pp_rank=pp_size - 1,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
if tie_word_embeddings:
|
| 404 |
+
print_rank_0("tie word embedding skip load lm_head...")
|
| 405 |
+
else:
|
| 406 |
+
print_rank_0("collecting lm_head...")
|
| 407 |
+
|
| 408 |
+
if is_value_model:
|
| 409 |
+
_broadcast_tensor(
|
| 410 |
+
gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,
|
| 411 |
+
"lm_head.weight",
|
| 412 |
+
src_pp_rank=pp_size - 1,
|
| 413 |
+
)
|
| 414 |
+
_broadcast_tensor(
|
| 415 |
+
gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None else None,
|
| 416 |
+
"reward_head.weight",
|
| 417 |
+
src_pp_rank=pp_size - 1,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
else:
|
| 421 |
+
_broadcast_tp_shard_tensor(
|
| 422 |
+
getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
|
| 423 |
+
"lm_head.weight",
|
| 424 |
+
src_pp_rank=pp_size - 1,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
dist.barrier()
|
| 428 |
+
|
| 429 |
+
torch.cuda.empty_cache()
|
| 430 |
+
if torch.distributed.get_rank() == 0:
|
| 431 |
+
for k, v in state_dict.items():
|
| 432 |
+
if dtype != v.dtype:
|
| 433 |
+
state_dict[k] = v.to(dtype)
|
| 434 |
+
|
| 435 |
+
print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
|
| 436 |
+
return state_dict
|
verl/models/qwen2/megatron/layers/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .parallel_attention import ParallelQwen2Attention
|
| 16 |
+
from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad
|
| 17 |
+
from .parallel_mlp import ParallelQwen2MLP
|
| 18 |
+
from .parallel_rmsnorm import ParallelQwen2RMSNorm
|
| 19 |
+
|
| 20 |
+
__all__ = ["ParallelQwen2Attention", "ParallelQwen2DecoderLayer", "ParallelQwen2DecoderLayerRmPad", "ParallelQwen2MLP", "ParallelQwen2RMSNorm"]
|