diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1a53453ab9aa9e36c85aba110789aa9440de1850 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +checkpoints/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..61865f73ec347515fab82ff764467b62f4f78b76 --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +--- +title: Text-to-Text Generation with DeepSeek-R1-Distill-Qwen-7B +emoji: πŸš€ +colorFrom: blue +colorTo: green +sdk: gradio +sdk_version: 4.19.2 +app_file: app.py +pinned: false +--- + +# Text-to-Text Generation with DeepSeek-R1-Distill-Qwen-7B + +This is a simple Gradio interface for text-to-text generation using the `deepseek-ai/DeepSeek-R1-Distill-Qwen-7B` model. + +## How to use + +1. Enter a prompt in the text box. +2. Click the "Generate" button. +3. The model will generate a response in the "Response" text box. \ No newline at end of file diff --git a/_infer.py b/_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd884800d436c9f55b983eb4a4b667a126efb1f5 --- /dev/null +++ b/_infer.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python +# Copyright 2024 Bytedance +# Apache-2.0 +# +# VERL + vLLM inference with runtime LoRA (no merge). +# - Wraps a LoRA .pt into a PEFT adapter and attaches via rollout.lora_modules +# - Mixed precision defaults for H100: dtype=bf16, kv_cache_dtype=fp8_e5m2 +# - Pins max_model_len, max_num_batched_tokens, sets swap_space +# - Uses OmegaConf.open_dict to add keys safely (no "not in struct" errors) +# - Prevents FSDP from trying to load LoRA .pt as a full model + +import os +import ast +import json +import hydra +import numpy as np +import ray +import torch +from pathlib import Path +from pprint import pprint + +# Quiet logs +os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "WARN") +os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get("TOKENIZERS_PARALLELISM", "true") + +# vLLM CuMem allocator is incompatible with expandable_segments +_bad = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") +if "expandable_segments:True" in _bad: + print(f"[fix] Removing incompatible PYTORCH_CUDA_ALLOC_CONF={_bad}") +os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None) + +import pandas as pd +from omegaconf import OmegaConf, open_dict + +from verl import DataProto +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_to_local +from verl.utils.hdfs_io import makedirs +from verl.utils.model import compute_position_id_with_mask +from verl.workers.fsdp_workers import ActorRolloutRefWorker + +# ---------------- LoRA helpers ---------------- + +DEFAULT_TARGET_MODULES = [ + "q_proj","k_proj","v_proj","o_proj", + "up_proj","gate_proj","down_proj", +] + +def _infer_lengths_and_defaults(config): + """Ensure rollout/data keys exist and set reasonable H100 defaults.""" + # Ensure nested structs exist + with open_dict(config): + if "rollout" not in config: + config["rollout"] = OmegaConf.create() + if "data" not in config: + config["data"] = OmegaConf.create() + if "trainer" not in config: + config["trainer"] = OmegaConf.create() + if "ray_init" not in config: + config["ray_init"] = OmegaConf.create() + + # Defaults that work on a single H100 + with open_dict(config.rollout): + # If user didn't set these, choose H100-friendly defaults + config.rollout.setdefault("dtype", "bfloat16") # weights/activations + config.rollout.setdefault("kv_cache_dtype", "fp8_e5m2") # KV cache precision + config.rollout.setdefault("tensor_model_parallel_size", 1) + config.rollout.setdefault("enable_chunked_prefill", True) + config.rollout.setdefault("swap_space", 8) # GB of host swap for KV + config.rollout.setdefault("gpu_memory_utilization", 0.62) # adjust 0.60~0.75 if needed + + # Pin lengths to avoid vLLM over-reserving KV cache + pl = int(config.rollout.get("prompt_length", 1024)) + rl = int(config.rollout.get("response_length", 128)) + need = int(pl + rl) + config.rollout.setdefault("max_model_len", need) + config.rollout.setdefault("max_num_batched_tokens", need) + + # Users may pass +rollout.quantization={fp8|awq|gptq} to shrink weights further + # We don't force it here. + + with open_dict(config.data): + config.data.setdefault("batch_size", 1) + config.data.setdefault("n_samples", 1) + config.data.setdefault("prompt_key", "prompt") + + with open_dict(config.trainer): + config.trainer.setdefault("n_gpus_per_node", 1) + config.trainer.setdefault("nnodes", 1) + + with open_dict(config.ray_init): + config.ray_init.setdefault("num_cpus", 4) + +def _infer_lora_rank_from_state(sd): + for k, v in sd.items(): + if k.endswith("lora_A.weight") and hasattr(v, "dim") and v.dim() == 2: + return int(v.shape[0]) + return None + +def _list_target_modules_from_state(sd): + found = set() + for k in sd.keys(): + if "lora_A.weight" in k or "lora_B.weight" in k: + if ".q_proj." in k: found.add("q_proj") + if ".k_proj." in k: found.add("k_proj") + if ".v_proj." in k: found.add("v_proj") + if ".o_proj." in k: found.add("o_proj") + if ".up_proj." in k: found.add("up_proj") + if ".gate_proj." in k: found.add("gate_proj") + if ".down_proj." in k: found.add("down_proj") + return sorted(found) + +def _write_adapter_config(adapter_dir, r, alpha, target_modules, dropout=0.0): + cfg = { + "peft_type": "LORA", + "auto_mapping": None, + "base_model_name_or_path": "", + "bias": "none", + "inference_mode": True, + "lora_alpha": int(alpha), + "lora_dropout": float(dropout), + "r": int(r), + "target_modules": target_modules, + "task_type": "CAUSAL_LM", + } + with open(os.path.join(adapter_dir, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(cfg, f, ensure_ascii=False, indent=2) + +def _wrap_lora_pt_as_peft(adapter_pt_path: str, out_dir: str, + fallback_rank=32, fallback_alpha=16): + os.makedirs(out_dir, exist_ok=True) + print(f"[lora] Loading LoRA state from: {adapter_pt_path}") + sd = torch.load(adapter_pt_path, map_location="cpu") + if isinstance(sd, dict) and "state_dict" in sd: + sd = sd["state_dict"] + + r = _infer_lora_rank_from_state(sd) or int(fallback_rank) + tmods = _list_target_modules_from_state(sd) or DEFAULT_TARGET_MODULES + print(f"[lora] inferred rank={r}, target_modules={tmods}") + + _write_adapter_config(out_dir, r=r, alpha=fallback_alpha, target_modules=tmods) + torch.save(sd, os.path.join(out_dir, "adapter_model.bin")) + return r, tmods + +def _maybe_attach_lora_adapter(config): + """Attach LoRA adapter directory to vLLM rollout (runtime LoRA).""" + # Accept either +lora.pt_path or model.load_param_path as a hint + lora_pt = None + if "lora" in config and getattr(config.lora, "pt_path", ""): + lora_pt = config.lora.pt_path + elif getattr(config.model, "load_param_path", ""): + lora_pt = config.model.load_param_path + + if not lora_pt or not Path(lora_pt).is_file(): + print("[lora] No LoRA .pt provided; running base model only.") + return + + adapter_dir = os.path.join("/tmp", "lora_adapter_vllm") + r, _ = _wrap_lora_pt_as_peft(lora_pt, adapter_dir, fallback_rank=32, fallback_alpha=16) + + # Ensure rollout keys exist and add LoRA knobs required by vLLM + with open_dict(config): + if "rollout" not in config: + config["rollout"] = OmegaConf.create() + with open_dict(config.rollout): + config.rollout.setdefault("max_loras", 1) + config.rollout.setdefault("max_lora_rank", int(r)) + config.rollout["lora_modules"] = [{"path": adapter_dir, "scale": 1.0}] + print(f"[lora] Attached PEFT adapter: {adapter_dir} (rank={r})") + + # CRITICAL: don't let FSDP try to load the LoRA .pt as a full state dict + with open_dict(config.model): + if getattr(config.model, "load_param", False): + print("[lora] Disabling model.load_param to avoid FSDP load_state_dict mismatch.") + config.model["load_param"] = False + +# ---------------- Hydra entry ---------------- + +@hydra.main(config_path="config", config_name="infer", version_base=None) +def main(config): + _infer_lengths_and_defaults(config) + + # Ray env for workers + if not ray.is_initialized(): + ray.init( + runtime_env={"env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "PYTORCH_CUDA_ALLOC_CONF": "", # keep allocator happy for vLLM + }}, + num_cpus=config.ray_init.num_cpus, + ) + + ray.get(main_task.remote(config)) + +@ray.remote(num_cpus=1) +def main_task(config): + print("[worker] PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF")) + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + # Build LoRA adapter if provided + _maybe_attach_lora_adapter(config) + + # Optionally pre-gen dataset schema if your repo provides it + try: + from prompts.infer_prompt import infer_dataset + infer_dataset( + model_name=config.model.path, + data_path=os.path.dirname(os.path.dirname(config.data.path)), + ) + except Exception as e: + print(f"[info] infer_dataset() skipped: {e}") + + # ---- Tokenizer from base model + local_path = copy_to_local(config.model.path) + trust_remote_code = getattr(config.model, "trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # ---- Sampling checks + if float(config.rollout.temperature) == 0.0: + assert int(config.data.n_samples) == 1, "When temperature=0, n_samples must be 1." + assert int(config.data.n_samples) >= 1, "n_samples should always >= 1" + + # ---- Load dataset + dataset = pd.read_parquet(config.data.path) + prompt_key = getattr(config.data, "prompt_key", "prompt") + if prompt_key not in dataset.columns: + raise KeyError(f"Dataset missing column '{prompt_key}'") + chat_lst = dataset[prompt_key].tolist() + chat_lst = [chat.tolist() if hasattr(chat, "tolist") else chat for chat in chat_lst] + + # ---- Worker group (vLLM inside Rollout) + ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") + resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) + print("[debug] rollout.lora_modules =", config.rollout.get("lora_modules", None)) + wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) + wg.init_model() # vLLM spins up; adapter used if set in rollout.lora_modules + + total = len(dataset) + bs = int(config.data.batch_size) + num_batch = -(-total // bs) + slots = [[] for _ in range(int(config.data.n_samples))] + + for b in range(num_batch): + print(f"[{b+1}/{num_batch}] Start to process.") + batch_chat = chat_lst[b * bs : (b + 1) * bs] + + inputs = tokenizer.apply_chat_template( + batch_chat, + add_generation_prompt=True, + padding=True, + truncation=True, + max_length=int(config.rollout.prompt_length), + return_tensors="pt", + return_dict=True, + tokenize=True, + ) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + position_ids = compute_position_id_with_mask(attention_mask) + batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} + + data = DataProto.from_dict(batch_dict) + data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size) + + print(f"[{b+1}/{num_batch}] Start to generate.") + for n in range(int(config.data.n_samples)): + output_padded = wg.generate_sequences(data_padded) + output = unpad_dataproto(output_padded, pad_size=pad_size) + texts = [] + for i in range(len(output)): + item = output[i] + pl = item.batch["prompts"].shape[-1] + valid_len = item.batch["attention_mask"][pl:].sum() + resp_ids = item.batch["responses"][:valid_len] + s = tokenizer.decode(resp_ids, skip_special_tokens=True) + print(f"[raw] Response {i}: {s!r}") + ix = s.find("") + if ix != -1: + s = s[ix + len("") :].lstrip() + print(f"Response {i}: {s!r}") + try: + texts.append(ast.literal_eval(s)) + except Exception: + texts.append(s) + slots[n].extend(texts) + + outputs = np.array(slots, dtype=object) + outputs = np.transpose(outputs, (1, 0)).tolist() + dataset["response"] = outputs + + keep = ["file_id", "vt", "gt", "response"] + cols = [c for c in keep if c in dataset.columns] + if cols: + dataset = dataset[cols] + + out_path = config.data.output_path + makedirs(os.path.dirname(out_path), exist_ok=True) + dataset.to_json(out_path, orient="records", lines=True, force_ascii=False) + print(f"[done] Wrote: {out_path}") + +if __name__ == "__main__": + main() diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..91f79c2b532f4f766f04b6c64612d2470e9dda8c --- /dev/null +++ b/app.py @@ -0,0 +1,43 @@ + +import gradio as gr +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + +# Load the model and tokenizer +tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto") + +def generate_response(prompt): + """ + Generates a response from the model. + """ + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + outputs = model.generate(**inputs, max_new_tokens=512) + + # Decode the generated text + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + + return generated_text + +# Create the Gradio interface +with gr.Blocks() as demo: + gr.Markdown("# Text-to-Text Generation with DeepSeek-R1-Distill-Qwen-7B") + gr.Markdown("Enter a prompt and the model will generate a response.") + + with gr.Row(): + prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="Enter your prompt here...") + + with gr.Row(): + generate_button = gr.Button("Generate") + + with gr.Row(): + response_output = gr.Textbox(label="Response", lines=8, interactive=False) + + generate_button.click( + fn=generate_response, + inputs=prompt_input, + outputs=response_output + ) + +if __name__ == "__main__": + demo.launch() diff --git a/config/infer.yaml b/config/infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b6f4e610e901b912a4a0153425eda0fbcd6e44dc --- /dev/null +++ b/config/infer.yaml @@ -0,0 +1,56 @@ +trainer: + nnodes: 1 + n_gpus_per_node: 1 + +data: + path: ./data/parquet/test.parquet + prompt_key: prompt + n_samples: 1 + output_path: ./checkpoints/grammar_generation.parquet + batch_size: 1 + +model: + path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + external_lib: null + load_param: False + load_param_path: null + +rollout: + name: vllm + mode: sync # sync: LLM, async: AsyncLLM + temperature: 0.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1.0 + max_loras: 1 + prompt_length: 1800 + response_length: 512 + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.9 # ↑ allow cache to allocate + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: 1800 # β‰₯ 1200 + 512 + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 1 + # for fire vllm rollout + use_fire_sampling: False # enable FIRE https://arxiv.org/abs/2410.21236 + # for hf rollout + do_sample: True + disable_log_stats: False + enable_chunked_prefill: True # OK because 8192 β‰₯ 3072 + n: 1 + # if beam search activated, top_k, temperature and top_p will be ignored + +actor: + strategy: fsdp # This is for backward-compatibility + ulysses_sequence_parallel_size: 1 # sp size + fsdp_config: + fsdp_size: -1 + +ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. diff --git a/prompts/__pycache__/base_instruction.cpython-311.pyc b/prompts/__pycache__/base_instruction.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5750bd44633a15b3fa2be52158fce1b498ec0d32 Binary files /dev/null and b/prompts/__pycache__/base_instruction.cpython-311.pyc differ diff --git a/prompts/__pycache__/infer_prompt.cpython-311.pyc b/prompts/__pycache__/infer_prompt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7aae43b3e972bed54dc6eb88e53700946be2a57 Binary files /dev/null and b/prompts/__pycache__/infer_prompt.cpython-311.pyc differ diff --git a/prompts/__pycache__/sft_prompt.cpython-311.pyc b/prompts/__pycache__/sft_prompt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6ccbd2fc6520f6199c7757a7c0091ff905286b1 Binary files /dev/null and b/prompts/__pycache__/sft_prompt.cpython-311.pyc differ diff --git a/prompts/base_instruction.py b/prompts/base_instruction.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5cdce19d34cd6b47fc2df42599112e6ab70672 --- /dev/null +++ b/prompts/base_instruction.py @@ -0,0 +1,24 @@ +def basic_instruction(content, modelname): + system_instruction = ( + "당신은 ν•œκ΅­μ–΄ λ¬Έμž₯ ꡐ정 μ „λ¬Έκ°€μž…λ‹ˆλ‹€. " + "μž…λ ₯ λ¬Έμž₯은 λ‹€μ–‘ν•œ 였λ₯˜(자λͺ¨ 뢄리, 철자 였λ₯˜, 단어 λˆ„λ½ λ“±)λ₯Ό 포함할 수 μžˆμŠ΅λ‹ˆλ‹€. " + "λ‹Ήμ‹ μ˜ μž„λ¬΄λŠ” μ΄λŸ¬ν•œ 잘λͺ»λœ λ¬Έμž₯을 μ™„μ „ν•˜κ³  μ˜¬λ°”λ₯Έ ν•œκ΅­μ–΄ λ¬Έμž₯으둜 λ³΅μ›ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€.\n" + "κ·œμΉ™:\n" + "β€’μΆœλ ₯은 λ°˜λ“œμ‹œ κ΅μ •λœ ν•œκ΅­μ–΄ λ¬Έμž₯만 μž‘μ„±ν•©λ‹ˆλ‹€.\n" + "β€’λΆˆν•„μš”ν•œ μ„€λͺ…, 이유, λ”°μ˜΄ν‘œλŠ” ν¬ν•¨ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.\n" + ) + + user_instruction = ( + f"잘λͺ»λœ λ¬Έμž₯(λ…Έμ΄μ¦ˆ): {content}\n\n" + "μœ„ λ¬Έμž₯을 μ˜¬λ°”λ₯Έ ν•œκ΅­μ–΄ λ¬Έμž₯으둜 κ΅μ •ν•˜μ„Έμš”.\n" + "좜λ ₯은 λ°˜λ“œμ‹œ κ΅μ •λœ λ¬Έμž₯ ν•˜λ‚˜λ§Œ μž‘μ„±ν•˜μ„Έμš”." + ) + + return [ + {"role": "system", "content": system_instruction}, + {"role": "user", "content": user_instruction}, + ] + + +def get_instruction_func(modelname): + return lambda desc, _: basic_instruction(desc, modelname) diff --git a/prompts/infer_prompt.py b/prompts/infer_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..85661050301b4ff1ac050519ea6d63507cfda60f --- /dev/null +++ b/prompts/infer_prompt.py @@ -0,0 +1,101 @@ +from prompts.base_instruction import get_instruction_func + +def infer_dataset( + model_name: str, + data_path: str, + ): + import os, json + from typing import Any, Dict, List + from datasets import Dataset + from transformers import AutoTokenizer + + MAX_TOKENS = 1200 # same as SFT + + jsonl_path = os.path.join(data_path, "jsonl") + parquet_path = os.path.join(data_path, "parquet") + os.makedirs(parquet_path, exist_ok=True) + + test_jsonl = os.path.join(jsonl_path, "test.jsonl") + + # --- robust load: tolerant JSONL/array/concatenated JSON + rows = [] + with open(test_jsonl, "r", encoding="utf-8") as f: + raw = f.read().strip() + + try: + obj = json.loads(raw) + if isinstance(obj, list): + rows = [x for x in obj if isinstance(x, dict)] + except Exception: + pass + + if not rows: + for ln in raw.replace("}{", "}\n{").splitlines(): + ln = ln.strip() + if not ln: + continue + try: + x = json.loads(ln) + if isinstance(x, dict): + rows.append(x) + except Exception: + continue + + test_dataset = Dataset.from_list(rows) + + instruction_func = get_instruction_func(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + # ─── helpers ─── + def _coerce(rec: Dict[str, Any]) -> Dict[str, Any]: + r = dict(rec) + r["vt"] = str(r.get("vt", "") or "") + return r + + def _prompt_tokens(prompt_messages) -> int: + prompt_str = tokenizer.apply_chat_template( + prompt_messages, add_generation_prompt=True, tokenize=False + ) + return len(tokenizer(prompt_str, add_special_tokens=False).input_ids) + + def make_map_fn(split: str): + def process_fn(example: Dict[str, Any], idx: int) -> Dict[str, Any]: + ex = _coerce(example) + vt = ex.get("vt", "").strip() + if not vt: + return {} + + chat_prompt = instruction_func(vt, model_name) + total_tokens = _prompt_tokens(chat_prompt) + + extra = { + "split": split, + "index": idx, + "total_tokens": int(total_tokens), + "file_id": ex.get("file_id") + } + + return { + "prompt": chat_prompt, + "extra_info": extra, + "total_tokens": int(total_tokens) + } + return process_fn + + # build prompts + token counts + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + # drop rows where prompt is empty + test_dataset = test_dataset.filter(lambda ex: bool(ex.get("prompt"))) + + # drop long prompts (> MAX_TOKENS) + n_before_len = len(test_dataset) + test_dataset = test_dataset.filter(lambda ex: ex["total_tokens"] <= MAX_TOKENS) + kept = len(test_dataset) + dropped_long = n_before_len - kept + + out_path = os.path.join(parquet_path, "test.parquet") + test_dataset.to_parquet(out_path) + + print(f"[test] kept {kept} rows, dropped_long(>{MAX_TOKENS}) {dropped_long}") + print(f"Wrote {kept} rows β†’ {out_path}") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ae0267d3b329808fb6cb5e298f5dfb8febba3ea7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +gradio +transformers +torch +accelerate \ No newline at end of file diff --git a/requirements.txt.txt b/requirements.txt.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a64143fb1cfd82f321640591db09a62393f3cb7 --- /dev/null +++ b/requirements.txt.txt @@ -0,0 +1,200 @@ +absl-py==2.3.1 +accelerate==1.6.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.11.16 +aiohttp-cors==0.8.1 +aiosignal==1.3.2 +airportsdata==20250224 +alabaster==1.0.0 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.9.0 +astor==0.8.1 +attrs==25.3.0 +babel==2.17.0 +blake3==1.0.4 +cachetools==5.5.2 +certifi==2025.1.31 +charset-normalizer==3.4.1 +click==8.1.8 +cloudpickle==3.1.1 +codetiming==1.4.0 +colorful==0.5.6 +compressed-tensors==0.9.2 +cupy-cuda12x==13.4.1 +datasets==3.5.0 +depyf==0.18.0 +dill==0.3.8 +diskcache==5.6.3 +distlib==0.3.9 +distro==1.9.0 +dnspython==2.7.0 +docker-pycreds==0.4.0 +docutils==0.21.2 +einops==0.8.1 +email_validator==2.2.0 +fastapi==0.115.12 +fastapi-cli==0.0.7 +fastrlock==0.8.3 +filelock==3.18.0 +flash_attn==2.7.4.post1 +flashinfer-python==0.2.5 +frozenlist==1.5.0 +fsspec==2024.6.1 +gguf==0.10.0 +gitdb==4.0.12 +GitPython==3.1.44 +google-api-core==2.24.2 +google-auth==2.38.0 +googleapis-common-protos==1.69.2 +grpcio==1.71.0 +h11==0.14.0 +httpcore==1.0.8 +httptools==0.6.4 +httpx==0.28.1 +huggingface-hub==0.30.2 +hydra-core==1.3.2 +idna==3.10 +imagesize==1.4.1 +importlib_metadata==8.6.1 +interegular==0.3.3 +Jinja2==3.1.6 +jiter==0.9.0 +jiwer==4.0.0 +joblib==1.5.2 +jsonschema==4.23.0 +jsonschema-specifications==2024.10.1 +lark==1.2.2 +Levenshtein==0.27.1 +liger_kernel==0.5.9 +llguidance==0.7.14 +llvmlite==0.43.0 +lm-format-enforcer==0.10.11 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +mdurl==0.1.2 +mistral_common==1.5.4 +mpmath==1.3.0 +msgpack==1.1.0 +msgspec==0.19.0 +multidict==6.4.3 +multiprocess==0.70.16 +nest-asyncio==1.6.0 +networkx==3.3 +ninja==1.11.1.4 +nltk==3.9.1 +numba==0.60.0 +numpy==1.26.4 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 +nvidia-ml-py==12.570.86 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +omegaconf==2.3.0 +openai==1.73.0 +opencensus==0.11.4 +opencensus-context==0.1.3 +opencv-python-headless==4.11.0.86 +orjson==3.10.16 +outlines==0.1.11 +outlines_core==0.1.26 +packaging==24.2 +pandas==2.2.3 +partial-json-parser==0.2.1.1.post5 +peft==0.15.1 +pillow==11.2.1 +platformdirs==4.3.7 +prometheus-fastapi-instrumentator==7.1.0 +prometheus_client==0.21.1 +propcache==0.3.1 +proto-plus==1.26.1 +protobuf==5.29.4 +psutil==7.0.0 +py-cpuinfo==9.0.0 +py-spy==0.4.0 +pyarrow==19.0.1 +pyasn1==0.6.1 +pyasn1_modules==0.4.2 +pybind11==2.13.6 +pycountry==24.6.1 +pydantic==2.11.3 +pydantic_core==2.33.1 +Pygments==2.19.1 +pylatexenc==2.10 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.0 +python-json-logger==3.3.0 +python-Levenshtein==0.27.1 +python-multipart==0.0.20 +pytz==2025.2 +PyYAML==6.0.2 +pyzmq==26.4.0 +RapidFuzz==3.14.1 +ray==2.44.1 +referencing==0.36.2 +regex==2024.11.6 +requests==2.32.3 +rich==14.0.0 +rich-toolkit==0.14.1 +roman-numerals-py==3.1.0 +rouge_score==0.1.2 +rpds-py==0.24.0 +rsa==4.9 +safetensors==0.5.3 +scipy==1.15.2 +sentencepiece==0.2.0 +sentry-sdk==2.25.1 +setproctitle==1.3.5 +shellingham==1.5.4 +six==1.17.0 +smart-open==7.1.0 +smmap==5.0.2 +sniffio==1.3.1 +snowballstemmer==2.2.0 +Sphinx==8.2.3 +sphinxcontrib-applehelp==2.0.0 +sphinxcontrib-devhelp==2.0.0 +sphinxcontrib-htmlhelp==2.1.0 +sphinxcontrib-jsmath==1.0.1 +sphinxcontrib-qthelp==2.0.0 +sphinxcontrib-serializinghtml==2.0.0 +starlette==0.46.1 +sympy==1.13.1 +tensordict==0.6.2 +tiktoken==0.9.0 +timeout-decorator==0.5.0 +tokenizers==0.21.1 +torch==2.6.0 +torchaudio==2.6.0 +torchdata==0.11.0 +torchvision==0.21.0 +tqdm==4.67.1 +transformers==4.51.2 +triton==3.2.0 +typer==0.15.2 +typing-inspection==0.4.0 +typing_extensions==4.12.2 +tzdata==2025.2 +urllib3==2.4.0 +uvicorn==0.34.0 +uvloop==0.21.0 +virtualenv==20.30.0 +vllm==0.8.2 +wandb==0.19.9 +watchfiles==1.0.5 +websockets==15.0.1 +wrapt==1.17.2 +xformers==0.0.29.post2 +xgrammar==0.1.16 +xxhash==3.5.0 +yarl==1.19.0 +zipp==3.21.0 diff --git a/scripts/sft_infer_pass1.sh b/scripts/sft_infer_pass1.sh new file mode 100644 index 0000000000000000000000000000000000000000..f7047c7ccd7a75e4c88bb1c28e70bd6ef92a6d06 --- /dev/null +++ b/scripts/sft_infer_pass1.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +#!/bin/bash +set -x + +python ./_infer.py \ + model.path=./checkpoints/model \ + model.load_param=False \ + data.path=./data/parquet/test.parquet \ + data.output_path=./model_output/sft_pass@1.jsonl \ + data.batch_size=32 data.n_samples=1 \ + rollout.tensor_model_parallel_size=1 \ + rollout.temperature=0.7 rollout.top_p=0.9 rollout.n=1 rollout.do_sample=True \ + rollout.prompt_length=1200 rollout.response_length=512 \ + rollout.enable_chunked_prefill=True \ + +rollout.kv_cache_dtype=fp8_e5m2 \ + rollout.max_model_len=1800 \ + rollout.max_num_batched_tokens=1800 \ + rollout.max_num_seqs=1 \ + +model.trust_remote_code=True \ + +rollout.kv_cache_block_size=16 \ + +rollout.swap_space=16 \ + rollout.gpu_memory_utilization=0.7 + +# python ./_infer.py \ +# model.load_param=True \ +# model.load_param_path="./checkpoints/merged_r1qwen14b/model.pt" \ +# data.output_path="./model_output/sft_pass@1.jsonl" \ +# data.n_samples=10\ +# data.path="./data/parquet/test.parquet" \ +# rollout.temperature=0.9\ +# rollout.top_p=0.9 \ +# rollout.n=1\ \ No newline at end of file diff --git a/verl/__init__.py b/verl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51fbed637e13b2972521d44397e979b74d75d525 --- /dev/null +++ b/verl/__init__.py @@ -0,0 +1,40 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from .protocol import DataProto +from .utils.logging_utils import set_basic_config + +version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) + +with open(os.path.join(version_folder, "version/version")) as f: + __version__ = f.read().strip() + + +set_basic_config(level=logging.WARNING) + + +__all__ = ["DataProto", "__version__"] + +if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true": + import importlib + + if importlib.util.find_spec("modelscope") is None: + raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`") + # Patch hub to download models from modelscope to speed up. + from modelscope.utils.hf_util import patch_hub + + patch_hub() diff --git a/verl/__pycache__/__init__.cpython-311.pyc b/verl/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71dafbf2b068353bef3932c3dff3a3ebfc84f613 Binary files /dev/null and b/verl/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/__pycache__/protocol.cpython-311.pyc b/verl/__pycache__/protocol.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c8f79feb587d0a695bccb964058926acaa50cde Binary files /dev/null and b/verl/__pycache__/protocol.cpython-311.pyc differ diff --git a/verl/models/README.md b/verl/models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2930e1c7c7cf0477ee61220d1c23acc8afd3bc0c --- /dev/null +++ b/verl/models/README.md @@ -0,0 +1,35 @@ +# Models +Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl. +## Adding a New Huggingface Model +### Step 1: Copy the model file from HF to verl +- Add a new file under verl/models/hf +- Copy ONLY the model file from huggingface/transformers/models to verl/models/hf + +### Step 2: Modify the model file to use packed inputs +- Remove all the code related to inference (kv cache) +- Modify the inputs to include only + - input_ids (total_nnz,) + - cu_seqlens (total_nnz + 1,) + - max_seqlen_in_batch: int +- Note that this requires using flash attention with causal mask. + +### Step 2.5: Add tests +- Add a test to compare this version and the huggingface version +- Following the infrastructure and add tests to tests/models/hf + +### Step 3: Add a function to apply tensor parallelism +- Please follow + - https://pytorch.org/docs/stable/distributed.tensor.parallel.html + - https://pytorch.org/tutorials/intermediate/TP_tutorial.html +- General comments + - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward. + +### Step 4: Add a function to apply data parallelism +- Please use FSDP2 APIs +- See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413 + +### Step 5: Add a function to apply pipeline parallelism +- Comes in Pytorch 2.4 +- Currently only in alpha in nightly version +- Check torchtitan for more details + diff --git a/verl/models/__init__.py b/verl/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/models/__pycache__/__init__.cpython-311.pyc b/verl/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7415bec3138d2431b5947762625380c54b90272f Binary files /dev/null and b/verl/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/models/__pycache__/registry.cpython-311.pyc b/verl/models/__pycache__/registry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02284117d1f1ef05c6049ab12c4774d8e46cccd7 Binary files /dev/null and b/verl/models/__pycache__/registry.cpython-311.pyc differ diff --git a/verl/models/llama/__init__.py b/verl/models/llama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/models/llama/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/models/llama/megatron/__init__.py b/verl/models/llama/megatron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..104c9c50351530677d53711dc35c64e7c802bb43 --- /dev/null +++ b/verl/models/llama/megatron/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_llama_megatron import ( + ParallelLlamaForCausalLM, + # rmpad with megatron + ParallelLlamaForCausalLMRmPad, + # rmpad with megatron and pipeline parallelism + ParallelLlamaForCausalLMRmPadPP, + ParallelLlamaForValueRmPad, + ParallelLlamaForValueRmPadPP, + # original model with megatron + ParallelLlamaModel, +) + +__all__ = [ + "ParallelLlamaForCausalLM", + "ParallelLlamaForCausalLMRmPad", + "ParallelLlamaForCausalLMRmPadPP", + "ParallelLlamaForValueRmPad", + "ParallelLlamaForValueRmPadPP", + "ParallelLlamaModel", +] diff --git a/verl/models/llama/megatron/checkpoint_utils/__init__.py b/verl/models/llama/megatron/checkpoint_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/models/llama/megatron/checkpoint_utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..940aff6b1ac164192dab6be9804d9c06c30c6fb8 --- /dev/null +++ b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py @@ -0,0 +1,295 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.megatron_utils import print_rank_0, unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def fetch_params(module): + for param in module.parameters(): + torch.distributed.fetch(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, (list, tuple)): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _fetch_tensor(tensor, name) -> torch.Tensor: + """fetch tensor""" + nonlocal state_dict + if tensor is not None: + tensor.data.copy_(state_dict[name]) + + def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """fetch gate_up tensor in tp shards""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if gate_name in state_dict and up_name in state_dict: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + else: + print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: + """fetch tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + if tensor is not None: + tensor.data.copy_(tensor_chunk[tp_rank]) + + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + + layer_list = [] + if vpp_size is not None: + for vpp_rank in range(vpp_size): + num_layer_vpp_chunk = num_layer_per_pp // vpp_size + num_layer_this_model = num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk) + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + else: + num_layer_this_model = num_layer_per_pp + offset = pp_rank * num_layer_per_pp + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + + for layer in layer_list: + print_rank_0(f"loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _fetch_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _fetch_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _fetch_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _fetch_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _fetch_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _fetch_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _fetch_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + print_rank_0("loading lm_head...") + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _fetch_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + else: + _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") + + dist.barrier() + torch.cuda.empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py new file mode 100644 index 0000000000000000000000000000000000000000..39c8e25d9a6202a45d01c60c7db3bcaf9acf447b --- /dev/null +++ b/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py @@ -0,0 +1,425 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.megatron_utils import print_rank_0, unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def broadcast_params(module): + for param in module.parameters(): + torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, (list, tuple)): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _broadcast_tensor(tensor, name) -> torch.Tensor: + """broadcast tensor from rank0 across mp_group""" + nonlocal state_dict + nonlocal mp_group + if torch.distributed.get_rank() == 0: + if name in state_dict: + weight = state_dict[name] + tensor_shape = weight.shape + else: + tensor_shape = None + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not in state_dict, skip load") + return + + if tensor is None: + tensor = torch.empty( + tensor_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + if torch.distributed.get_rank() == 0: + tensor.data.copy_(weight) + dist.broadcast(tensor, src=0, group=mp_group) + + def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + for layer in range(config.num_hidden_layers): + print_rank_0(f"loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + print_rank_0("loading lm_head...") + lm_head_weight = None + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _broadcast_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + else: + _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") + dist.barrier() + # Broadcast weights inside data parallel groups + for wrapped_model in wrapped_models: + broadcast_params(wrapped_model) + + torch.cuda.empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..5d55bbf674bad5c9f81c2f0f2fe5352932191d13 --- /dev/null +++ b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py @@ -0,0 +1,430 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as LocalDDP +from megatron.core.transformer.module import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from verl.utils.megatron_utils import print_rank_0, unwrap_model + + +def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): + """given TP,DP,PP rank to get the global rank.""" + + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + # We only support TP-DP-PP grouping, for correctness when resharding + return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + """Merge sharded parameters of a Megatron module into a merged checkpoint. + + Args: + wrapped_models (list of megatron.core.distributed.DistributedDataParallel): + The local DDP wrapped megatron modules. + config (str or None): + HF config for model + dtype: model params type + is_value_model: if model is value model + tie_word_embeddings: tie_word_embeddings, not used in llama, only to keep same interface with qwen2 + Returns: + state_dict (dict): + The merged state_dict in rank 0, and an empty dictionary in other ranks. + """ + start_time = time.time() + + def _get_gpt_model(model): + return model + + dp_rank = mpu.get_data_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if dist.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, (list, tuple)): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert len(models[i].model.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].model.layers), num_layers_per_model) + + state_dict = dict() + + def _get_cpu_tensor(tensor: torch.Tensor): + if tensor is None: + return None + if tensor.device == torch.device("cpu"): + return tensor.detach().clone() + return tensor.detach().cpu() + + def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: + """broadcast tensor across mp_group""" + nonlocal state_dict + nonlocal mp_group + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + if torch.distributed.get_rank() == src_rank: + if tensor is None: + weight = None + tensor_shape = None + else: + weight = tensor + tensor_shape = weight.shape + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not exist, skip collect") + return + + if weight is None: + weight = torch.empty( + tensor_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + dist.broadcast(weight, src=src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + state_dict[name] = _get_cpu_tensor(weight) + + def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=concat_dim) + if mutate_func is not None: + full_tensor = mutate_func(full_tensor) + state_dict[name] = full_tensor + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_list = [] + up_weight_list = [] + for i in range(tp_size): + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] + up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] + gate_weight_list.append(gate_weight_tp) + up_weight_list.append(up_weight_tp) + + state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) + state_dict[up_name] = torch.cat(up_weight_list, dim=0) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + q_weight_list = [] + k_weight_list = [] + v_weight_list = [] + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + k_weight_list.append(k_part) + v_weight_list.append(v_part) + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + if i * config.num_key_value_heads % tp_size == 0: + k_weight_list.append(k_part) + v_weight_list.append(v_part) + + state_dict[q_name] = torch.cat(q_weight_list, dim=0) + state_dict[k_name] = torch.cat(k_weight_list, dim=0) + state_dict[v_name] = torch.cat(v_weight_list, dim=0) + + # empty cache before collecting weights + torch.cuda.empty_cache() + # Embeddings + # ------------------- + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("collecting embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + _broadcast_tp_shard_tensor( + gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None, + "model.embed_tokens.weight", + src_pp_rank=0, + ) + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + for layer in range(config.num_hidden_layers): + print_rank_0(f"collecting layer #{layer}...") + layer_name = f"model.layers.{layer}" + src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[src_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight, + f"{layer_name}.input_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight, + f"{layer_name}.self_attn.o_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight, + f"{layer_name}.post_attention_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight, + f"{layer_name}.mlp.down_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + # Final Layernorm + # ------------------- + print_rank_0("collecting final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + src_pp_rank=pp_size - 1, + ) + + print_rank_0("collecting lm_head...") + + if is_value_model: + if pp_rank == pp_size - 1: + print(f"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}") + _broadcast_tensor( + gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + _broadcast_tensor( + gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None else None, + "reward_head.weight", + src_pp_rank=pp_size - 1, + ) + + else: + _broadcast_tp_shard_tensor( + getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + + dist.barrier() + + torch.cuda.empty_cache() + if torch.distributed.get_rank() == 0: + if dtype not in [torch.float16, torch.bfloat16, torch.float32]: + print(f'Unknown/unsupported dtype to save: {dtype}"') + exit(1) + for k, v in state_dict.items(): + if dtype != v.dtype: + state_dict[k] = v.to(dtype) + + print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") + return state_dict diff --git a/verl/models/llama/megatron/layers/__init__.py b/verl/models/llama/megatron/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89c5afc0d5eae1d4562705004e758e48e6ca03f1 --- /dev/null +++ b/verl/models/llama/megatron/layers/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .parallel_attention import ParallelLlamaAttention +from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad +from .parallel_linear import ( + LinearForLastLayer, + MergedColumnParallelLinear, + QKVParallelLinear, +) +from .parallel_mlp import ParallelLlamaMLP +from .parallel_rmsnorm import ParallelLlamaRMSNorm + +__all__ = ["LinearForLastLayer", "MergedColumnParallelLinear", "QKVParallelLinear", "ParallelLlamaAttention", "ParallelLlamaDecoderLayer", "ParallelLlamaDecoderLayerRmPad", "ParallelLlamaMLP", "ParallelLlamaRMSNorm"] diff --git a/verl/models/llama/megatron/layers/parallel_attention.py b/verl/models/llama/megatron/layers/parallel_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6c663a282670338a17b379a025dfa109711db2ea --- /dev/null +++ b/verl/models/llama/megatron/layers/parallel_attention.py @@ -0,0 +1,425 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange +from flash_attn.layers.rotary import apply_rotary_emb +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers import LlamaConfig +from transformers.utils import is_flash_attn_2_available + +from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding): + def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None): + super().__init__(dim, max_position_embeddings, base, device) + + self.factor = config.rope_scaling["factor"] # `8` in the original implementation + self.high_freq_factor = config.rope_scaling["high_freq_factor"] # `1` in the original implementation + self.low_freq_factor = config.rope_scaling["low_freq_factor"] # `4` in the original implementation + self.old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + + low_freq_wavelen = self.old_context_len / self.low_freq_factor + high_freq_wavelen = self.old_context_len / self.high_freq_factor + + wavelen = 2 * math.pi / self.inv_freq + # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / (self.high_freq_factor - self.low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ParallelLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config = config + self.megatron_config = megatron_config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # assign values after tp + tp_size = mpu.get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0, f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" + assert self.num_key_value_heads % tp_size == 0, f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}" + + self.num_heads_per_tp = self.num_heads // tp_size + self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size + self.hidden_size_per_tp = self.hidden_size // tp_size + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).") + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + + # [self.q_size, self.k_size, self.v_size] + self.qkv_proj = QKVParallelLinear( + input_size=self.hidden_size, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + bias=config.attention_bias, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + self.q_size = self.num_heads_per_tp * self.head_dim + self.k_size = self.num_key_value_heads_per_tp * self.head_dim + self.v_size = self.num_key_value_heads_per_tp * self.head_dim + + self.o_proj = tensor_parallel.RowParallelLinear( + input_size=self.num_heads * self.head_dim, + output_size=self.hidden_size, + bias=config.attention_bias, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + rope_type_key = "type" if "type" in self.config.rope_scaling else "rope_type" + scaling_type = self.config.rope_scaling[rope_type_key] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "llama3": + self.rotary_emb = LlamaLlama3ScalingRotaryEmbedding( + self.head_dim, + self.config, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): + raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) + attn_output = self.o_proj(attn_output)[0] + return attn_output + + +""" +Remove padding Attention +- Using Flash-attn 2 +- Compatible with sequence parallel +""" + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): + batch_size = position_ids.shape[0] + + q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) + k = pad_input(k, indices, batch_size, sequence_length) + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices) + k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices) + + return q_embed, k_embed + + +# use flash-attn rotary embeddings with rmpad +# cos/sin shoudl be: (seq_length, rotary_dim / 2) +def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): + q_embed = apply_rotary_emb(q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + k_embed = apply_rotary_emb(k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + return q_embed, k_embed + + +class ParallelLlamaAttentionRmPad(ParallelLlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None, + ): + total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel + + if self.megatron_config.sequence_parallel: + total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() + + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) # (total_nnz, 1, hidden_size) + + if self.megatron_config.sequence_parallel: + sequence_parallel_pad = total_nnz - cu_seqlens[-1] + total_nnz = cu_seqlens[-1] # total_nnz before sp padding + query_states = query_states[:total_nnz] + key_states = key_states[:total_nnz] + value_states = value_states[:total_nnz] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dime x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) + key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + + cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) + cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half + query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch) + # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices, + + # TODO: llama does not have dropout in the config?? + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_in_batch, + max_seqlen_k=max_seqlen_in_batch, + dropout_p=dropout_rate, + softmax_scale=None, + causal=True, + ) + + attn_output_unpad = attn_output_unpad.to(input_dtype) + attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() + + # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled + # Here we need to repad + if self.megatron_config.sequence_parallel: + attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) + + attn_output_unpad = self.o_proj(attn_output_unpad)[0] + return attn_output_unpad diff --git a/verl/models/llama/megatron/layers/parallel_decoder.py b/verl/models/llama/megatron/layers/parallel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ecab612e81f67d21b9d013e7cd8134ef9729e911 --- /dev/null +++ b/verl/models/llama/megatron/layers/parallel_decoder.py @@ -0,0 +1,150 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import LlamaConfig + +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad +from .parallel_mlp import ParallelLlamaMLP +from .parallel_rmsnorm import ParallelLlamaRMSNorm + + +class ParallelLlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config) + + self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Note: sequence parallel is hidden inside ColumnParallelLinear + # reduce scatter is hidden inside RowParallelLinear + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # TODO: add sequence parallel operator all_gather here + + hidden_states = self.mlp(hidden_states) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs + + +class ParallelLlamaDecoderLayerRmPad(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config) + + self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states # (total_nnz // sp, 1, hidden_size) + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) + # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + # shape changes same as attn + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs diff --git a/verl/models/llama/megatron/layers/parallel_linear.py b/verl/models/llama/megatron/layers/parallel_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e946227490fffec182094efae588b197d6271748 --- /dev/null +++ b/verl/models/llama/megatron/layers/parallel_linear.py @@ -0,0 +1,106 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py + +import torch +from megatron.core import tensor_parallel + + +class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + num_heads, + num_key_value_heads, + head_dim, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.q_output_size = num_heads * head_dim + self.kv_output_size = num_key_value_heads * head_dim + self.head_dim = head_dim + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + input_size = self.input_size + output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) + + +class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + gate_ouput_size, + up_output_size, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.output_size = gate_ouput_size + up_output_size + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + super().__init__( + input_size=self.input_size, + output_size=self.output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) + + +class LinearForLastLayer(torch.nn.Linear): + def __init__( + self, + input_size, + output_size, + *, + config, + bias=True, + ): + super().__init__(in_features=input_size, out_features=output_size, bias=bias) + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel: + self.weight.sequence_parallel = True + + def forward( + self, + input_, + weight=None, + runtime_gather_output=None, + ): + logits = super().forward(input_) + logits = logits.float() + if self.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits, None diff --git a/verl/models/llama/megatron/layers/parallel_mlp.py b/verl/models/llama/megatron/layers/parallel_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..d9a886b3c16372e3d8fa369e4296d19d005a47ee --- /dev/null +++ b/verl/models/llama/megatron/layers/parallel_mlp.py @@ -0,0 +1,74 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers.activations import ACT2FN + +from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class ParallelLlamaMLP(nn.Module): + def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + + tp_size = mpu.get_tensor_model_parallel_world_size() + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + gate_ouput_size=self.intermediate_size, + up_output_size=self.intermediate_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + self.gate_size = self.intermediate_size // tp_size + + self.down_proj = tensor_parallel.RowParallelLinear( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + gate_up = self.gate_up_proj(x)[0] + gate, up = gate_up.split(self.gate_size, dim=-1) + return self.down_proj(self.act_fn(gate) * up)[0] diff --git a/verl/models/llama/megatron/layers/parallel_rmsnorm.py b/verl/models/llama/megatron/layers/parallel_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..2a034e111a257391ca7da32c0c3add161e11bfab --- /dev/null +++ b/verl/models/llama/megatron/layers/parallel_rmsnorm.py @@ -0,0 +1,48 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers + +import torch +from apex.normalization.fused_layer_norm import fused_rms_norm_affine +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import LlamaConfig + +from verl.utils.megatron import sequence_parallel as sp_utils + + +class ParallelLlamaRMSNorm(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + if isinstance(config.hidden_size, numbers.Integral): + normalized_shape = (config.hidden_size,) + self.normalized_shape = torch.Size(normalized_shape) + self.weight = nn.Parameter(torch.ones(self.normalized_shape)) + self.variance_epsilon = config.rms_norm_eps + + if megatron_config.sequence_parallel: + sp_utils.mark_parameter_as_sequence_parallel(self.weight) + + def forward(self, hidden_states): + return fused_rms_norm_affine( + input=hidden_states, + weight=self.weight, + normalized_shape=self.normalized_shape, + eps=self.variance_epsilon, + memory_efficient=True, + ) diff --git a/verl/models/llama/megatron/modeling_llama_megatron.py b/verl/models/llama/megatron/modeling_llama_megatron.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5016010b5b1e4b8aa26640099aacb74ac6ce8 --- /dev/null +++ b/verl/models/llama/megatron/modeling_llama_megatron.py @@ -0,0 +1,662 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LLaMA model with Megatron-style acceleration.""" + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from megatron.core import ModelParallelConfig, mpu, tensor_parallel +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import CausalLMOutputWithPast + +from verl.utils.megatron import sequence_parallel as sp_utils +from verl.utils.megatron import tensor_parallel as tp_utils +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm + +""" +TODO: +1. Add weight initialization. Here we need to be careful on TP weight init. +2. Add sequence parallel +3. Load checkpoint from meta LLama pretrained checkpoint +""" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class ParallelLlamaModel(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + + self.layers = nn.ModuleList([ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]) + self.norm = ParallelLlamaRMSNorm(config, megatron_config) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + + Args: + input_ids: input ids. shape (batch_size, seq_length) + attention_mask: attention_mask. shape (batch_size, seq_length) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + batch_size, seq_length = input_ids.shape + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelLlamaForCausalLM(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.model = ParallelLlamaModel(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = outputs + logits = self.lm_head(hidden_states)[0] + + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + logits = logits.float() + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +class ParallelLlamaModelRmPad(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + self.megatron_config = megatron_config + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + + self.layers = nn.ModuleList([ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]) + self.norm = ParallelLlamaRMSNorm(config, megatron_config) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelLlamaForCausalLMRmPad(nn.Module): + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + self._init_head(config) + + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + logits = self.lm_head(hidden_states)[0] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + batch_size, sequence_length = input_ids.shape + + # remove padding here + input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids = sp_utils.pad_to_sequence_parallel(input_ids) + + input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = outputs + + logits = self._forward_head(hidden_states) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back + logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output = super().forward(input_ids, attention_mask, position_ids) + output.logits = torch.squeeze(output.logits, dim=-1) + return output + + +""" +Support pipeline parallelism +""" + + +class ParallelLlamaModelRmPadPP(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + This model definition supports pipeline parallelism. To support pp and vpp, + - This model only contains layer in this pp stage and vpp chunk + - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp. + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + self.megatron_config = megatron_config + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + if pre_process: + self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + else: + self.embed_tokens = None + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = megatron_config.pipeline_model_parallel_size + self.num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = megatron_config.virtual_pipeline_model_parallel_size + vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() + + if vpp_size is not None: + self.layers = nn.ModuleList() + self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size + self.num_layer_this_model = self.num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) + else: + self.num_layer_this_model = self.num_layer_per_pp + offset = pp_rank * self.num_layer_per_pp + + self.layers = nn.ModuleList() + for i in range(self.num_layer_this_model): + layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i) + self.layers.add_module(f"{i}", layer) + + if post_process: + self.norm = ParallelLlamaRMSNorm(config, megatron_config) + else: + self.norm = None + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + if self.pre_process: + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron + # so need to deal with it by handle here: + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + else: + # self.hidden_states should be passed by Megatron + hidden_states = self.input_tensor + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + if self.post_process: + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelLlamaForCausalLMRmPadPP(nn.Module): + def __init__( + self, + config: LlamaConfig, + megatron_config: ModelParallelConfig, + pre_process, + post_process, + share_embeddings_and_output_weights=False, + ): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelLlamaModelRmPadPP(config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process) + assert share_embeddings_and_output_weights is False, "Llama Model not supports sharing embedding and output weights" + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + if post_process: + self._init_head(config) + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + assert len(input_tensor) == 1 + self.model.set_input_tensor(input_tensor[0]) + + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + # logits shape before forward_head hidden_states.shape: [4, 32, 4096] + logits = self.lm_head(hidden_states)[0] + # logits shape after forward_head logits.shape: [8, 32, 8] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + return logits + + def forward( + self, + # original input + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # Note that input_ids, attention_mask and position_ids should be passed to every pp layer. + # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model + batch_size, sequence_length = input_ids.shape + # remove padding here + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) + + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids_rmpad, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + if self.post_process: + hidden_states = outputs + # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096]) + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back. If input is already rmpad, we let the caller pad_input + logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + else: + return outputs + + +class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + if self.post_process: + output.logits = torch.squeeze(output.logits, dim=-1) + return output + else: + return output diff --git a/verl/models/mcore/__init__.py b/verl/models/mcore/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b436cc32b085e49f6164cf6a601a65ba711e49f4 --- /dev/null +++ b/verl/models/mcore/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .registry import get_mcore_forward_fn, get_mcore_weight_converter, hf_to_mcore_config, init_mcore_model + +__all__ = ["hf_to_mcore_config", "init_mcore_model", "get_mcore_forward_fn", "get_mcore_weight_converter"] diff --git a/verl/models/mcore/config_converter.py b/verl/models/mcore/config_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..c5204232dd817b93e2baf51f770e767de7e884db --- /dev/null +++ b/verl/models/mcore/config_converter.py @@ -0,0 +1,197 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# convert huggingface config to mcore transformer config + + +import torch +import torch.nn.functional as F +from megatron.core.transformer import MLATransformerConfig, TransformerConfig +from transformers import PretrainedConfig + + +def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **kwargs) -> TransformerConfig: + """ + Create a base TransformerConfig with common parameters across different model architectures. + TODO: (ycl) use dataclass or converter config? + + Args: + hf_config: HuggingFace model configuration + dtype: Data type for the model + **kwargs: Additional parameters to override defaults + + Returns: + TransformerConfig with common parameters + """ + from megatron.core import parallel_state as mpu + + # Common parallel state parameters + overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 + batch_p2p_comm = False + + # Base configuration with common parameters + base_config = { + # Model architecture parameters + "num_layers": hf_config.num_hidden_layers, + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_query_groups": hf_config.num_key_value_heads, + "ffn_hidden_size": hf_config.intermediate_size, + "attention_dropout": hf_config.attention_dropout, + "hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0), + "kv_channels": getattr(hf_config, "head_dim", None), + "layernorm_epsilon": hf_config.rms_norm_eps, + # Activation and normalization + "activation_func": F.silu, + "normalization": "RMSNorm", + "gated_linear_unit": True, + # Data types + "pipeline_dtype": dtype, + "params_dtype": dtype, + "bf16": dtype is torch.bfloat16, + # Parallel configuration + "tensor_model_parallel_size": mpu.get_tensor_model_parallel_world_size(), + "pipeline_model_parallel_size": mpu.get_pipeline_model_parallel_world_size(), + "virtual_pipeline_model_parallel_size": mpu.get_virtual_pipeline_model_parallel_world_size(), + "context_parallel_size": mpu.get_context_parallel_world_size(), + "overlap_p2p_comm": overlap_p2p_comm, + "batch_p2p_comm": batch_p2p_comm, + "sequence_parallel": mpu.get_tensor_model_parallel_world_size() > 1, + # Common settings + "variable_seq_lengths": True, + "masked_softmax_fusion": True, + "moe_token_dispatcher_type": "alltoall", + } + + # Update with any provided overrides + base_config.update(kwargs) + print(f"Overridden TF init config: {base_config}") + + return TransformerConfig(**base_config) + + +def hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: + # for LlamaForCausalLM or Qwen2ForCausalLM + qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) + qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False + + return _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + add_qkv_bias=qkv_bias, + qk_layernorm=qk_layernorm, + ) + + +def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: + return _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + layernorm_epsilon=hf_config.rms_norm_eps, + # MoE specific + moe_ffn_hidden_size=hf_config.moe_intermediate_size, + moe_router_bias_update_rate=0.001, + moe_router_topk=hf_config.num_experts_per_tok, + num_moe_experts=hf_config.num_experts, + moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size, + moe_aux_loss_coeff=hf_config.router_aux_loss_coef, + # moe_aux_loss_coeff=0.0, + moe_router_load_balancing_type="aux_loss", + moe_shared_expert_overlap=True, + moe_grouped_gemm=True, + moe_router_score_function="softmax", + # Other optimizations + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + # Qwen specific + moe_router_pre_softmax=True, + add_qkv_bias=True, + ) + + +def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: + return _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + layernorm_epsilon=hf_config.rms_norm_eps, + # MoE specific + num_moe_experts=hf_config.num_local_experts, + moe_aux_loss_coeff=hf_config.router_aux_loss_coef, + moe_router_topk=hf_config.num_experts_per_tok, + moe_router_pre_softmax=True, + moe_router_load_balancing_type="aux_loss", + moe_router_score_function="softmax", + moe_shared_expert_intermediate_size=None, # mixtral has no shared expert + moe_shared_expert_overlap=False, # mixtral has no shared expert + moe_ffn_hidden_size=hf_config.intermediate_size, + moe_router_bias_update_rate=0.001, + # moe_permute_fusion=True, # need TE 2.1+ + moe_grouped_gemm=True, + # Other optimizations + persist_layer_norm=True, + apply_rope_fusion=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + ) + + +def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: + return _get_base_transformer_config( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + layernorm_epsilon=hf_config.rms_norm_eps, + # MoE specific + moe_ffn_hidden_size=hf_config.moe_intermediate_size, + moe_router_bias_update_rate=0.001, + moe_router_topk=hf_config.num_experts_per_tok, + num_moe_experts=hf_config.num_experts, + moe_aux_loss_coeff=hf_config.router_aux_loss_coef, + # moe_aux_loss_coeff=0.0, + moe_router_load_balancing_type="aux_loss", + moe_grouped_gemm=True, + moe_router_score_function="softmax", + # Other optimizations + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + # Qwen specific + moe_router_pre_softmax=True, + qk_layernorm=True, + ) + + +def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype) -> MLATransformerConfig: + # DeepseekV3ForCausalLM + raise NotImplementedError("DeepseekV3ForCausalLM is not supported yet") + + +def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: + # Qwen2_5_VLForConditionalGeneration + raise NotImplementedError("Qwen2_5_VLForConditionalGeneration is not supported yet") + + +def hf_to_mcore_config_llama4(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: + # Llama4ForConditionalGeneration + raise NotImplementedError("Llama4ForConditionalGeneration is not supported yet") diff --git a/verl/models/mcore/loader.py b/verl/models/mcore/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..55acb0722dcc19932eb532522ebc2a3e19456bd6 --- /dev/null +++ b/verl/models/mcore/loader.py @@ -0,0 +1,468 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + +from .saver import _megatron_calc_global_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.megatron_utils import print_rank_0, unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def broadcast_params(module): + for param in module.parameters(): + torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + cp_rank = mpu.get_context_parallel_rank() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank) + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == src_rank: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, (list, tuple)): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.decoder.layers) == num_layers_per_model + + def _broadcast_tensor(tensor, name) -> torch.Tensor: + """broadcast tensor from rank0 across mp_group""" + nonlocal state_dict + nonlocal mp_group + if torch.distributed.get_rank() == src_rank: + if name in state_dict: + weight = state_dict[name] + tensor_shape = weight.shape + else: + tensor_shape = None + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not in state_dict, skip load") + return + + if tensor is None: + tensor = torch.empty( + tensor_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + if torch.distributed.get_rank() == src_rank: + tensor.data.copy_(weight) + dist.broadcast(tensor, src=src_rank, group=mp_group) + + def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + if name in state_dict: + full_weight = state_dict[name] + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == src_rank: + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + sizes = [total_size * tp_size] + if not bias: + sizes.append(config.hidden_size) + new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + num_query_groups_per_partition = models[0].config.num_query_groups // tp_size + new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] + q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0) + k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0) + v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0) + total_size_per_head = total_size // num_query_groups_per_partition + for j in range(num_query_groups_per_partition): + new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + sizes = [total_size * tp_size] + if not bias: + sizes.append(config.hidden_size) + new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] + q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0) + k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0) + v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0) + total_size_per_head = total_size // config.num_attention_heads + for j in range(config.num_attention_heads): + new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0)) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == src_rank: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=src_rank, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight + _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + for layer in range(config.num_hidden_layers): + print_rank_0(f"loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.decoder.layers[dst_layer_idx] + + _broadcast_tensor( + sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + if f"{layer_name}.self_attn.q_norm.weight" in state_dict: + _broadcast_tensor( + sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_norm.weight", + ) + _broadcast_tensor( + sync_layer.self_attention.k_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.k_norm.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + if f"{layer_name}.self_attn.q_proj.bias" in state_dict: + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + _broadcast_tensor( + sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.decoder.final_layernorm, "weight", None), + "model.norm.weight", + ) + + print_rank_0("loading lm_head...") + lm_head_weight = None + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.output_layer.weight + + if is_value_model: + # if torch.distributed.get_rank() == src_rank: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "lm_head.weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _broadcast_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + # else: + + # _broadcast_tensor(lm_head_weight, "lm_head.weight") + + else: + _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") + dist.barrier() + # Broadcast weights inside data parallel groups + for wrapped_model in wrapped_models: + broadcast_params(wrapped_model) + pass + torch.cuda.empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..1cbd15446e4ffd98896f6f78195659d6b4f44b0d --- /dev/null +++ b/verl/models/mcore/model_forward.py @@ -0,0 +1,50 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from verl.utils.megatron_utils import unwrap_model + +from .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding + + +def gptmodel_forward(model, input_ids, attention_mask, position_ids, sequence_parallel, value_model=False, pack_seqs=True): + """Default forward pass for GPT models with optional sequence packing.""" + pre_process = unwrap_model(model).pre_process + post_process = unwrap_model(model).post_process + if pack_seqs: + batch_size, seq_len = attention_mask.shape[:2] + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) + input_ids_rmpad = input_ids_rmpad.contiguous() + output_orig = model( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + ) + + output = postprocess_packed_seqs(output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process) + else: + batch_size, sequence_length = attention_mask.shape + new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process) + output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids) + output = recover_left_padding(output, new_attention_mask, attention_mask, sequence_length, post_process=post_process) + if value_model and post_process: + output = output[..., 0] + return output + + +def gptmodel_forward_qwen2_5_vl(*args, **kwargs): + """Forward pass for Qwen2.5 VL model (not implemented).""" + raise NotImplementedError("VLM is not supported yet") diff --git a/verl/models/mcore/model_initializer.py b/verl/models/mcore/model_initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..debbf10e75335053b3f315c5110b2c0a6a7dbefe --- /dev/null +++ b/verl/models/mcore/model_initializer.py @@ -0,0 +1,160 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# use mcore transformer config to initialize the model +from abc import ABC, abstractmethod + +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.models.gpt.gpt_model import GPTModel + +from .config_converter import PretrainedConfig, TransformerConfig + + +class BaseModelInitializer(ABC): + """Base class for model initializers.""" + + def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig): + self.tfconfig = tfconfig + self.hf_config = hf_config + + @abstractmethod + def get_transformer_layer_spec(self): + """Get the transformer layer specification. + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py""" + pass + + def get_rope_scaling_args(self) -> dict: + """Get rope scaling args.""" + rope_scaling_args = {} + if "rope_scaling" in self.hf_config: + if self.hf_config.rope_scaling is not None: + assert self.hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now" + rope_scaling_args["seq_len_interpolation_factor"] = self.hf_config.rope_scaling["factor"] + return rope_scaling_args + + def initialize( + self, + pre_process: bool = True, + post_process: bool = True, + share_embeddings_and_output_weights: bool = False, + value: bool = False, + **extra_kwargs, + ) -> GPTModel: + """Initialize a GPT model with the given configuration. + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py + + Args: + pre_process (bool): include embedding layer. + post_process (bool): including an output layer. + share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared. + value (bool): add an extra linear layer for classification or regression. + + Returns: + GPTModel: An initialized GPT model instance + """ + transformer_layer_spec = self.get_transformer_layer_spec() + rope_scaling_args = self.get_rope_scaling_args() + + model = GPTModel( + config=self.tfconfig, + transformer_layer_spec=transformer_layer_spec, + vocab_size=self.hf_config.vocab_size, + max_sequence_length=self.hf_config.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + position_embedding_type="rope", + rotary_base=self.hf_config.rope_theta, + **rope_scaling_args, + ) + + if post_process and value: + from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer + + model.output_layer = LinearForLastLayer(input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig) + + return model + + +class DenseModel(BaseModelInitializer): + """Initializer for dense models like Llama and Qwen2.""" + + def get_transformer_layer_spec(self): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + + +class Qwen2MoEModel(BaseModelInitializer): + """Initializer for Qwen2 MoE models.""" + + def get_transformer_layer_spec(self): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + + # Patch layer spec for shared experts + for i in range(len(transformer_layer_spec.layer_specs)): + transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params["gate"] = True + + return transformer_layer_spec + + def initialize(self, freeze_moe_router: bool = True, **kwargs): + # Qwen default freeze_moe_router: true + model = super().initialize(**kwargs) + if freeze_moe_router: + for layer in model.decoder.layers: + layer.mlp.router.weight.requires_grad = False + layer.mlp.shared_experts.gate_weight.requires_grad = False + return model + + +class MixtralModel(BaseModelInitializer): + """Initializer for Mixtral models.""" + + def get_transformer_layer_spec(self): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + return transformer_layer_spec + + def initialize(self, freeze_moe_router: bool = False, **kwargs): + model = super().initialize(**kwargs) + if freeze_moe_router: + for layer in model.decoder.layers: + layer.mlp.router.weight.requires_grad = False + return model + + +class Qwen3MoEModel(BaseModelInitializer): + """Initializer for Qwen3 MoE models.""" + + def get_transformer_layer_spec(self): + assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + return transformer_layer_spec + + def initialize(self, freeze_moe_router: bool = True, **kwargs): + # Qwen default freeze_moe_router: true + model = super().initialize(**kwargs) + if freeze_moe_router: + for layer in model.decoder.layers: + layer.mlp.router.weight.requires_grad = False + return model + + +class Qwen25VLModel(BaseModelInitializer): + """Initializer for Qwen2.5 VL models.""" + + def get_transformer_layer_spec(self): + raise NotImplementedError("VLM is not supported yet") diff --git a/verl/models/mcore/readme.md b/verl/models/mcore/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..e4ab7227b7720a625f28afca6b447759b35d986b --- /dev/null +++ b/verl/models/mcore/readme.md @@ -0,0 +1,99 @@ +# verl Megatron-Core Models +The earlier versions of verl use `Megatron-LM` 0.4 and workaround huggingface model classes. To better use the latest features and speedup of modern Megatron, we are migrating to `Megatron-Core`(mcore), and use the recommended `GPTModel` class for all language models. With mcore `GPTModel`, we can use the latest features like `context parallel`, `expert parallel`, `dist_checkpointing`, etc. and we can update mcore with little effort in the future for new features. + +The migration has been successful with the help of the mcore team and the community. What we have done is: +1. update `Megatron` version to `0.11.0` +2. migrate `LlamaForCausalLM` and `Qwen2ForCausalLM` to mcore `GPTModel` +3. support sequence packing/thd format. +4. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`. +5. support the mcore `dist_checkpointing` feature and a basic offline weighs conversion scipt from huggingface to mcore `dist_checkpointing` format. + +We are working on the following features: +- support `Qwen2MoeForCausalLM` +- support `MixtralForCausalLM` +- support `DeepseekV3ForCausalLM` +- support `expert parallel` + +Features we invite the community to contribute: +- better scipts for offline weights conversion from huggingface to mcore `dist_checkpointing` format. + - conversion of large models with multiple GPUs + - conversion of large models with single GPU +- refactor the `megatron_checkpoint_manager.py` by `dist_checkpointing` format. +- support llama4 +- support qwen2.5-vl + +To track the progress of verl mcore integration, please refer to the [mcore integration issue](https://github.com/volcengine/verl/issues/1033). + +## How things work now +To engage the community in contributing, here are the key steps in our mcore integration process and features under development. + +The huggingface `transformers` is the de facto standard of model zoo while mcore is good at computation efficiency. The main challenge is conversion between the two. +main steps: +1. modelling the huggingface model with mcore `GPTModel` + - a. convert the huggingface config to mcore `TransformerConfig` + - b. init the mcore `GPTModel` with the converted config + - c. load the huggingface model weights to the `GPTModel` +2. online weight conversion from mcore to huggingface (due the the rollout engine `vLLM` is using huggingface format) + - a. bridge the gap between mcore and huggingface weights format and name mapping + - b. online resharding the mcore weights to rollout engine + - this part is very complicated with multiple parallel strategies composition between mcore and rollout engine +3. support the mcore features in verl + - a. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel` + - b. support recompute and other mcore speed up features + +4. checkpointing + - a. support recovering the verl training. + - b. support exporting the mcore checkpoint to huggingface format, for downstream inference. + +### Modelling the huggingface model with mcore `GPTModel` +The first step is to convert huggingface config to mcore `TransformerConfig` and init the mcore `GPTModel` with the converted config. See code in `verl/models/mcore/config_converter.py` and `verl/verl/models/mcore/models/model_initializer.py`. The corresponding model forward code is in `verl/verl/models/mcore/models/model_forward.py`. + +There are two ways of loading the huggingface model weights to the `GPTModel` +1. Runtime loading + - every rank loads the entire huggingface model weights and then shard and convert to mcore weights. + - speed is slow and memory consumption is high. + - this way is deprecated and will not support new models. +2. Offline loading + - use offline script to convert the huggingface model weights to mcore weights and save with mcore `dist_checkpointing` format. + - online loading and sharding is automatically done by mcore `dist_checkpointing` format. The speed is fast and memory consumption is low. + - the offline script is in `verl/scripts/converter_hf_to_mcore.py`. + +### online weight conversion from mcore to huggingface +See function `convert_megatron_model_to_transformers_model` in `verl/utils/megatron_utils.py` for the details. + +It should be refatored for extensibility and better performance. + +### support the mcore features in verl +Most of the features of `GPTModel` is out-of-the-box supported in verl through changing the `TransformerConfig`, except those about parallel strategies, such as `expert parallel`. +Features about parallel strategies should be supported with changes about the online weights conversion(especially the resharding part) and verl work dispatching. + +### checkpointing +The existing checkpointing code is in `verl/utils/checkpoint/megatron_checkpoint_manager.py`. And the script to convert checkpoint to huggingface format is in `verl/scripts/model_merger.py`. + +The existing checkpoint format is simplely save every rank's weights and optimizer states. It should be refactored by `dist_checkpointing` format. + + +## How to support new models +1. make sure the model is supported by vLLM +2. modelling the huggingface model with mcore `GPTModel` (The [Pai-Megatron-Path](https://github.com/alibaba/Pai-Megatron-Patch/tree/main) is a good reference) + - a. convert the huggingface config to mcore `TransformerConfig` + - b. init the mcore `GPTModel` with the converted config + - c. load the huggingface model weights to the `GPTModel` + - d. for VLM the interface might be different, it is ok to add a new model class with GPTModel as its module. +3. offline weights conversion from huggingface to mcore `dist_checkpointing` format +4. support online weights conversion from mcore to huggingface + - it is recommended to initilize a vLLM model with the converted mcore weights, and then test if the generating sequence is correct. + + +## How to scale up to larger models like deepseek-v3 or other 100B+ models +The greatest challenge for scaling up to larger models is the memory consumption. + +The necessary features under development for scaling up are +1. Training engine part + - expert parallel +2. Rollout engine part + - pipeline parallel + - expert parallel + - more efficient and general weight resharding and loading +3. Offline weights conversion + - support weights larger then single GPU memory diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d2590481f6621fe411998a09a7335a9640a14839 --- /dev/null +++ b/verl/models/mcore/registry.py @@ -0,0 +1,179 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Registry module for model architecture components. +""" + +from enum import Enum +from typing import Callable, Dict, Type + +import torch +import torch.nn as nn + +from .config_converter import ( + PretrainedConfig, + TransformerConfig, + hf_to_mcore_config_dense, + hf_to_mcore_config_dpskv3, + hf_to_mcore_config_llama4, + hf_to_mcore_config_mixtral, + hf_to_mcore_config_qwen2_5_vl, + hf_to_mcore_config_qwen2moe, + hf_to_mcore_config_qwen3moe, +) +from .model_forward import ( + gptmodel_forward, +) +from .model_initializer import ( + BaseModelInitializer, + DenseModel, + MixtralModel, + Qwen2MoEModel, + Qwen3MoEModel, + Qwen25VLModel, +) +from .weight_converter import ( + McoreToHFWeightConverterDense, + McoreToHFWeightConverterMixtral, + McoreToHFWeightConverterQwen2Moe, + McoreToHFWeightConverterQwen3Moe, +) + + +class SupportedModel(Enum): + LLAMA = "LlamaForCausalLM" # tested + QWEN2 = "Qwen2ForCausalLM" # tested + QWEN2_MOE = "Qwen2MoeForCausalLM" # pending + DEEPSEEK_V3 = "DeepseekV3ForCausalLM" # not tested + MIXTRAL = "MixtralForCausalLM" # tested + QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" # not supported + LLAMA4 = "Llama4ForConditionalGeneration" # not tested + QWEN3 = "Qwen3ForCausalLM" # tested + QWEN3_MOE = "Qwen3MoeForCausalLM" # not tested + + +# Registry for model configuration converters +MODEL_CONFIG_CONVERTER_REGISTRY: Dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = { + SupportedModel.LLAMA: hf_to_mcore_config_dense, + SupportedModel.QWEN2: hf_to_mcore_config_dense, + SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe, + SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3, + SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral, + SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl, + SupportedModel.LLAMA4: hf_to_mcore_config_llama4, + SupportedModel.QWEN3: hf_to_mcore_config_dense, + SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe, +} + +# Registry for model initializers +MODEL_INITIALIZER_REGISTRY: Dict[SupportedModel, Type[BaseModelInitializer]] = { + SupportedModel.LLAMA: DenseModel, + SupportedModel.QWEN2: DenseModel, + SupportedModel.QWEN2_MOE: Qwen2MoEModel, + SupportedModel.MIXTRAL: MixtralModel, + SupportedModel.DEEPSEEK_V3: DenseModel, + SupportedModel.QWEN2_5_VL: Qwen25VLModel, + SupportedModel.LLAMA4: DenseModel, + SupportedModel.QWEN3: DenseModel, + SupportedModel.QWEN3_MOE: Qwen3MoEModel, +} + +# Registry for model forward functions +MODEL_FORWARD_REGISTRY: Dict[SupportedModel, Callable] = { + SupportedModel.LLAMA: gptmodel_forward, + SupportedModel.QWEN2: gptmodel_forward, + SupportedModel.QWEN2_MOE: gptmodel_forward, + SupportedModel.MIXTRAL: gptmodel_forward, + SupportedModel.DEEPSEEK_V3: gptmodel_forward, + SupportedModel.QWEN2_5_VL: gptmodel_forward, + SupportedModel.LLAMA4: gptmodel_forward, + SupportedModel.QWEN3: gptmodel_forward, + SupportedModel.QWEN3_MOE: gptmodel_forward, +} + +# Registry for model weight converters +MODEL_WEIGHT_CONVERTER_REGISTRY: Dict[SupportedModel, Type] = { + SupportedModel.LLAMA: McoreToHFWeightConverterDense, + SupportedModel.QWEN2: McoreToHFWeightConverterDense, + SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe, + SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral, + SupportedModel.QWEN3: McoreToHFWeightConverterDense, + SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe, +} + + +def get_supported_model(model_type: str) -> SupportedModel: + try: + return SupportedModel(model_type) + except ValueError as err: + supported_models = [e.value for e in SupportedModel] + raise NotImplementedError(f"Model Type: {model_type} not supported. Supported models: {supported_models}") from err + + +def hf_to_mcore_config(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype) + + +def init_mcore_model( + tfconfig: TransformerConfig, + hf_config: PretrainedConfig, + pre_process: bool = True, + post_process: bool = None, + *, + share_embeddings_and_output_weights: bool = False, + value: bool = False, + **extra_kwargs, # may be used for vlm and moe +) -> nn.Module: + """ + Initialize a Mcore model. + + Args: + tfconfig: The transformer config. + hf_config: The HuggingFace config. + pre_process: Optional pre-processing function. + post_process: Optional post-processing function. + share_embeddings_and_output_weights: Whether to share embeddings and output weights. + value: Whether to use value. + **extra_kwargs: Additional keyword arguments. + + Returns: + The initialized model. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + initializer_cls = MODEL_INITIALIZER_REGISTRY[model] + initializer = initializer_cls(tfconfig, hf_config) + return initializer.initialize(pre_process=pre_process, post_process=post_process, share_embeddings_and_output_weights=share_embeddings_and_output_weights, value=value, **extra_kwargs) + + +def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable: + """ + Get the forward function for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + return MODEL_FORWARD_REGISTRY[model] + + +def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable: + """ + Get the weight converter for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + model = get_supported_model(hf_config.architectures[0]) + tfconfig = hf_to_mcore_config(hf_config, dtype) + return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig) diff --git a/verl/models/mcore/saver.py b/verl/models/mcore/saver.py new file mode 100644 index 0000000000000000000000000000000000000000..bb8bdfbfada3ef6cf51b1849ed4635aa7dc7ef88 --- /dev/null +++ b/verl/models/mcore/saver.py @@ -0,0 +1,459 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as LocalDDP +from megatron.core.transformer.module import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from verl.utils.megatron_utils import print_rank_0, unwrap_model + + +def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0): + """Calculate global rank with support for CP/EP parallelism""" + + # Get parallel sizes for each dimension + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + # ep_size = mpu.get_expert_model_parallel_world_size() + + # Verify total GPU count matches (must be consistent with parallel_state.py) + total_size = tp_size * dp_size * pp_size * cp_size + assert total_size == torch.distributed.get_world_size(), f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}" + + # Core calculation logic (corresponds to RankGenerator order parameter) + # Assumes default order is "tp-cp-ep-dp-pp" + return ((pp_rank * dp_size + dp_rank) * cp_size + cp_rank) * tp_size + tp_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + """Merge sharded parameters of a Megatron module into a merged checkpoint. + + Args: + wrapped_models (list of megatron.core.distributed.DistributedDataParallel): + The local DDP wrapped megatron modules. + config (str or None): + HF config for model + dtype: model params type + is_value_model: if model is value model + tie_word_embeddings: tie_word_embeddings + Returns: + state_dict (dict): + The merged state_dict in rank 0, and an empty dictionary in other ranks. + """ + start_time = time.time() + + def _get_gpt_model(model): + return model + + dp_rank = mpu.get_data_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + cp_rank = mpu.get_context_parallel_rank() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if dist.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, (list, tuple)): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert len(models[i].decoder.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].decoder.layers), num_layers_per_model) + + state_dict = dict() + + def _get_cpu_tensor(tensor: torch.Tensor): + if tensor is None: + return None + if tensor.device == torch.device("cpu"): + return tensor.detach().clone() + return tensor.detach().cpu() + + def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: + """broadcast tensor across mp_group""" + nonlocal state_dict + nonlocal mp_group + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + if torch.distributed.get_rank() == src_rank: + if tensor is None: + weight = None + tensor_shape = None + else: + weight = tensor + tensor_shape = weight.shape + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not exist, skip collect") + return + + if weight is None: + weight = torch.empty( + tensor_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + dist.broadcast(weight, src=src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + state_dict[name] = _get_cpu_tensor(weight) + + def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + # tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=concat_dim) + if mutate_func is not None: + full_tensor = mutate_func(full_tensor) + state_dict[name] = full_tensor + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + # tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_list = [] + up_weight_list = [] + for i in range(tp_size): + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] + up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] + gate_weight_list.append(gate_weight_tp) + up_weight_list.append(up_weight_tp) + + state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) + state_dict[up_name] = torch.cat(up_weight_list, dim=0) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + # tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + q_weight_list = [] + k_weight_list = [] + v_weight_list = [] + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_size_chunk = q_size_tp // num_query_groups_per_partition + kv_size_chunk = kv_size_tp // num_query_groups_per_partition + for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): + q_part = qkv_part_chunk[:q_size_chunk] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] + q_weight_list.append(q_part) + k_weight_list.append(k_part) + v_weight_list.append(v_part) + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_size_chunk = q_size_tp // num_query_groups_per_partition + kv_size_chunk = kv_size_tp // num_query_groups_per_partition + for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): + q_part = qkv_part_chunk[:q_size_chunk] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] + q_weight_list.append(q_part) + if i * config.num_key_value_heads % tp_size == 0: + k_weight_list.append(k_part) + v_weight_list.append(v_part) + + state_dict[q_name] = torch.cat(q_weight_list, dim=0) + state_dict[k_name] = torch.cat(k_weight_list, dim=0) + state_dict[v_name] = torch.cat(v_weight_list, dim=0) + + # empty cache before collecting weights + torch.cuda.empty_cache() + # Embeddings + # ------------------- + if dp_rank == 0 and cp_rank == 0: # models are identical across cp ranks + # Embeddings + # ------------------- + print_rank_0("collecting embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + _broadcast_tp_shard_tensor( + gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None, + "model.embed_tokens.weight", + src_pp_rank=0, + ) + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + for layer in range(config.num_hidden_layers): + print_rank_0(f"collecting layer #{layer}...") + layer_name = f"model.layers.{layer}" + src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) + sync_layer = gpt_model_module.decoder.layers[src_layer_idx] + + _broadcast_tensor( + sync_layer.self_attention.linear_qkv.layer_norm_weight, + f"{layer_name}.input_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.weight, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + src_pp_rank=src_pp_rank, + ) + + if getattr(sync_layer.self_attention.linear_qkv, "bias", None) is not None and sync_layer.self_attention.linear_qkv.bias.numel() > 0: + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attention.linear_qkv.bias, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attention.linear_proj.weight, + f"{layer_name}.self_attn.o_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + _broadcast_tensor( + sync_layer.mlp.linear_fc1.layer_norm_weight, + f"{layer_name}.post_attention_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.linear_fc1.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.linear_fc2.weight, + f"{layer_name}.mlp.down_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + # Final Layernorm + # ------------------- + print_rank_0("collecting final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.decoder.final_layernorm, "weight", None), + "model.norm.weight", + src_pp_rank=pp_size - 1, + ) + + if tie_word_embeddings: + print_rank_0("tie word embedding skip load lm_head...") + else: + print_rank_0("collecting lm_head...") + + if is_value_model: + lm_head_weight = None + if pp_rank == pp_size - 1: + lm_head_weight = getattr(gpt_model_module.output_layer, "weight", None) + _broadcast_tensor(lm_head_weight, "lm_head.weight", src_pp_rank=pp_size - 1) + + else: + _broadcast_tp_shard_tensor( + getattr(gpt_model_module.output_layer, "weight", None) if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + + dist.barrier() + torch.cuda.empty_cache() + if torch.distributed.get_rank() == 0: + for k, v in state_dict.items(): + if dtype != v.dtype: + state_dict[k] = v.to(dtype) + + print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") + return state_dict + + +def merge_megatron_ckpt_gptmodel_qwen_moe(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen_moe is not implemented") + + +def merge_megatron_ckpt_gptmodel_mixtral(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + raise NotImplementedError("merge_megatron_ckpt_gptmodel_mixtral is not implemented") diff --git a/verl/models/mcore/util.py b/verl/models/mcore/util.py new file mode 100644 index 0000000000000000000000000000000000000000..38ea60575b2b70c0fa2aaec4ed59e5d5b24e798e --- /dev/null +++ b/verl/models/mcore/util.py @@ -0,0 +1,190 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import parallel_state as mpu +from megatron.core.packed_seq_params import PackedSeqParams + + +def preprocess_packed_seqs(input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True) -> tuple[torch.Tensor, PackedSeqParams]: + """ + Preprocess packed sequences + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 gets second and second last chunks, and so on), this is for load balancing with causal masking. + See https://github.com/NVIDIA/TransformerEngine/issues/1368 + """ + batch_size = input_ids.shape[0] + + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + + pad_size = (align_size - seqlens_in_batch % align_size) % align_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + max_seqlen_in_batch = seqlens_in_batch_padded.max().item() + + shape = list(input_ids.shape[1:]) + shape[0] = seqlens_in_batch_padded.sum().item() // cp_size + if pre_process: + input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) + for i in range(batch_size): + if cp_size <= 1: + seqlen = seqlens_in_batch[i] + input_ids_rmpad[cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]] + continue + seqlen = seqlens_in_batch_padded[i] // cp_size + half_seqlen = seqlen // 2 + start_idx = cu_seqlens_padded[i] // cp_size + # split to 2 chunks + d = input_ids[i, attention_mask[i]] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)] + + remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1) + remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank + remain_end = min(remain_end, d.shape[0]) + remain_len = remain_end - remain_start + if remain_len > 0: + input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[remain_start:remain_end] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + if pre_process: + return input_ids_rmpad.unsqueeze(0), packed_seq_params + else: + return input_ids, packed_seq_params + + +def postprocess_packed_seqs( + output: torch.Tensor, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> torch.Tensor: + """ + Postprocess packed sequences + """ + if not post_process: + return output + shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim + output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) + + cp_size = mpu.get_context_parallel_world_size() + # all gather output across context parallel group + if cp_size > 1: + # output shape: [1, packed_len, hidden_dim] + # need to gather across cp group and concatenate in sequence dimension + output_list = [torch.empty_like(output) for _ in range(cp_size)] + torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) + output_list[mpu.get_context_parallel_rank()] = output + else: + output_list = [output] + for i in range(batch_size): + if cp_size <= 1: + s = attention_mask[i].sum().item() + output_new[i, attention_mask[i]] = output[0][packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s] + continue + s_len_padded_chunk = (packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i]) // cp_size + half_seqlen = s_len_padded_chunk // 2 + s_len = attention_mask[i].sum().item() + s_len_padded = s_len_padded_chunk * cp_size + tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device) + for j in range(cp_size): + o = output_list[j][0] + # split to 2 chunks + packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size + o0, o1 = ( + o[packed_start_idx : packed_start_idx + half_seqlen], + o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], + ) + tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 + tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 + output_new[i, attention_mask[i]] = tmp[:s_len] + + return output_new + + +def remove_left_padding( + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + sequence_parallel: bool = False, + pre_process: bool = True, +): + """ + Remove left padding from input_ids, attention_mask and position_ids + return new_input_ids, new_attention_mask, new_position_ids + """ + assert attention_mask.ndim == 2 + assert position_ids.ndim == 2 + cp_size = mpu.get_context_parallel_world_size() + assert cp_size == 1, "Context parallel size without seq_pack is not supported" + batch_size = input_ids.shape[0] + shape = list(input_ids.shape) # batch_size, seq_len,... + seq_lens = attention_mask.sum(dim=1) + seq_len = seq_lens.max().item() + if sequence_parallel: + sp_world_size = mpu.get_tensor_model_parallel_world_size() + pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size + seq_len = seq_len + pad_size + shape[1] = seq_len + if pre_process: + new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape) + new_attention_mask = torch.zeros(dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len)) + new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len)) + for i in range(batch_size): + if pre_process: + new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]] + new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]] + new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]] + if pre_process: + return new_input_ids, new_attention_mask, new_position_ids + else: + return input_ids, new_attention_mask, new_position_ids + + +def recover_left_padding( + result, + attention_mask: torch.Tensor, + original_attention_mask: torch.Tensor, + origin_seqlen: int, + post_process: bool = True, +): + """ + Recover left padding from result + return result + """ + if not post_process: + return result + shape = list(result.shape) + batch_size = shape[0] + shape[1] = origin_seqlen + new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape) + for i in range(batch_size): + new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]] + return new_result diff --git a/verl/models/mcore/weight_converter.py b/verl/models/mcore/weight_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..8e93ebe8f6225ff2045871c449739a1844c9b06b --- /dev/null +++ b/verl/models/mcore/weight_converter.py @@ -0,0 +1,207 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# online convert mcore weight to pure huggingface weight, no any fusion +# including format conversion and name mapping +# not including resharding +import torch +from megatron.core.transformer import TransformerConfig +from transformers import PretrainedConfig + + +class McoreToHFWeightConverterBase: + def __init__(self, hf_config: PretrainedConfig, mcore_config: TransformerConfig): + self.hf_config = hf_config + self.mcore_config = mcore_config + + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> torch.Tensor: + raise NotImplementedError + + +class McoreToHFWeightConverterDense(McoreToHFWeightConverterBase): + def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # 'decoder.layers.0.self_attention.linear_proj.weight' + # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight' + # 'decoder.layers.0.self_attention.linear_qkv.weight' + # 'decoder.layers.0.self_attention.linear_qkv.bias' + layer_number = name.split(".")[2] + convert_names = [] + if "self_attention.linear_qkv.bias" in name or "self_attention.linear_qkv.weight" in name: + param_type = name.split(".")[-1] + assert param_type == "bias" or param_type == "weight" + convert_names.append(f"model.layers.{layer_number}.self_attn.q_proj.{param_type}") + convert_names.append(f"model.layers.{layer_number}.self_attn.k_proj.{param_type}") + convert_names.append(f"model.layers.{layer_number}.self_attn.v_proj.{param_type}") + assert len(params) == 3 + elif "self_attention.linear_proj.weight" in name: + convert_names.append(f"model.layers.{layer_number}.self_attn.o_proj.weight") + assert len(params) == 1 + elif "self_attention.linear_qkv.layer_norm_weight" in name: + convert_names.append(f"model.layers.{layer_number}.input_layernorm.weight") + assert len(params) == 1 + elif "self_attention.q_layernorm.weight" in name: + convert_names.append(f"model.layers.{layer_number}.self_attn.q_norm.weight") + assert len(params) == 1 + elif "self_attention.k_layernorm.weight" in name: + convert_names.append(f"model.layers.{layer_number}.self_attn.k_norm.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight' + # 'decoder.layers.0.mlp.linear_fc1.weight' + # 'decoder.layers.0.mlp.linear_fc2.weight' + layer_number = name.split(".")[2] + convert_names = [] + if "mlp.linear_fc1.weight" in name: + # split gate_proj and up_proj + convert_names.append(f"model.layers.{layer_number}.mlp.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.up_proj.weight") + assert len(params) == 2 + elif "mlp.linear_fc1.layer_norm_weight" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + assert len(params) == 1 + elif "mlp.linear_fc2.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.down_proj.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + direct_name_mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + if name in direct_name_mapping: + return [direct_name_mapping[name]], [params_one_group[0]] + + if "self_attention" in name: + return self._convert_attention_param(name, params_one_group) + elif "mlp" in name: + return self._convert_mlp_param(name, params_one_group) + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + + +class McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense): + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # 'decoder.layers.0.pre_mlp_layernorm.weight', + # 'decoder.layers.0.mlp.router.weight', + # 'decoder.layers.0.mlp.shared_experts.gate_weight', + # 'decoder.layers.0.mlp.shared_experts.linear_fc1.weight', + # 'decoder.layers.0.mlp.shared_experts.linear_fc2.weight' + # moe1 + # 'decoder.layers.0.mlp.experts.linear_fc1.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight1', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight2', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight3', + # moe2 + # 'decoder.layers.0.mlp.experts.linear_fc2.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc2.weight1', + layer_number = name.split(".")[2] + convert_names = [] + if "pre_mlp_layernorm" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + assert len(params) == 1 + elif "mlp.router.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") + assert len(params) == 1 + elif "shared_experts.gate_weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert_gate.weight") + assert len(params) == 1 + elif "shared_experts.linear_fc1.weight" in name: # split gate_proj and up_proj + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight") + assert len(params) == 2 + elif "shared_experts.linear_fc2.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight") + assert len(params) == 1 + elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") + assert len(params) == 2 + elif "mlp.experts.linear_fc2" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + +class McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense): + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # decoder.layers.0.mlp.router.weight + # decoder.layers.0.mlp.experts.linear_fc1.weight0 - weight7 + # decoder.layers.0.mlp.experts.linear_fc2.weight0 - weight7 + + layer_number = name.split(".")[2] + convert_names = [] + if "pre_mlp_layernorm" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + elif "mlp.router.weight" in name: + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.gate.weight") + elif "mlp.experts.linear_fc1.weight" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight") + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight") + elif "mlp.experts.linear_fc2.weight" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight") + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params + + +class McoreToHFWeightConverterQwen3Moe(McoreToHFWeightConverterDense): + def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + # qwen3 moe no share expert + + # 'decoder.layers.0.pre_mlp_layernorm.weight', + # 'decoder.layers.0.mlp.router.weight', + # moe1 + # 'decoder.layers.0.mlp.experts.linear_fc1.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight1', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight2', + # 'decoder.layers.0.mlp.experts.linear_fc1.weight3', + # moe2 + # 'decoder.layers.0.mlp.experts.linear_fc2.weight0', + # 'decoder.layers.0.mlp.experts.linear_fc2.weight1', + layer_number = name.split(".")[2] + convert_names = [] + if "pre_mlp_layernorm" in name: + convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + assert len(params) == 1 + elif "mlp.router.weight" in name: + convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") + assert len(params) == 1 + elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") + assert len(params) == 2 + elif "mlp.experts.linear_fc2" in name: + expert_id = name.split("weight")[-1] + convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") + assert len(params) == 1 + else: + raise NotImplementedError(f"Unsupported parameter name: {name}") + return convert_names, params diff --git a/verl/models/qwen2/__init__.py b/verl/models/qwen2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/models/qwen2/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/models/qwen2/megatron/__init__.py b/verl/models/qwen2/megatron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a03fde797ed7ab9f916d73267095049e5b0f7586 --- /dev/null +++ b/verl/models/qwen2/megatron/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_qwen2_megatron import ( + ParallelQwen2ForCausalLM, + # rmpad with megatron + ParallelQwen2ForCausalLMRmPad, + # rmpad with megatron and pipeline parallelism + ParallelQwen2ForCausalLMRmPadPP, + ParallelQwen2ForValueRmPad, + ParallelQwen2ForValueRmPadPP, + # original model with megatron + ParallelQwen2Model, +) + +__all__ = [ + "ParallelQwen2ForCausalLM", + "ParallelQwen2ForCausalLMRmPad", + "ParallelQwen2ForCausalLMRmPadPP", + "ParallelQwen2ForValueRmPad", + "ParallelQwen2ForValueRmPadPP", + "ParallelQwen2Model", +] diff --git a/verl/models/qwen2/megatron/checkpoint_utils/__init__.py b/verl/models/qwen2/megatron/checkpoint_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/models/qwen2/megatron/checkpoint_utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..479cdc07529df193eb06875881cf4b63deeb6e8f --- /dev/null +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py @@ -0,0 +1,312 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.megatron_utils import print_rank_0, unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def fetch_params(module): + for param in module.parameters(): + torch.distributed.fetch(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, (list, tuple)): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _fetch_tensor(tensor, name) -> torch.Tensor: + """fetch tensor""" + nonlocal state_dict + if tensor is not None: + tensor = tensor.data.copy_(state_dict[name], non_blocking=True) + + def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """fetch tensor in tp shards""" + nonlocal state_dict + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + else: + print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """fetch gate_up tensor in tp shards""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + if gate_name in state_dict and up_name in state_dict: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + else: + print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") + + def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + """fetch tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + if tensor is not None: + tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) + + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + + layer_list = [] + if vpp_size is not None: + for vpp_rank in range(vpp_size): + num_layer_vpp_chunk = num_layer_per_pp // vpp_size + num_layer_this_model = num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk) + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + else: + num_layer_this_model = num_layer_per_pp + offset = pp_rank * num_layer_per_pp + layer_list.extend(list(range(offset, offset + num_layer_this_model))) + + for layer in layer_list: + print(f"{torch.distributed.get_rank()} loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + print(f"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}") + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _fetch_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _fetch_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _fetch_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True, + ) + + _fetch_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _fetch_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _fetch_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _fetch_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _fetch_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + if tie_word_embeddings: + print_rank_0("tie_word_embeddings skip load lm_head") + else: + print_rank_0("loading lm_head...") + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head from value_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _fetch_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _fetch_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + + else: + _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") + + dist.barrier() + torch.cuda.empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cc076506b4680995e364aa2a7b5115daa02479 --- /dev/null +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py @@ -0,0 +1,442 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False): + """Load merged state_dict to sharded Megatron module in training.""" + from megatron.core import DistributedDataParallel as LocalDDP + from megatron.core import mpu + from megatron.core.transformer.module import Float16Module + from torch.nn.parallel import DistributedDataParallel as torchDDP + + from verl.utils.megatron_utils import print_rank_0, unwrap_model + + start_time = time.time() + + def _get_gpt_model(model): + return model + + def broadcast_params(module): + for param in module.parameters(): + torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group()) + + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if torch.distributed.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, (list, tuple)): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + gpt_model_module = _get_gpt_model(models[i]) + assert len(gpt_model_module.model.layers) == num_layers_per_model + + def _broadcast_tensor(tensor, name) -> torch.Tensor: + """broadcast tensor from rank0 across mp_group""" + nonlocal state_dict + nonlocal mp_group + if torch.distributed.get_rank() == 0: + if name in state_dict: + weight = state_dict[name] + tensor_shape = weight.shape + else: + tensor_shape = None + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not in state_dict, skip load") + return + + if tensor is None: + tensor = torch.empty( + tensor_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + if torch.distributed.get_rank() == 0: + tensor.data.copy_(weight) + dist.broadcast(tensor, src=0, group=mp_group) + + def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + if name in state_dict: + full_weight = state_dict[name] + if mutate_func is not None: + full_weight = mutate_func(full_weight) + tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + gate_weight = state_dict[gate_name] + up_weight = state_dict[up_name] + new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] + new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) + + tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + + if torch.distributed.get_rank() == 0: + assert q_name in state_dict and k_name in state_dict and v_name in state_dict + full_weight_q = state_dict[q_name] + full_weight_k = state_dict[k_name] + full_weight_v = state_dict[v_name] + + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] + v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + if not bias: + new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=torch.cuda.current_device()) + else: + new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=torch.cuda.current_device()) + for i in range(tp_size): + q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] + start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head + end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + k_part = full_weight_k[start_idx:end_idx] + v_part = full_weight_v[start_idx:end_idx] + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + + tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) + chunk_shape = tensor_chunk[0].shape + else: + chunk_shape = None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=0, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + return + + if tensor is None: + sync_tensor = torch.empty( + chunk_shape, + dtype=params_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + else: + assert tensor.shape == chunk_shape, f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False) + + for i in range(tp_size): + if torch.distributed.get_rank() == 0: + sync_tensor.data.copy_(tensor_chunk[i]) + dist.broadcast(sync_tensor, src=0, group=mp_group) + if (i == tp_rank) and (tensor is not None): + tensor.data.copy_(sync_tensor) + + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("loading embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + embed_tokens_weight = None + if pp_rank == 0: + embed_tokens_weight = gpt_model_module.model.embed_tokens.weight + _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + + for layer in range(config.num_hidden_layers): + print_rank_0(f"loading layer #{layer}...") + layer_name = f"model.layers.{layer}" + dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[dst_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.input_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + bias=True, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.self_attn.o_proj.weight", + chunk_dim=1, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.post_attention_layernorm.weight", + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, + f"{layer_name}.mlp.down_proj.weight", + chunk_dim=1, + ) + # Final Layernorm + # ------------------- + print_rank_0("loading final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + ) + + if tie_word_embeddings: + print_rank_0("tie_word_embeddings skip load lm_head") + else: + print_rank_0("loading lm_head...") + lm_head_weight = None + if pp_rank + 1 == pp_size: + lm_head_weight = gpt_model_module.lm_head.weight + + if is_value_model: + if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "lm_head.weight") + print_rank_0("load lm_head from value_head weight") + elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + _broadcast_tensor(lm_head_weight, "reward_head.weight") + print_rank_0("load lm_head from value_head weight") + else: + _broadcast_tensor(None, "lm_head.weight") + print_rank_0("fail to match lm_head in value_model") + + else: + _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") + + dist.barrier() + # Broadcast weights inside data parallel groups + for wrapped_model in wrapped_models: + broadcast_params(wrapped_model) + + torch.cuda.empty_cache() + print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6999fa90538de6883a522d31a0c37d751c7851 --- /dev/null +++ b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py @@ -0,0 +1,436 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as LocalDDP +from megatron.core.transformer.module import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from verl.utils.megatron_utils import print_rank_0, unwrap_model + + +def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): + """given TP,DP,PP rank to get the global rank.""" + + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" + # We only support TP-DP-PP grouping, for correctness when resharding + return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank + + +def _megatron_calc_layer_map(config): + """Calculate the mapping of global layer_idx to local layer_idx + Returns: + layer_map (Dict: int -> tuple(int, int, int)): + mapping from the global layer index to + a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) + """ + from megatron.core import mpu + + pp_size = mpu.get_pipeline_model_parallel_world_size() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + + layer_map = dict() + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + for pp_rank_idx in range(pp_size): + for virtual_pp_rank_idx in range(virtual_pp_size): + layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + for layer_idx in range(num_layers_per_model): + layer_map[layer_offset + layer_idx] = ( + pp_rank_idx, + virtual_pp_rank_idx, + layer_idx, + ) + return layer_map + + +def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): + """Merge sharded parameters of a Megatron module into a merged checkpoint. + + Args: + wrapped_models (list of megatron.core.distributed.DistributedDataParallel): + The local DDP wrapped megatron modules. + config (str or None): + HF config for model + dtype: model params type + is_value_model: if model is value model + tie_word_embeddings: tie_word_embeddings + Returns: + state_dict (dict): + The merged state_dict in rank 0, and an empty dictionary in other ranks. + """ + start_time = time.time() + + def _get_gpt_model(model): + return model + + dp_rank = mpu.get_data_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 + mp_group = mpu.get_model_parallel_group() + + if dist.get_rank() == 0: + assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" + assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" + assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" + + if not isinstance(wrapped_models, (list, tuple)): + wrapped_models = list(wrapped_models) + + assert len(wrapped_models) == virtual_pp_size + num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size + assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + + models = [None] * len(wrapped_models) + + for i, wrapped_model in enumerate(wrapped_models): + models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) + assert len(models[i].model.layers) == num_layers_per_model, "len model layers {} not equal to num_layers_per_model {}".format(len(models[i].model.layers), num_layers_per_model) + + state_dict = dict() + + def _get_cpu_tensor(tensor: torch.Tensor): + if tensor is None: + return None + if tensor.device == torch.device("cpu"): + return tensor.detach().clone() + return tensor.detach().cpu() + + def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: + """broadcast tensor across mp_group""" + nonlocal state_dict + nonlocal mp_group + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + if torch.distributed.get_rank() == src_rank: + if tensor is None: + weight = None + tensor_shape = None + else: + weight = tensor + tensor_shape = weight.shape + else: + weight = None + tensor_shape = None + + obj_list = [tensor_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + tensor_shape = obj_list[0] + + if tensor_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tensor:[{name}] not exist, skip collect") + return + + if weight is None: + weight = torch.empty( + tensor_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + dist.broadcast(weight, src=src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + state_dict[name] = _get_cpu_tensor(weight) + + def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=concat_dim) + if mutate_func is not None: + full_tensor = mutate_func(full_tensor) + state_dict[name] = full_tensor + + def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_list = [] + up_weight_list = [] + for i in range(tp_size): + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] + up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] + gate_weight_list.append(gate_weight_tp) + up_weight_list.append(up_weight_tp) + + state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) + state_dict[up_name] = torch.cat(up_weight_list, dim=0) + + def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): + """broadcast tensor in tp shards across mp_group""" + nonlocal state_dict + nonlocal mp_group + tp_size = mpu.get_tensor_model_parallel_world_size() + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) + + chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None + + obj_list = [chunk_shape] + dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) + chunk_shape = obj_list[0] + if chunk_shape is None: + # all or none ranks in the mp_group should reach here + print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") + return + + buffer_tensor = torch.empty( + chunk_shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + + chunk_tensors = [None] * tp_size + + for i in range(tp_size): + cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) + sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) + + if torch.distributed.get_rank() == 0: + chunk_tensors[i] = _get_cpu_tensor(sync_tensor) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.concat(chunk_tensors, dim=0) + q_weight_list = [] + k_weight_list = [] + v_weight_list = [] + hidden_size_per_head = config.hidden_size // config.num_attention_heads + + if config.num_key_value_heads >= tp_size: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + k_weight_list.append(k_part) + v_weight_list.append(v_part) + else: + q_size_tp = config.hidden_size // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_part = qkv_part[:q_size_tp] + k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] + v_part = qkv_part[q_size_tp + kv_size_tp : total_size] + q_weight_list.append(q_part) + if i * config.num_key_value_heads % tp_size == 0: + k_weight_list.append(k_part) + v_weight_list.append(v_part) + + state_dict[q_name] = torch.cat(q_weight_list, dim=0) + state_dict[k_name] = torch.cat(k_weight_list, dim=0) + state_dict[v_name] = torch.cat(v_weight_list, dim=0) + + # empty cache before collecting weights + torch.cuda.empty_cache() + # Embeddings + # ------------------- + if dp_rank == 0: + # Embeddings + # ------------------- + print_rank_0("collecting embeddings...") + gpt_model_module = _get_gpt_model(models[0]) + _broadcast_tp_shard_tensor( + gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None, + "model.embed_tokens.weight", + src_pp_rank=0, + ) + + # Transformer layers + # ------------------- + layer_map = _megatron_calc_layer_map(config) + for layer in range(config.num_hidden_layers): + print_rank_0(f"collecting layer #{layer}...") + layer_name = f"model.layers.{layer}" + src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] + + gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) + sync_layer = gpt_model_module.model.layers[src_layer_idx] + + _broadcast_tensor( + sync_layer.input_layernorm.weight, + f"{layer_name}.input_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.weight, + f"{layer_name}.self_attn.q_proj.weight", + f"{layer_name}.self_attn.k_proj.weight", + f"{layer_name}.self_attn.v_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_qkv( + sync_layer.self_attn.qkv_proj.bias, + f"{layer_name}.self_attn.q_proj.bias", + f"{layer_name}.self_attn.k_proj.bias", + f"{layer_name}.self_attn.v_proj.bias", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.self_attn.o_proj.weight, + f"{layer_name}.self_attn.o_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + _broadcast_tensor( + sync_layer.post_attention_layernorm.weight, + f"{layer_name}.post_attention_layernorm.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor_gate_up( + sync_layer.mlp.gate_up_proj.weight, + f"{layer_name}.mlp.gate_proj.weight", + f"{layer_name}.mlp.up_proj.weight", + src_pp_rank=src_pp_rank, + ) + + _broadcast_tp_shard_tensor( + sync_layer.mlp.down_proj.weight, + f"{layer_name}.mlp.down_proj.weight", + concat_dim=1, + src_pp_rank=src_pp_rank, + ) + + # Final Layernorm + # ------------------- + print_rank_0("collecting final layernorm...") + gpt_model_module = _get_gpt_model(models[-1]) + _broadcast_tensor( + getattr(gpt_model_module.model.norm, "weight", None), + "model.norm.weight", + src_pp_rank=pp_size - 1, + ) + + if tie_word_embeddings: + print_rank_0("tie word embedding skip load lm_head...") + else: + print_rank_0("collecting lm_head...") + + if is_value_model: + _broadcast_tensor( + gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + _broadcast_tensor( + gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None else None, + "reward_head.weight", + src_pp_rank=pp_size - 1, + ) + + else: + _broadcast_tp_shard_tensor( + getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, + "lm_head.weight", + src_pp_rank=pp_size - 1, + ) + + dist.barrier() + + torch.cuda.empty_cache() + if torch.distributed.get_rank() == 0: + for k, v in state_dict.items(): + if dtype != v.dtype: + state_dict[k] = v.to(dtype) + + print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") + return state_dict diff --git a/verl/models/qwen2/megatron/layers/__init__.py b/verl/models/qwen2/megatron/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3ebc41b73e1fc0f905d555b4c0ca312319cc79 --- /dev/null +++ b/verl/models/qwen2/megatron/layers/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .parallel_attention import ParallelQwen2Attention +from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad +from .parallel_mlp import ParallelQwen2MLP +from .parallel_rmsnorm import ParallelQwen2RMSNorm + +__all__ = ["ParallelQwen2Attention", "ParallelQwen2DecoderLayer", "ParallelQwen2DecoderLayerRmPad", "ParallelQwen2MLP", "ParallelQwen2RMSNorm"] diff --git a/verl/models/qwen2/megatron/layers/parallel_attention.py b/verl/models/qwen2/megatron/layers/parallel_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4c70fbf18eed2a922e02bbee23582c97453ce2ea --- /dev/null +++ b/verl/models/qwen2/megatron/layers/parallel_attention.py @@ -0,0 +1,370 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple + +import torch.nn.functional as F +from einops import rearrange +from transformers.utils import is_flash_attn_2_available + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +import torch +from flash_attn.layers.rotary import apply_rotary_emb +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers import Qwen2Config + +from verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class Qwen2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding): + """Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding): + """Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ParallelQwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config = config + self.megatron_config = megatron_config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # assign values after tp + tp_size = mpu.get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0, f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" + assert self.num_key_value_heads % tp_size == 0, f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}" + + self.num_heads_per_tp = self.num_heads // tp_size + self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size + self.hidden_size_per_tp = self.hidden_size // tp_size + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).") + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + + # [self.q_size, self.k_size, self.v_size] + self.qkv_proj = QKVParallelLinear( + input_size=self.hidden_size, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + # bias=config.attention_bias, + bias=True, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + self.q_size = self.num_heads_per_tp * self.head_dim + self.k_size = self.num_key_value_heads_per_tp * self.head_dim + self.v_size = self.num_key_value_heads_per_tp * self.head_dim + + self.o_proj = tensor_parallel.RowParallelLinear( + input_size=self.num_heads * self.head_dim, + output_size=self.hidden_size, + # bias=config.attention_bias, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self._init_rope() + + def _init_rope(self): + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): + raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) + attn_output = self.o_proj(attn_output)[0] + return attn_output + + +""" +Remove padding Attention +- Using Flash-attn 2 +- Compatible with sequence parallel +""" + + +def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): + batch_size = position_ids.shape[0] + + q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) + k = pad_input(k, indices, batch_size, sequence_length) + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices) + k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices) + + return q_embed, k_embed + + +# use flash-attn rotary embeddings with rmpad +# cos/sin shoudl be: (seq_length, rotary_dim / 2) +def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): + q_embed = apply_rotary_emb(q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + k_embed = apply_rotary_emb(k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + return q_embed, k_embed + + +class ParallelQwen2AttentionRmPad(ParallelQwen2Attention): + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: torch.Tensor = None, + max_seqlen_in_batch: int = None, + ): + total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel + + if self.megatron_config.sequence_parallel: + total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() + + qkv = self.qkv_proj(hidden_states)[0] + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) # (total_nnz, 1, hidden_size) + + if self.megatron_config.sequence_parallel: + sequence_parallel_pad = total_nnz - cu_seqlens[-1] + total_nnz = cu_seqlens[-1] # total_nnz before sp padding + query_states = query_states[:total_nnz] + key_states = key_states[:total_nnz] + value_states = value_states[:total_nnz] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dime x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) + key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + + cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) + cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half + query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch) + # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices, + + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Qwen2RMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_in_batch, + max_seqlen_k=max_seqlen_in_batch, + dropout_p=dropout_rate, + softmax_scale=None, + causal=True, + ) + + attn_output_unpad = attn_output_unpad.to(input_dtype) + attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() + + # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled + # Here we need to repad + if self.megatron_config.sequence_parallel: + attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) + + attn_output_unpad = self.o_proj(attn_output_unpad)[0] + return attn_output_unpad diff --git a/verl/models/qwen2/megatron/layers/parallel_decoder.py b/verl/models/qwen2/megatron/layers/parallel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4208032f79f2eb645db3c3a422f2e95b7b2d8d84 --- /dev/null +++ b/verl/models/qwen2/megatron/layers/parallel_decoder.py @@ -0,0 +1,150 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import Qwen2Config + +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad +from .parallel_mlp import ParallelQwen2MLP +from .parallel_rmsnorm import ParallelQwen2RMSNorm + + +class ParallelQwen2DecoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config) + + self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Note: sequence parallel is hidden inside ColumnParallelLinear + # reduce scatter is hidden inside RowParallelLinear + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # TODO: add sequence parallel operator all_gather here + + hidden_states = self.mlp(hidden_states) + + # TODO: add sequence parallel operator reduce_scatter here + + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs + + +class ParallelQwen2DecoderLayerRmPad(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config) + + self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) + self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states # (total_nnz // sp, 1, hidden_size) + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) + # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + # shape changes same as attn + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = hidden_states + + return outputs diff --git a/verl/models/qwen2/megatron/layers/parallel_linear.py b/verl/models/qwen2/megatron/layers/parallel_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..c1dac77cceaa6cc1fd441886f089c1899a308261 --- /dev/null +++ b/verl/models/qwen2/megatron/layers/parallel_linear.py @@ -0,0 +1,79 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py + + +from megatron.core import tensor_parallel + + +class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + num_heads, + num_key_value_heads, + head_dim, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.q_output_size = num_heads * head_dim + self.kv_output_size = num_key_value_heads * head_dim + self.head_dim = head_dim + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + input_size = self.input_size + output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) + + +class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): + def __init__( + self, + input_size, + gate_ouput_size, + up_output_size, + *, + bias=True, + gather_output=True, + skip_bias_add=False, + **kwargs, + ): + # Keep input parameters, and already restrict the head numbers + self.input_size = input_size + self.output_size = gate_ouput_size + up_output_size + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + super().__init__( + input_size=self.input_size, + output_size=self.output_size, + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + **kwargs, + ) diff --git a/verl/models/qwen2/megatron/layers/parallel_mlp.py b/verl/models/qwen2/megatron/layers/parallel_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..981b24ebe29cbc8f58fc0cce15b62336ab46f00e --- /dev/null +++ b/verl/models/qwen2/megatron/layers/parallel_mlp.py @@ -0,0 +1,74 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core import ModelParallelConfig, tensor_parallel +from megatron.core import parallel_state as mpu +from torch import nn +from transformers.activations import ACT2FN + +from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear +from verl.utils.megatron import tensor_parallel as tp_utils + + +class ParallelQwen2MLP(nn.Module): + def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() + + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + assert row_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) + tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) + + tp_size = mpu.get_tensor_model_parallel_world_size() + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + gate_ouput_size=self.intermediate_size, + up_output_size=self.intermediate_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + self.gate_size = self.intermediate_size // tp_size + + self.down_proj = tensor_parallel.RowParallelLinear( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=False, + input_is_parallel=True, + skip_bias_add=False, + **row_kwargs, + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + gate_up = self.gate_up_proj(x)[0] + gate, up = gate_up.split(self.gate_size, dim=-1) + return self.down_proj(self.act_fn(gate) * up)[0] diff --git a/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py b/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..831b73fd6b2ca5f96dd1f5cfcca4ae12274c0e7b --- /dev/null +++ b/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py @@ -0,0 +1,48 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers + +import torch +from apex.normalization.fused_layer_norm import fused_rms_norm_affine +from megatron.core import ModelParallelConfig +from torch import nn +from transformers import Qwen2Config + +from verl.utils.megatron import sequence_parallel as sp_utils + + +class ParallelQwen2RMSNorm(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + if isinstance(config.hidden_size, numbers.Integral): + normalized_shape = (config.hidden_size,) + self.normalized_shape = torch.Size(normalized_shape) + self.weight = nn.Parameter(torch.ones(self.normalized_shape)) + self.variance_epsilon = config.rms_norm_eps + + if megatron_config.sequence_parallel: + sp_utils.mark_parameter_as_sequence_parallel(self.weight) + + def forward(self, hidden_states): + return fused_rms_norm_affine( + input=hidden_states, + weight=self.weight, + normalized_shape=self.normalized_shape, + eps=self.variance_epsilon, + memory_efficient=True, + ) diff --git a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py new file mode 100644 index 0000000000000000000000000000000000000000..95fe66d5e5ebdb77d2005a912b42c45286feb5d9 --- /dev/null +++ b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py @@ -0,0 +1,711 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model.""" + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast + +from verl.utils.megatron import sequence_parallel as sp_utils +from verl.utils.megatron import tensor_parallel as tp_utils +from verl.utils.megatron_utils import TransformerConfig, convert_config + +from .layers import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad, ParallelQwen2RMSNorm + +""" +TODO: +1. Add weight initialization. Here we need to be careful on TP weight init. +2. Add sequence parallel +3. Load checkpoint from Qwen2 pretrained checkpoint +""" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class ParallelQwen2Model(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + + self.layers = nn.ModuleList([ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)]) + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + + Args: + input_ids: input ids. shape (batch_size, seq_length) + attention_mask: attention_mask. shape (batch_size, seq_length) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + batch_size, seq_length = input_ids.shape + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLM(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.model = ParallelQwen2Model(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = outputs + logits = self.lm_head(hidden_states)[0] + + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + logits = logits.float() + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +class ParallelQwen2ModelRmPad(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + self.megatron_config = megatron_config + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + + self.layers = nn.ModuleList([ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)]) + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLMRmPad(nn.Module): + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelQwen2ModelRmPad(config, megatron_config=megatron_config) + self.vocab_size = config.vocab_size + self._init_head(config) + + def _init_head(self, config: Qwen2Config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + **column_kwargs, + ) + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + logits = self.lm_head(hidden_states)[0] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + batch_size, sequence_length = input_ids.shape + + # remove padding here + input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids = sp_utils.pad_to_sequence_parallel(input_ids) + + input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = outputs + + logits = self._forward_head(hidden_states) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back + logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + +class ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output = super().forward(input_ids, attention_mask, position_ids) + output.logits = torch.squeeze(output.logits, dim=-1) + return output + + +""" +Support pipeline parallelism +""" + + +class ParallelQwen2ModelRmPadPP(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + This model definition supports pipeline parallelism. To support pp and vpp, + - This model only contains layer in this pp stage and vpp chunk + - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp. + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + self.megatron_config = megatron_config + embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() + if megatron_config is not None: + assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) + if pre_process: + self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs) + else: + self.embed_tokens = None + + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = megatron_config.pipeline_model_parallel_size + self.num_layer_per_pp = config.num_hidden_layers // pp_size + vpp_size = megatron_config.virtual_pipeline_model_parallel_size + vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() + + if vpp_size is not None: + self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size + self.num_layer_this_model = self.num_layer_vpp_chunk + offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) + else: + self.num_layer_this_model = self.num_layer_per_pp + offset = pp_rank * self.num_layer_per_pp + + self.layers = nn.ModuleList() + for i in range(self.num_layer_this_model): + layer = ParallelQwen2DecoderLayerRmPad(config, megatron_config, layer_idx=i + offset) + self.layers.add_module(f"{i}", layer) + + if post_process: + self.norm = ParallelQwen2RMSNorm(config, megatron_config) + else: + self.norm = None + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + sequence_length: int = None, + indices: torch.Tensor = None, + cu_seqlens: int = None, + max_seqlen_in_batch: int = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + + Args: + input_ids: input ids. shape (1, totol_nnz) + position_ids: position ids. shape (batch_size, seq_length) + + Returns: + + """ + if self.pre_process: + inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + + # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron + # so need to deal with it by handle here: + # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) + inputs_embeds = inputs_embeds.transpose(0, 1) + if self.megatron_config.sequence_parallel: + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + + hidden_states = inputs_embeds + else: + # self.hidden_states should be passed by Megatron + hidden_states = self.input_tensor + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + hidden_states = layer_outputs + + if self.post_process: + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class ParallelQwen2ForCausalLMRmPadPP(nn.Module): + def __init__( + self, + config: Qwen2Config, + megatron_config: ModelParallelConfig, + pre_process, + post_process, + share_embeddings_and_output_weights, + ): + super().__init__() + self.config: TransformerConfig = convert_config(config, megatron_config) + self.megatron_config = megatron_config + self.model = ParallelQwen2ModelRmPadPP(config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process) + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.vocab_size = config.vocab_size + self.pre_process = pre_process + self.post_process = post_process + if post_process: + self._init_head(config) + if pre_process or post_process: + self.setup_embeddings_and_output_layer() + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + assert len(input_tensor) == 1 + self.model.set_input_tensor(input_tensor[0]) + + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = tensor_parallel.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.vocab_size, + bias=False, + gather_output=False, + skip_bias_add=False, + skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights, + **column_kwargs, + ) + + def setup_embeddings_and_output_layer(self) -> None: + """Sets up embedding layer in first stage and output layer in last stage. + + This function initalizes word embeddings in the final stage when we are + using pipeline parallelism and sharing word embeddings, and sets up param + attributes on the embedding and output layers. + """ + # Set `is_embedding_or_output_parameter` attribute. + if self.pre_process: + self.model.embed_tokens.weight.is_embedding_or_output_parameter = True + if self.post_process and self.lm_head.weight is not None: + self.lm_head.weight.is_embedding_or_output_parameter = True + + if not self.share_embeddings_and_output_weights: + return + + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + # Zero out wgrad if sharing embeddings between two layers on same + # pipeline stage to make sure grad accumulation into main_grad is + # correct and does not include garbage values (e.g., from torch.empty). + self.shared_embedding_or_output_weight().zero_out_wgrad = True + return + + if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process: + self.shared_embedding_or_output_weight().shared_embedding = True + + if self.post_process and not self.pre_process: + assert not parallel_state.is_pipeline_first_stage() + # set word_embeddings weights to 0 here, then copy first + # stage's weights using all_reduce below. + self.lm_head.weight.data.fill_(0) + self.lm_head.weight.shared = True + self.lm_head.weight.shared_embedding = True + + if torch.distributed.is_initialized() and parallel_state.is_rank_in_embedding_group(): + weight = self.shared_embedding_or_output_weight() + weight.data = weight.data.cuda() + torch.distributed.all_reduce(weight.data, group=parallel_state.get_embedding_group()) + + def shared_embedding_or_output_weight(self) -> torch.Tensor: + if self.pre_process: + return self.model.embed_tokens.weight + elif self.post_process: + return self.lm_head.weight + return None + + def _forward_head(self, hidden_states): + # all_gather from sequence parallel region is performed inside lm_head + # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = {self.config.vocab_size}') # [4, 32, 4096] + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits = self.lm_head(hidden_states, weight=output_weight)[0] + # print(f'logits shape after forward_head: {logits.shape}') # [8, 32, 8] + logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) + return logits + + def forward( + self, + # original input + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + + # Note that input_ids, attention_mask and position_ids should be passed to every pp layer. + # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model + batch_size, sequence_length = input_ids.shape + # remove padding here + input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask) # (total_nnz, 1) + + # pad input_ids to multiple of tp for all tp ranks + # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap + if self.megatron_config.sequence_parallel: + input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) + + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) + + outputs = self.model( + input_ids=input_ids_rmpad, + position_ids=position_ids, + sequence_length=sequence_length, + indices=indices, + cu_seqlens=cu_seqlens, + max_seqlen_in_batch=max_seqlen_in_batch, + ) + + if self.post_process: + hidden_states = outputs + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back. If input is already rmpad, we let the caller pad_input + logits = pad_input(logits, indices, batch_size, seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + else: + return outputs + + +class ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP): + def _init_head(self, config): + column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() + if self.megatron_config is not None: + assert column_kwargs.get("config", False), "must have ModelParallelConfig" + tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) + self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + # lm_head is effectively the same as sequence parallel + sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) + + def _forward_head(self, hidden_states): + logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) + logits = logits.float() + if self.megatron_config.sequence_parallel: + logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + return logits + + def forward( + self, + *, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + if self.post_process: + output.logits = torch.squeeze(output.logits, dim=-1) + return output + else: + return output diff --git a/verl/models/registry.py b/verl/models/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..5439c96848a29db1df8f352e83be78ae0b99147a --- /dev/null +++ b/verl/models/registry.py @@ -0,0 +1,58 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from typing import List, Optional, Type + +import torch.nn as nn + +# Supported models in Megatron-LM +# Architecture -> (module, class). +_MODELS = { + "LlamaForCausalLM": ( + "llama", + ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad"), + ), + "Qwen2ForCausalLM": ( + "qwen2", + ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad"), + ), + "MistralForCausalLM": ( + "mistral", + ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", "ParallelMistralForCausalLMRmPad"), + ), +} + + +# return model class +class ModelRegistry: + @staticmethod + def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]: + if model_arch not in _MODELS: + return None + + megatron = "megatron" + + module_name, model_cls_name = _MODELS[model_arch] + if not value: # actor/ref + model_cls_name = model_cls_name[0] + elif value: # critic/rm + model_cls_name = model_cls_name[1] + + module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron") + return getattr(module, model_cls_name, None) + + @staticmethod + def get_supported_archs() -> List[str]: + return list(_MODELS.keys()) diff --git a/verl/models/transformers/__init__.py b/verl/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/models/transformers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/models/transformers/__pycache__/__init__.cpython-311.pyc b/verl/models/transformers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2803decc4865ea58e84333ae6aef0b5bd946f9a8 Binary files /dev/null and b/verl/models/transformers/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/models/transformers/__pycache__/monkey_patch.cpython-311.pyc b/verl/models/transformers/__pycache__/monkey_patch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd91a84ddf5efcbb2e9c4411e98613a7826c8ee0 Binary files /dev/null and b/verl/models/transformers/__pycache__/monkey_patch.cpython-311.pyc differ diff --git a/verl/models/transformers/llama.py b/verl/models/transformers/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..f13da734f17d980b06a5ad3868db3252203ab401 --- /dev/null +++ b/verl/models/transformers/llama.py @@ -0,0 +1,230 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Callable, Optional, Tuple + +import torch + +if sys.version_info >= (3, 11): + pass +else: + pass + +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.utils import logging + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + +logger = logging.get_logger(__name__) + + +def llama_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. + + NOTE: This function is used for transformers versions in the range [4.45.0, 4.47.1]. + """ + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # trade off: repeat first and then all to all + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) + + # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) # full seq length + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once(f"The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in {target_dtype}.") + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + full_q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def llama_attn_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. + + NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from transformers.models.llama.modeling_llama import eager_attention_forward + + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size) + + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once('`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.') + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..92c50880f84b0ec8a6292ab8633322328982eab3 --- /dev/null +++ b/verl/models/transformers/monkey_patch.py @@ -0,0 +1,151 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Apply monkey-patch function to models +""" + +import importlib.metadata +import sys +from functools import lru_cache +from typing import Optional + +import torch +from packaging import version +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_utils import PreTrainedModel + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_group, + get_ulysses_sequence_parallel_world_size, +) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=2, repeats=n_rep). The hidden states go from (batch, + seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim) + """ + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim) + return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim) + + +def _ulysses_flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + *args, + position_ids: Optional[torch.Tensor] = None, + **kwargs, +): + """Insert all-to-all before and after flash attention. + DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509 + + Args: + query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim) + key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim) + value_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim) + position_ids (torch.Tensor, optional): (batch_size, seqlen/sp_size) + + Returns: + torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim) + """ + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism" + + # NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k, + # we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA. + # For example: + # - nheads_k=4, sp=8, repeats=2 + # - nheads_k=8, sp=8, repeats=1 + # - nheads_k=16, sp=8, repeats=1 + repeats = max(ulysses_sp_size // key_states.size(2), 1) + key_states = repeat_kv(key_states, repeats) + value_states = repeat_kv(value_states, repeats) + + # (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2) + key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2) + value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2) + + # TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate + # this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly. + # https://github.com/huggingface/transformers/pull/33932 + + # (bsz, seq_len/n) -> (bsz, seq_len) + position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)] + torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group()) + position_ids = torch.concat(position_ids_list, dim=-1) + + # (bsz, seq_len, n_head/n, head_dim) + attn_output = _flash_attention_forward(query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs) + + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim) + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + + return attn_output + + +def apply_monkey_patch(model: PreTrainedModel, ulysses_sp_size: int): + """Replace _flash_attention_forward to _ulysses_flash_attention_forward""" + module = sys.modules[model.__module__] + + num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads + assert num_attention_heads % ulysses_sp_size == 0, f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}" + assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, ( + f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0,kv heads are repeated to ensure correctness." + ) + # TODO: VLM models only, unify monkey patch to LLM models. + if model.config.model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2 + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 + + from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward + + Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward + Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward + print("Monkey patch FlashAttention2.forward in Qwen2VL") + return + + # transformers<=4.47.1 + if hasattr(module, "_flash_attention_forward"): + module._flash_attention_forward = _ulysses_flash_attention_forward + print(f"Monkey patch _flash_attention_forward in {model.__module__}") + else: + # transformers>=4.48.0 + from transformers.integrations import flash_attention + + flash_attention._flash_attention_forward = _ulysses_flash_attention_forward + print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}") + + +@lru_cache +def is_transformers_version_in_range(min_version: str, max_version: str) -> bool: + try: + # Get the installed version of the transformers library + transformers_version = importlib.metadata.version("transformers") + except importlib.metadata.PackageNotFoundError as e: + raise ModuleNotFoundError("The `transformers` package is not installed.") from e + + # Check if the version is within the specified range + return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version) diff --git a/verl/models/transformers/qwen2.py b/verl/models/transformers/qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..98c8c88fb985239359084f6fb1dead21bf74a066 --- /dev/null +++ b/verl/models/transformers/qwen2.py @@ -0,0 +1,225 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Tuple + +import torch +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from transformers.utils import logging + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + +logger = logging.get_logger(__name__) + + +def qwen2_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 +): + """ + Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. + + NOTE: This function is only tested on transformers versions between 4.45.0 and 4.47.1. + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) + + # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) # full seq length + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once(f"The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in {target_dtype}.") + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers: + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + full_q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + # use full_q_len to reshape + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def qwen2_attn_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. + + NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + bsz, q_len, _ = hidden_states.shape + hidden_shape = (bsz, q_len, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + ########## AlltoAll for Ulysses ########## + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size) + + # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + + full_q_len = query_states.size(2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + sliding_window = None + if self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers: + sliding_window = self.config.sliding_window + + from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once('`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.') + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim) + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..7e61d86e3eeb160657f0d48774fda9abf575e80c --- /dev/null +++ b/verl/models/transformers/qwen2_vl.py @@ -0,0 +1,287 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from typing import Optional, Tuple + +import torch +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.utils import is_flash_attn_greater_or_equal + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, + validate_ulysses_config, +) + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +except ImportError: + flash_attn_varlen_func = None + + +def get_rope_index( + processor, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence. + The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. + https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546 + """ + spatial_merge_size = processor.image_processor.merge_size + tokens_per_second = 2 + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") + vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) + image_index, video_index = 0, 0 + input_ids = input_ids[attention_mask == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + second_per_grid_t = second_per_grid_ts[video_index] if second_per_grid_ts is not None else 1.0 + + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) + t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) + else: + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) + + return position_ids + + +def prepare_fa2_from_position_ids(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor): + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + cu_seqlens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope + return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length)) + + +def flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool = True, + position_ids: Optional[torch.Tensor] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + deterministic: Optional[bool] = None, + **kwargs, +): + """ + Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length) + """ + causal = is_causal if not use_top_left_mask else is_causal and query_length != 1 + + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + use_sliding_windows = _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + + if is_flash_attn_greater_or_equal("2.4.1"): + if deterministic is None: + deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = deterministic + + if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all(): + batch_size = query_states.size(0) + query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids[0]) # remove channel dimension + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=kwargs.pop("dropout", 0.0), + softmax_scale=kwargs.pop("softmax_scale", None), + causal=causal, + **flash_kwargs, + ) + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + is_causal=is_causal, + sliding_window=sliding_window, + use_top_left_mask=use_top_left_mask, + deterministic=deterministic, + **kwargs, + ) # do not pass position_ids to old flash_attention_forward + + return attn_output + + +def ulysses_flash_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> Tuple[torch.Tensor, None, None]: + from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv + + bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size + query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + # (batch_size, num_head / sp_size, seq_length, head_size) + full_q_len = query_states.size(2) # full_q_len = seq_length + else: + full_q_len = q_len + + # Because the input can be padded, the absolute sequence length depends on the max position id. + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_multimodal_rotary_pos_emb(query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers: + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + full_q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + position_ids=position_ids, # important: pass position ids + ) # (batch_size, seq_length, num_head / sp_size, head_size) + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, None, None diff --git a/verl/models/weight_loader_registry.py b/verl/models/weight_loader_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..94b4d442d20e411e814d855b0cdf9b30212f05be --- /dev/null +++ b/verl/models/weight_loader_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def get_weight_loader(arch: str): + from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel + + _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = { + "LlamaForCausalLM": load_state_dict_to_megatron_gptmodel, + "Qwen2ForCausalLM": load_state_dict_to_megatron_gptmodel, + } + + if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: + return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch] + raise ValueError(f"Model architectures {arch} loader are not supported for now. Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}") + + +def get_weight_saver(arch: str): + from verl.models.mcore.saver import merge_megatron_ckpt_gptmodel, merge_megatron_ckpt_gptmodel_mixtral, merge_megatron_ckpt_gptmodel_qwen_moe + + _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = { + "LlamaForCausalLM": merge_megatron_ckpt_gptmodel, + "Qwen2ForCausalLM": merge_megatron_ckpt_gptmodel, + "MixtralForCausalLM": merge_megatron_ckpt_gptmodel_mixtral, + "Qwen2MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, + "Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel, + "Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, + } + if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY: + return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch] + raise ValueError(f"Model architectures {arch} saver are not supported for now. Supported architectures: {_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}") diff --git a/verl/protocol.py b/verl/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..f6fef3ec8700af08c21252e240b575b8a85d0a21 --- /dev/null +++ b/verl/protocol.py @@ -0,0 +1,811 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implement base data transfer protocol between any two functions, modules. +We can subclass Protocol to define more detailed batch info with specific keys +""" + +import contextlib +import copy +import logging +import os +import pickle +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Union + +import numpy as np +import pandas as pd +import ray +import tensordict +import torch +import torch.distributed +from packaging import version +from tensordict import TensorDict +from torch.utils.data import DataLoader + +from verl.utils.py_functional import union_two_dict +from verl.utils.torch_functional import allgather_dict_tensors + +__all__ = ["DataProto", "union_tensor_dict"] + +with contextlib.suppress(Exception): + tensordict.set_lazy_legacy(False).set() + + +class _DataProtoConfigMeta(type): + _config = {} + + auto_padding_key = "_verl_auto_padding" + + @property + def auto_padding(cls): + enabled_by_env = os.getenv("VERL_AUTO_PADDING", "FALSE").upper() in ["TRUE", "1"] + return enabled_by_env or cls._config.get(cls.auto_padding_key, False) + + @auto_padding.setter + def auto_padding(cls, enabled: bool): + assert isinstance(enabled, bool), f"enabled must be a boolean, got {enabled} as {type(enabled)}" + cls._config[cls.auto_padding_key] = enabled + + +class DataProtoConfig(metaclass=_DataProtoConfigMeta): + pass + + +_padding_size_key = "_padding_size_key_x123d" + + +def pad_dataproto_to_divisor(data: "DataProto", size_divisor: int): + """Pad a DataProto to size divisible by size_divisor + + Args: + size_divisor (int): size divisor + + Returns: + data: (DataProto): the padded DataProto + pad_size (int) + """ + assert isinstance(data, DataProto), "data must be a DataProto" + if len(data) % size_divisor != 0: + pad_size = size_divisor - len(data) % size_divisor + padding_protos = [] + remaining_pad = pad_size + while remaining_pad > 0: + take_size = min(remaining_pad, len(data)) + padding_protos.append(data[:take_size]) + remaining_pad -= take_size + data_padded = DataProto.concat([data] + padding_protos) + else: + if len(data) == 0: + logging.warning("padding a DataProto with no item, no changed made") + pad_size = 0 + data_padded = data + return data_padded, pad_size + + +def unpad_dataproto(data: "DataProto", pad_size): + if pad_size != 0: + data = data[:-pad_size] + return data + + +def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: + """Union two tensordicts.""" + assert tensor_dict1.batch_size == tensor_dict2.batch_size, f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" + for key in tensor_dict2.keys(): + if key not in tensor_dict1.keys(): + tensor_dict1[key] = tensor_dict2[key] + else: + assert tensor_dict1[key].equal(tensor_dict2[key]), f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + + return tensor_dict1 + + +def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]: + for key, val in tensor_dict2.items(): + if key in tensor_dict1: + assert isinstance(tensor_dict2[key], np.ndarray) + assert isinstance(tensor_dict1[key], np.ndarray) + # to properly deal with nan and object type + assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + tensor_dict1[key] = val + + return tensor_dict1 + + +def list_of_dict_to_dict_of_list(list_of_dict: list[dict]): + if len(list_of_dict) == 0: + return {} + keys = list_of_dict[0].keys() + output = {key: [] for key in keys} + for data in list_of_dict: + for key, item in data.items(): + assert key in output + output[key].append(item) + return output + + +def fold_batch_dim(data: "DataProto", new_batch_size): + """ + Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx] + """ + batch_size = data.batch.batch_size[0] + + assert batch_size % new_batch_size == 0 + + tensor: TensorDict = data.batch + non_tensor = data.non_tensor_batch + + tensor = tensor.view(new_batch_size, -1) + tensor.auto_batch_size_(batch_dims=1) + + for key, val in non_tensor.items(): + non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:])) + + return type(data)(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info) + + +def unfold_batch_dim(data: "DataProto", batch_dims=2): + """ + Unfold the first n dims as new batch dim + """ + tensor: TensorDict = data.batch + non_tensor = data.non_tensor_batch + tensor.auto_batch_size_(batch_dims=batch_dims) + tensor = tensor.view(-1) + + batch_size = tensor.batch_size[0] + + non_tensor_new = {} + + for key, val in non_tensor.items(): + non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:])) + + return type(data)(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info) + + +def collate_fn(x: list["DataProtoItem"]): + batch = [] + non_tensor_batch = [] + for data in x: + batch.append(data.batch) + non_tensor_batch.append(data.non_tensor_batch) + batch = torch.stack(batch).contiguous() + non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch) + for key, val in non_tensor_batch.items(): + non_tensor_batch[key] = np.array(val, dtype=object) + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) + + +@dataclass +class DataProtoItem: + # TODO(zhangchi.usc1992) add consistency check + batch: TensorDict = None + non_tensor_batch: Dict = field(default_factory=dict) + meta_info: Dict = field(default_factory=dict) + + +@dataclass +class DataProto: + """ + A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. + It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. + TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the + same batch size should be put inside batch. + """ + + batch: TensorDict = None + non_tensor_batch: Dict = field(default_factory=dict) + meta_info: Dict = field(default_factory=dict) + + def __post_init__(self): + # perform necessary checking + self.check_consistency() + + def __len__(self): + if self.batch is not None: + return self.batch.batch_size[0] + elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0: + random_key = list(self.non_tensor_batch.keys())[0] + return self.non_tensor_batch[random_key].shape[0] + else: + return 0 + + def __getitem__(self, item): + """ + Enhanced indexing for DataProto objects. + + Args: + item: Can be one of: + - int: A single index + - slice: A slice object (start:stop:step) + - list: A list of indices + - numpy.ndarray: An array of indices + - torch.Tensor: A tensor of indices + + Returns: + DataProto: For all indexing types except single integers + DataProtoItem: Only for single integer indices + """ + # Case 1: Slice object - use the slice method + if isinstance(item, slice): + return self.slice(item.start, item.stop, item.step) + + # Case 2: List, numpy array, or torch tensor - use sel_idxs + elif isinstance(item, (list, np.ndarray, torch.Tensor)): + return self.select_idxs(item) + + # Case 3: Single integer - return DataProtoItem for backward compatibility + elif isinstance(item, (int, np.integer)): + tensor_data = self.batch[item] + non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} + return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) + + # # Case 4: Unsupported type + else: + raise TypeError(f"Indexing with {type(item)} is not supported") + + def __getstate__(self): + import io + + buffer = io.BytesIO() + if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: + self.batch = self.batch.contiguous() + self.batch = self.batch.consolidate() + torch.save(self.batch, buffer) + buffer_bytes = buffer.getvalue() + return buffer_bytes, self.non_tensor_batch, self.meta_info + + def __setstate__(self, data): + import io + + batch_deserialized_bytes, non_tensor_batch, meta_info = data + batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) + batch = torch.load(batch_deserialized, weights_only=False, map_location="cpu" if not torch.cuda.is_available() else None) + self.batch = batch + self.non_tensor_batch = non_tensor_batch + self.meta_info = meta_info + + def save_to_disk(self, filepath): + with open(filepath, "wb") as f: + pickle.dump(self, f) + + @staticmethod + def load_from_disk(filepath) -> "DataProto": + with open(filepath, "rb") as f: + data = pickle.load(f) + return data + + def print_size(self, prefix=""): + size_of_tensordict = 0 + for key, tensor in self.batch.items(): + size_of_tensordict += tensor.element_size() * tensor.numel() + size_of_numpy_array = 0 + for key, numpy_array in self.non_tensor_batch.items(): + size_of_numpy_array += numpy_array.nbytes + + size_of_numpy_array /= 1024**3 + size_of_tensordict /= 1024**3 + + message = f"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB" + + if prefix: + message = f"{prefix}, " + message + print(message) + + def check_consistency(self): + """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch + We expose this function as a public one so that user can call themselves directly + """ + if self.batch is not None: + assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1" + + if self.non_tensor_batch is not None: + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray) + + if self.batch is not None and len(self.non_tensor_batch) != 0: + # TODO: we can actually lift this restriction if needed + assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty." + + batch_size = self.batch.batch_size[0] + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray), f"data in the non_tensor_batch must be a numpy.array with dtype=object, but for {key=}, got {type(val)=}" + assert val.shape[0] == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}" + + @classmethod + def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None, auto_padding=False): + """Create a DataProto from a dict of tensors and non_tensors""" + tensors = {} + non_tensors = {} + + for key, val in data.items(): + if isinstance(val, torch.Tensor): + tensors[key] = val + elif isinstance(val, np.ndarray): + non_tensors[key] = val + else: + raise ValueError(f"Unsupported type in data {type(val)}") + + return cls.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info, auto_padding=auto_padding) + + @classmethod + def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None, meta_info=None, num_batch_dims=1, auto_padding=False): + """Create a DataProto from a dict of tensors. This assumes that + 1. All the tensor in tensors have the same dim0 + 2. Only dim0 is the batch dim + """ + assert len(tensors) > 0, "tensors must not be empty" + assert num_batch_dims > 0, "num_batch_dims must be greater than zero" + if non_tensors is not None: + assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None." + + if meta_info is None: + meta_info = {} + if non_tensors is None: + non_tensors = {} + + assert isinstance(non_tensors, dict) + + # get and check batch size + batch_size = None + pivot_key = None + for key, tensor in tensors.items(): + if batch_size is None: + batch_size = tensor.shape[:num_batch_dims] + pivot_key = key + else: + current_batch = tensor.shape[:num_batch_dims] + assert batch_size == current_batch, f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}" + + for key, val in non_tensors.items(): + non_tensors[key] = np.array(val, dtype=object) + + tensor_dict = TensorDict(source=tensors, batch_size=batch_size) + if auto_padding: + meta_info[DataProtoConfig.auto_padding_key] = True + return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info) + + def to(self, device) -> "DataProto": + """move the batch to device + + Args: + device (torch.device, str): torch device + + Returns: + DataProto: the current DataProto + + """ + if self.batch is not None: + self.batch = self.batch.to(device) + return self + + def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> "DataProto": + """Select a subset of the DataProto via batch_keys and meta_info_keys + + Args: + batch_keys (list, optional): a list of strings indicating the keys in batch to select + meta_info_keys (list, optional): a list of keys indicating the meta info to select + + Returns: + DataProto: the DataProto with the selected batch_keys and meta_info_keys + """ + # TODO (zhangchi.usc1992) whether to copy + if batch_keys is not None: + batch_keys = tuple(batch_keys) + sub_batch = self.batch.select(*batch_keys) + else: + sub_batch = self.batch + + if non_tensor_batch_keys is not None: + non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys} + else: + non_tensor_batch = self.non_tensor_batch + + if deepcopy: + non_tensor_batch = copy.deepcopy(non_tensor_batch) + + if meta_info_keys is not None: + sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys} + else: + sub_meta_info = self.meta_info + + if deepcopy: + sub_meta_info = copy.deepcopy(sub_meta_info) + + return type(self)(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) + + def select_idxs(self, idxs): + """ + Select specific indices from the DataProto. + + Args: + idxs (torch.Tensor or numpy.ndarray or list): Indices to select + + Returns: + DataProto: A new DataProto containing only the selected indices + """ + if isinstance(idxs, list): + idxs = torch.tensor(idxs) + if idxs.dtype != torch.bool: + idxs = idxs.type(torch.int32) + + if isinstance(idxs, np.ndarray): + idxs_np = idxs + idxs_torch = torch.from_numpy(idxs) + else: # torch.Tensor + idxs_torch = idxs + idxs_np = idxs.detach().cpu().numpy() + + batch_size = idxs_np.sum() if idxs_np.dtype == bool else idxs_np.shape[0] + + if self.batch is not None: + # Use TensorDict's built-in indexing capabilities + selected_batch = TensorDict(source={key: tensor[idxs_torch] for key, tensor in self.batch.items()}, batch_size=(batch_size,)) + else: + selected_batch = None + + selected_non_tensor = {} + for key, val in self.non_tensor_batch.items(): + selected_non_tensor[key] = val[idxs_np] + + return type(self)(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info) + + def slice(self, start=None, end=None, step=None): + """ + Slice the DataProto and return a new DataProto object. + This is an improved version of direct slicing which returns a DataProtoItem. + + Args: + start (int, optional): Start index. Defaults to None (start from beginning). + end (int, optional): End index (exclusive). Defaults to None (go to end). + step (int, optional): Step size. Defaults to None (step=1). + + Returns: + DataProto: A new DataProto containing the sliced data + + Examples: + # Using the slice method directly + sliced_data = data_proto.slice(10, 20) + + # Using enhanced indexing (returns DataProto) + sliced_data = data_proto[10:20] + sliced_data = data_proto[::2] # Every other element + + # Using list indexing (returns DataProto) + indices = [1, 5, 10] + selected_data = data_proto[indices] + + # Single index still returns DataProtoItem + single_item = data_proto[5] + """ + # Create a slice object + slice_obj = slice(start, end, step) + + # Handle the batch data + if self.batch is not None: + # Use TensorDict's built-in slicing capabilities + sliced_batch = self.batch[slice_obj] + else: + sliced_batch = None + + # Handle the non-tensor batch data + sliced_non_tensor = {} + for key, val in self.non_tensor_batch.items(): + sliced_non_tensor[key] = val[slice_obj] + + # Return a new DataProto object + return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info) + + def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto": + """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` + + Args: + batch_keys (list, optional): a list of strings indicating the keys in batch to pop + meta_info_keys (list, optional): a list of keys indicating the meta info to pop + + Returns: + DataProto: the DataProto with the poped batch_keys and meta_info_keys + """ + assert batch_keys is not None + if meta_info_keys is None: + meta_info_keys = [] + if non_tensor_batch_keys is None: + non_tensor_batch_keys = [] + + tensors = {} + # tensor batch + for key in batch_keys: + assert key in self.batch.keys() + tensors[key] = self.batch.pop(key) + non_tensors = {} + # non tensor batch + for key in non_tensor_batch_keys: + assert key in self.non_tensor_batch.keys() + non_tensors[key] = self.non_tensor_batch.pop(key) + meta_info = {} + for key in meta_info_keys: + assert key in self.meta_info.keys() + meta_info[key] = self.meta_info.pop(key) + return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) + + def rename(self, old_keys=None, new_keys=None) -> "DataProto": + """ + Note that this function only rename the key in the batch + """ + + def validate_input(keys): + if keys is not None: + if isinstance(keys, str): + keys = [keys] + elif isinstance(keys, list): + pass + else: + raise TypeError(f"keys must be a list or a string, but got {type(keys)}") + return keys + + old_keys = validate_input(old_keys) + new_keys = validate_input(new_keys) + + if len(new_keys) != len(old_keys): + raise ValueError(f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}") + + self.batch.rename_key_(tuple(old_keys), tuple(new_keys)) + + return self + + def union(self, other: "DataProto") -> "DataProto": + """Union with another DataProto. Union batch and meta_info separately. + Throw an error if + + - there are conflict keys in batch and they are not equal + - the batch size of two data batch is not the same + - there are conflict keys in meta_info and they are not the same. + + Args: + other (DataProto): another DataProto to union + + Returns: + DataProto: the DataProto after union + """ + self.batch = union_tensor_dict(self.batch, other.batch) + self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch) + self.meta_info = union_two_dict(self.meta_info, other.meta_info) + return self + + def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None): + r"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch + dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details. + + + Args: + mini_batch_size (int): mini-batch size when iterating the dataset. We require that ``batch.batch_size[0] % mini_batch_size == 0``. + epochs (int): number of epochs when iterating the dataset. + dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The dataloader_kwargs is the kwargs passed to the DataLoader. + + Returns: + Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is ``self.batch.batch_size * epochs // mini_batch_size`` + """ + assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" + # we can directly create a dataloader from TensorDict + if dataloader_kwargs is None: + dataloader_kwargs = {} + + if seed is not None: + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = None + + assert isinstance(dataloader_kwargs, Dict) + train_dataloader = DataLoader(dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs) + + def get_data(): + for _ in range(epochs): + for d in train_dataloader: + d.meta_info = self.meta_info + yield d + + return iter(get_data()) + + def is_padding_enabled(self): + """ + Check if padding is enabled for the DataProto. + Returns: + bool: True if padding is enabled, False otherwise. + """ + dataproto_specific_padding = self.meta_info.get(DataProtoConfig.auto_padding_key, False) + return dataproto_specific_padding or DataProtoConfig.auto_padding + + def padding(self, padding_size, padding_candidate=""): + """Pad the DataProto by concating with padding_candidate.repeat(padding_size) + + Args: + padding_size (int): the number of repeated padding_candidate + padding_candidate: the item to be repeated and appended to the DataProto, only supporting ["first", "last"] + """ + if padding_size == 0: + return + padding_candidate = self.select_idxs([0 if padding_candidate == "first" else len(self) - 1]) + padding_part = padding_candidate.repeat(padding_size) + padded_dp = DataProto.concat([self, padding_part]) + self.batch = padded_dp.batch + self.non_tensor_batch = padded_dp.non_tensor_batch + + def chunk(self, chunks: int) -> List["DataProto"]: + """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. + + Args: + chunks (int): the number of chunks to split on dim=0 + + Returns: + List[DataProto]: a list of DataProto after splitting + """ + if not self.is_padding_enabled(): + assert len(self) % chunks == 0, f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." + + bsz_in_batch = None + if self.batch is not None: + batch_lst = self.batch.chunk(chunks=chunks, dim=0) + bsz_in_batch = np.array([batch.batch_size[0] for batch in batch_lst]) + chunk_indices = np.cumsum(bsz_in_batch)[:-1] + else: + batch_lst = [None for _ in range(chunks)] + + non_tensor_batch_lst = [{} for _ in range(chunks)] + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray) + if bsz_in_batch is not None: + non_tensor_lst = np.array_split(val, chunk_indices.tolist()) + else: + non_tensor_lst = np.array_split(val, chunks) + assert len(non_tensor_lst) == chunks + for i in range(chunks): + non_tensor_batch_lst[i][key] = non_tensor_lst[i] + + output = [] + for i in range(chunks): + output.append(type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)) + + return output + + @staticmethod + def concat(data: List["DataProto"]) -> "DataProto": + """Concat a list of DataProto. The batch is concatenated among dim=0. + The meta_info is assumed to be identical and will use the first one. + + Args: + data (List[DataProto]): list of DataProto + + Returns: + DataProto: concatenated DataProto + """ + batch_lst = [] + for batch in data: + batch_lst.append(batch.batch) + new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None + + non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data]) + for key, val in non_tensor_batch.items(): + non_tensor_batch[key] = np.concatenate(val, axis=0) + + cls = type(data[0]) if len(data) > 0 else DataProto + return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) + + def reorder(self, indices): + """ + Note that this operation is in-place + """ + indices_np = indices.detach().numpy() + self.batch = self.batch[indices] + self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()} + + def repeat(self, repeat_times=2, interleave=True): + """ + Repeat the batch data a specified number of times. + + Args: + repeat_times (int): Number of times to repeat the data. + interleave (bool): Whether to interleave the repeated data. + + Returns: + DataProto: A new DataProto with repeated data. + """ + if self.batch is not None: + if interleave: + # Interleave the data + repeated_tensors = {key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()} + else: + # Stack the data + repeated_tensors = {key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) for key, tensor in self.batch.items()} + + repeated_batch = TensorDict( + source=repeated_tensors, + batch_size=(self.batch.batch_size[0] * repeat_times,), + ) + else: + repeated_batch = None + + repeated_non_tensor_batch = {} + for key, val in self.non_tensor_batch.items(): + if interleave: + repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) + else: + repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) + + return type(self)( + batch=repeated_batch, + non_tensor_batch=repeated_non_tensor_batch, + meta_info=self.meta_info, + ) + + +@dataclass +class DataProtoFuture: + """ + DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait + for data so that asynchronous execution becomes possible. + DataProtoFuture contains a list of futures from another WorkerGroup of size world_size. + - collect_fn is a Callable that reduces the list of futures to a DataProto + - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select + + Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination + - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any + operation on the DataProtoFuture in driver. + """ + + collect_fn: Callable + futures: List[ray.ObjectRef] + dispatch_fn: Callable = None + + @staticmethod + def concat(data: List[ray.ObjectRef]) -> "DataProtoFuture": + output = DataProtoFuture(collect_fn=DataProto.concat, futures=data) + return output + + def chunk(self, chunks: int) -> List["DataProtoFuture"]: + from functools import partial + + arg_future_lst = [] + for i in range(chunks): + # note that we can't directly pass i and chunks + def dispatch_fn(x, i, chunks): + return x.chunk(chunks=chunks)[i] + + arg_future = DataProtoFuture(collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures) + arg_future_lst.append(arg_future) + return arg_future_lst + + def get(self): + output = ray.get(self.futures) # dp_size. + for o in output: + assert isinstance(o, DataProto) + output = self.collect_fn(output) # select dp, concat + if self.dispatch_fn is not None: + output = self.dispatch_fn(output) # split in batch dim, select using dp + return output + + +def all_gather_data_proto(data: DataProto, process_group): + # Note that this is an inplace operator just like torch.distributed.all_gather + group_size = torch.distributed.get_world_size(group=process_group) + assert isinstance(data, DataProto) + prev_device = data.batch.device + data.batch = data.batch.cuda(device=torch.cuda.current_device()) + data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0) + data.batch = data.batch.to(prev_device) + # all gather non_tensor_batch + all_non_tensor_batch = [None for _ in range(group_size)] + torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=process_group) + data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch} diff --git a/verl/single_controller/__init__.py b/verl/single_controller/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1129c7147b3e6482f78bea7e2896dd0ddfb18a02 --- /dev/null +++ b/verl/single_controller/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from . import base +from .base import * + +version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) + +# Note(haibin.lin): single_controller.__version__ is deprecated +with open(os.path.join(os.path.join(version_folder, os.pardir), "version/version")) as f: + __version__ = f.read().strip() + + +__all__ = base.__all__ diff --git a/verl/single_controller/__pycache__/__init__.cpython-311.pyc b/verl/single_controller/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6154ffb0bee833defe2ab704039090f031a7eb21 Binary files /dev/null and b/verl/single_controller/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/single_controller/base/__init__.py b/verl/single_controller/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8ae40b349d94b4900b0280bf59558ccbd67a479 --- /dev/null +++ b/verl/single_controller/base/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .worker import Worker +from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup + +__all__ = ["Worker", "WorkerGroup", "ClassWithInitArgs", "ResourcePool"] diff --git a/verl/single_controller/base/__pycache__/__init__.cpython-311.pyc b/verl/single_controller/base/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cf98529d2be806dbcc35ce492b96a3a5b4a2c98 Binary files /dev/null and b/verl/single_controller/base/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/single_controller/base/__pycache__/decorator.cpython-311.pyc b/verl/single_controller/base/__pycache__/decorator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56e466a269c991a13168b0578181ca76bf6f7e39 Binary files /dev/null and b/verl/single_controller/base/__pycache__/decorator.cpython-311.pyc differ diff --git a/verl/single_controller/base/__pycache__/worker.cpython-311.pyc b/verl/single_controller/base/__pycache__/worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..203cca9419e1c86798bb1d5b598c69e65ed61e0f Binary files /dev/null and b/verl/single_controller/base/__pycache__/worker.cpython-311.pyc differ diff --git a/verl/single_controller/base/__pycache__/worker_group.cpython-311.pyc b/verl/single_controller/base/__pycache__/worker_group.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2682aa21579aa04c753f9815718d1bc5016374da Binary files /dev/null and b/verl/single_controller/base/__pycache__/worker_group.cpython-311.pyc differ diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..0216e369afca420e674bc8d93443d3af62d18c80 --- /dev/null +++ b/verl/single_controller/base/decorator.py @@ -0,0 +1,514 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from functools import wraps +from types import FunctionType +from typing import Dict, List, Tuple + +from verl.protocol import DataProtoFuture, _padding_size_key +from verl.utils.py_functional import DynamicEnum + +# here we add a magic number of avoid user-defined function already have this attribute +MAGIC_ATTR = "attrs_3141562937" + + +class Dispatch(DynamicEnum): + _registry = {} + _next_value = 0 + + +def init_predefined_dispatch_mode(): + Dispatch.register("RANK_ZERO") + Dispatch.register("ONE_TO_ALL") + Dispatch.register("ALL_TO_ALL") + Dispatch.register("MEGATRON_COMPUTE") + Dispatch.register("MEGATRON_PP_AS_DP") + Dispatch.register("MEGATRON_PP_ONLY") + Dispatch.register("MEGATRON_COMPUTE_PROTO") + Dispatch.register("MEGATRON_PP_AS_DP_PROTO") + Dispatch.register("DP_COMPUTE") + Dispatch.register("DP_COMPUTE_PROTO") + Dispatch.register("DP_COMPUTE_PROTO_WITH_FUNC") + Dispatch.register("DP_COMPUTE_METRIC") + # This is a special dispatch mode for vllm ExternalRayDistributedExecutor + Dispatch.register("DIRECT_ROLLOUT_METHOD") + + +class Execute(DynamicEnum): + _registry = {} + _next_value = 0 + + +def init_predefined_execute_mode(): + Execute.register("ALL") + Execute.register("RANK_ZERO") + + +# Initialize the two Dynamic Enum Classes +init_predefined_dispatch_mode() +init_predefined_execute_mode() + + +def _split_args_kwargs_data_proto(chunks, *args, **kwargs): + from verl.protocol import DataProto, DataProtoFuture + + splitted_args = [] + for arg in args: + assert isinstance(arg, (DataProto, DataProtoFuture)) + splitted_args.append(arg.chunk(chunks=chunks)) + + splitted_kwargs = {} + for key, val in kwargs.items(): + assert isinstance(val, (DataProto, DataProtoFuture)) + splitted_kwargs[key] = val.chunk(chunks=chunks) + + return splitted_args, splitted_kwargs + + +def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs): + from verl.protocol import DataProto, DataProtoFuture + + splitted_args = [] + splitted_kwargs = {} + + data_proto_len = None + padding_size = None + for arg in args: + assert isinstance(arg, (DataProto, DataProtoFuture)) + if isinstance(arg, DataProto) and arg.is_padding_enabled(): + # for padding, we only support DataProto with same length + if data_proto_len is None: + data_proto_len = len(arg) + padding_size = (chunks - (data_proto_len % chunks)) if (data_proto_len % chunks > 0) else 0 + splitted_kwargs[_padding_size_key] = padding_size + else: + assert data_proto_len == len(arg), f"expecting all arg share same length of {data_proto_len}, but got {len(arg)}" + data_proto_len = len(arg) + arg.padding(padding_size=padding_size) + + splitted_args.append(arg.chunk(chunks=chunks)) + + for key, val in kwargs.items(): + assert isinstance(val, (DataProto, DataProtoFuture)) + if isinstance(val, DataProto) and val.is_padding_enabled(): + # for padding, we only support DataProto with same length + if data_proto_len is None: + data_proto_len = len(val) + padding_size = chunks - (data_proto_len % chunks) + splitted_kwargs[_padding_size_key] = padding_size + else: + assert data_proto_len == len(val), f"expecting all arg share same length of {data_proto_len}, but got {len(val)}" + data_proto_len = len(val) + splitted_kwargs[key] = val.chunk(chunks=chunks) + + return splitted_args, splitted_kwargs + + +def dispatch_one_to_all(worker_group, *args, **kwargs): + args = tuple([arg] * worker_group.world_size for arg in args) + kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} + return args, kwargs + + +def dummy_direct_rollout_call(worker_group, *args, **kwargs): + raise NotImplementedError("Direct rollout call is forbidden.") + + +def dispatch_all_to_all(worker_group, *args, **kwargs): + return args, kwargs + + +def collect_all_to_all(worker_group, output): + return output + + +def dispatch_megatron_compute(worker_group, *args, **kwargs): + """ + User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup), f"worker_group must be MegatronWorkerGroup, Got {type(worker_group)}" + + all_args = [] + for arg in args: + assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.dp_size + transformed_args = [] + for i in range(worker_group.world_size): + local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank + transformed_args.append(arg[local_dp_rank]) + all_args.append(transformed_args) + all_args = tuple(all_args) + + all_kwargs = {} + for k, v in kwargs.items(): + assert isinstance(v, (Tuple, List)) and len(v) == worker_group.dp_size + transformed_v = [] + for i in range(worker_group.world_size): + local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank + transformed_v.append(v[local_dp_rank]) + all_kwargs[k] = transformed_v + return all_args, all_kwargs + + +def collect_megatron_compute(worker_group, output): + """ + Only collect the data from the tp=0 and pp=last and every dp ranks + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + output_in_dp = [] + pp_size = worker_group.get_megatron_global_info().pp_size + for global_rank in range(worker_group.world_size): + local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) + if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1 and local_rank_info.cp_rank == 0: + output_in_dp.append(output[global_rank]) + return output_in_dp + + +def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs): + """ + All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs) + return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs) + + +def _concat_data_proto_or_future(output: List): + import ray + + from verl.protocol import DataProto, DataProtoFuture + + # make sure all the elements in output has the same type + for o in output: + assert type(o) is type(output[0]) + + o = output[0] + + if isinstance(o, DataProto): + return DataProto.concat(output) + elif isinstance(o, ray.ObjectRef): + return DataProtoFuture.concat(output) + else: + raise NotImplementedError + + +def collect_megatron_compute_data_proto(worker_group, output): + """ + Each output must be a DataProto. We concat the dim=0 of output + """ + import ray + + from verl.protocol import DataProto + + output = collect_megatron_compute(worker_group, output) + for o in output: + assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}" + + return _concat_data_proto_or_future(output) + + +def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs): + """ + treat pp as dp. + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + + pp_size = worker_group.pp_size + dp_size = worker_group.dp_size + cp_size = worker_group.cp_size + pp_dp_cp_size = pp_size * dp_size * cp_size + + all_args = [] + for arg in args: + assert isinstance(arg, (List, Tuple)) and len(arg) == pp_dp_cp_size + transformed_args = [] + for i in range(worker_group.world_size): + local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank + local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank + local_cp_rank = worker_group.get_megatron_rank_info(rank=i).cp_rank + # compute the rank in arg. Note that the order is dp then cp then pp + # Also note that the outputs within a pp group will be firstly allgathered, then only the output of pp0 will be collected. + # For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order: + # dispatch: pp_allgther: collect: + # dp 0 1 2 3 dp 0 1 2 3 + # pp +---------+ pp +-------------+ + # 0 | A C E G | 0 | AB CD EF GH | ABCDEFGH + # 1 | B D F H | 1 | AB CD EF GH | + # +---------+ +-------------+ + dp_cp_rank = local_cp_rank * dp_size + local_dp_rank + arg_rank = dp_cp_rank * pp_size + local_pp_rank + + transformed_args.append(arg[arg_rank]) + all_args.append(transformed_args) + all_args = tuple(all_args) + + all_kwargs = {} + for k, v in kwargs.items(): + assert isinstance(v, (List, Tuple)) and len(v) == pp_dp_cp_size, f"expect len(v)=={pp_dp_cp_size}, got {len(v)}" + transformed_v = [] + for i in range(worker_group.world_size): + local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank + local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank + local_cp_rank = worker_group.get_megatron_rank_info(rank=i).cp_rank + # compute the rank in arg. Note that the order is dp then cp then pp + dp_cp_rank = local_cp_rank * dp_size + local_dp_rank + arg_rank = dp_cp_rank * pp_size + local_pp_rank + transformed_v.append(v[arg_rank]) + all_kwargs[k] = transformed_v + return all_args, all_kwargs + + +def collect_megatron_pp_as_dp(worker_group, output): + """ + treat pp as dp. Only collect data on tp=0 + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + output_in_dp = [] + for global_rank in range(worker_group.world_size): + local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) + if local_rank_info.tp_rank == 0: + output_in_dp.append(output[global_rank]) + return output_in_dp + + +def collect_megatron_pp_only(worker_group, output): + """ + Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp + """ + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + output_in_pp = [] + for global_rank in range(worker_group.world_size): + local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) + if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0: + output_in_pp.append(output[global_rank]) + return output_in_pp + + +def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs): + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + + pp_dp_cp_size = worker_group.dp_size * worker_group.pp_size * worker_group.cp_size + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_cp_size, *args, **kwargs) + ret = dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs) + return ret + + +def collect_megatron_pp_as_dp_data_proto(worker_group, output): + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + + assert isinstance(worker_group, MegatronWorkerGroup) + + output = collect_megatron_pp_as_dp(worker_group, output) + return _concat_data_proto_or_future(output) + + +def dispatch_dp_compute(worker_group, *args, **kwargs): + from verl.single_controller.base.worker_group import WorkerGroup + + assert isinstance(worker_group, WorkerGroup) + for arg in args: + assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size + for k, v in kwargs.items(): + assert isinstance(v, (Tuple, List)) and len(v) == worker_group.world_size + return args, kwargs + + +def collect_dp_compute(worker_group, output): + from verl.single_controller.base.worker_group import WorkerGroup + + assert isinstance(worker_group, WorkerGroup) + assert len(output) == worker_group.world_size + return output + + +def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): + from verl.single_controller.base.worker_group import WorkerGroup + + assert isinstance(worker_group, WorkerGroup) + # Note: enable auto padding for dp compute DatapProto + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding( + worker_group.world_size, + *args, + **kwargs, + ) + return splitted_args, splitted_kwargs + + +def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): + from verl.single_controller.base.worker_group import WorkerGroup + + assert isinstance(worker_group, WorkerGroup) + assert isinstance(args[0], FunctionType) # NOTE: The first one args is a function! + + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) + splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args + return splitted_args_with_func, splitted_kwargs + + +def collect_dp_compute_data_proto(worker_group, output): + import ray + + from verl.protocol import DataProto + + for o in output: + assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}" + + output = collect_dp_compute(worker_group, output) + return _concat_data_proto_or_future(output) + + +# Global registry for dispatch mode. +DISPATCH_MODE_FN_REGISTRY = { + Dispatch.ONE_TO_ALL: { + "dispatch_fn": dispatch_one_to_all, + "collect_fn": collect_all_to_all, + }, + Dispatch.ALL_TO_ALL: { + "dispatch_fn": dispatch_all_to_all, + "collect_fn": collect_all_to_all, + }, + Dispatch.MEGATRON_COMPUTE: { + "dispatch_fn": dispatch_megatron_compute, + "collect_fn": collect_megatron_compute, + }, + Dispatch.MEGATRON_PP_AS_DP: { + "dispatch_fn": dispatch_megatron_pp_as_dp, + "collect_fn": collect_megatron_pp_as_dp, + }, + Dispatch.MEGATRON_PP_ONLY: {"dispatch_fn": dispatch_one_to_all, "collect_fn": collect_megatron_pp_only}, + Dispatch.MEGATRON_COMPUTE_PROTO: { + "dispatch_fn": dispatch_megatron_compute_data_proto, + "collect_fn": collect_megatron_compute_data_proto, + }, + Dispatch.MEGATRON_PP_AS_DP_PROTO: { + "dispatch_fn": dispatch_megatron_pp_as_dp_data_proto, + "collect_fn": collect_megatron_pp_as_dp_data_proto, + }, + Dispatch.DP_COMPUTE: {"dispatch_fn": dispatch_dp_compute, "collect_fn": collect_dp_compute}, + Dispatch.DP_COMPUTE_PROTO: { + "dispatch_fn": dispatch_dp_compute_data_proto, + "collect_fn": collect_dp_compute_data_proto, + }, + Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: { + "dispatch_fn": dispatch_dp_compute_data_proto_with_func, + "collect_fn": collect_dp_compute_data_proto, + }, + Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute}, + Dispatch.DIRECT_ROLLOUT_METHOD: { + "dispatch_fn": dummy_direct_rollout_call, + "collect_fn": dummy_direct_rollout_call, + }, +} + + +def get_predefined_dispatch_fn(dispatch_mode): + return DISPATCH_MODE_FN_REGISTRY[dispatch_mode] + + +def register_dispatch_mode(dispatch_mode_name, dispatch_fn, collect_fn): + """ + Register a new dispatch mode. + """ + dispatch_mode = Dispatch.register(dispatch_mode_name) + _check_dispatch_mode(dispatch_mode) + assert dispatch_mode not in DISPATCH_MODE_FN_REGISTRY, f"dispatch_mode_name {dispatch_mode_name} already exists" + DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {"dispatch_fn": dispatch_fn, "collect_fn": collect_fn} + + +def update_dispatch_mode(dispatch_mode, dispatch_fn, collect_fn): + """ + Update the dispatch mode. + """ + _check_dispatch_mode(dispatch_mode) + assert dispatch_mode in DISPATCH_MODE_FN_REGISTRY, f"dispatch_mode {dispatch_mode} not found" + DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {"dispatch_fn": dispatch_fn, "collect_fn": collect_fn} + + +def get_predefined_execute_fn(execute_mode): + """ + Note that here we only asks execute_all and execute_rank_zero to be implemented + Leave the choice of how these two functions handle argument 'blocking' to users + """ + predefined_execute_mode_fn = { + Execute.ALL: {"execute_fn_name": "execute_all"}, + Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"}, + } + return predefined_execute_mode_fn[execute_mode] + + +def _check_dispatch_mode(dispatch_mode): + assert isinstance(dispatch_mode, (Dispatch, Dict)), f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" + if isinstance(dispatch_mode, Dict): + necessary_keys = ["dispatch_fn", "collect_fn"] + for key in necessary_keys: + assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" + + +def _check_execute_mode(execute_mode): + assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" + + +def _materialize_futures(*args, **kwargs): + new_args = [] + for arg in args: + if isinstance(arg, DataProtoFuture): + arg = arg.get() + # add more type to materialize + new_args.append(arg) + for k, v in kwargs.items(): + if isinstance(v, DataProtoFuture): + kwargs[k] = v.get() + + new_args = tuple(new_args) + return new_args, kwargs + + +def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): + _check_dispatch_mode(dispatch_mode=dispatch_mode) + _check_execute_mode(execute_mode=execute_mode) + + def decorator(func): + @wraps(func) + def inner(*args, **kwargs): + if materialize_futures: + args, kwargs = _materialize_futures(*args, **kwargs) + return func(*args, **kwargs) + + @wraps(func) + async def async_inner(*args, **kwargs): + if materialize_futures: + args, kwargs = _materialize_futures(*args, **kwargs) + return await func(*args, **kwargs) + + wrapper = async_inner if inspect.iscoroutinefunction(func) else inner + attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} + setattr(wrapper, MAGIC_ATTR, attrs) + return wrapper + + return decorator diff --git a/verl/single_controller/base/megatron/__init__.py b/verl/single_controller/base/megatron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/single_controller/base/megatron/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/single_controller/base/megatron/worker.py b/verl/single_controller/base/megatron/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..59641de5fb50b36cff20e997f97ccd74ef60da5f --- /dev/null +++ b/verl/single_controller/base/megatron/worker.py @@ -0,0 +1,83 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from verl.single_controller.base.worker import DistGlobalInfo, DistRankInfo, Worker + + +class MegatronWorker(Worker): + def __init__(self, cuda_visible_devices=None) -> None: + super().__init__(cuda_visible_devices) + + def get_megatron_global_info(self): + from megatron.core import parallel_state as mpu + + tp_size = mpu.get_tensor_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size, cp_size=cp_size) + return info + + def get_megatron_rank_info(self): + from megatron.core import parallel_state as mpu + + tp_rank = mpu.get_tensor_model_parallel_rank() + dp_rank = mpu.get_data_parallel_rank() + pp_rank = mpu.get_pipeline_model_parallel_rank() + cp_rank = mpu.get_context_parallel_rank() + info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank, cp_rank=cp_rank) + return info + + def _init_hf_config_and_tf_config(self, model_path, dtype, override_model_config): + from transformers import AutoConfig + + from verl.models.mcore import hf_to_mcore_config + from verl.utils import hf_tokenizer + from verl.utils.fs import copy_to_local + from verl.utils.model import update_model_config + + # Step 1: initialize the tokenizer + self.local_path = copy_to_local(model_path) + self.tokenizer = hf_tokenizer(self.local_path) + + # Step 2: get the hf + hf_config = AutoConfig.from_pretrained(self.local_path) + + # Step 3: override the hf config + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_model_config) + self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) + update_model_config(hf_config, override_config_kwargs=override_config_kwargs) + self.architectures = getattr(hf_config, "architectures", None) + if self.rank == 0: + print(f"Model config after override: {hf_config}") + tf_config = hf_to_mcore_config(hf_config, dtype) + + def add_optimization_config_to_tf_config(tf_config, verl_model_config): + # add optimization config to tf_config, e.g. checkpointing + if verl_model_config.get("enable_gradient_checkpointing", False): + gradient_checkpointing_cfg = dict(verl_model_config.get("gradient_checkpointing_kwargs", dict())) + tf_config.recompute_method = gradient_checkpointing_cfg.get("activations_checkpoint_method", "full") + tf_config.recompute_granularity = gradient_checkpointing_cfg.get("activations_checkpoint_granularity", "full") + tf_config.recompute_num_layers = gradient_checkpointing_cfg.get("activations_checkpoint_num_layers", -1) + + add_optimization_config_to_tf_config(tf_config, self.config.model) + + print(f"TF config: {tf_config}") + self.hf_config = hf_config + self.tf_config = tf_config diff --git a/verl/single_controller/base/megatron/worker_group.py b/verl/single_controller/base/megatron/worker_group.py new file mode 100644 index 0000000000000000000000000000000000000000..17a29a22ef42c83577b473b9e33cd81d2ba65c93 --- /dev/null +++ b/verl/single_controller/base/megatron/worker_group.py @@ -0,0 +1,56 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +from verl.single_controller.base import ResourcePool, WorkerGroup + +from .worker import DistGlobalInfo, DistRankInfo + + +class MegatronWorkerGroup(WorkerGroup): + def __init__(self, resource_pool: ResourcePool, **kwargs): + super().__init__(resource_pool=resource_pool, **kwargs) + self._megatron_rank_info = None + self._megatron_global_info: DistGlobalInfo = None + + def init_megatron(self, default_megatron_kwargs: Dict = None): + raise NotImplementedError("MegatronWorkerGroup.init_megatron should be overwritten") + + def get_megatron_rank_info(self, rank: int) -> DistRankInfo: + assert 0 <= rank < self.world_size, f"rank must be from [0, world_size), Got {rank}" + return self._megatron_rank_info[rank] + + @property + def tp_size(self): + assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + return self._megatron_global_info.tp_size + + @property + def dp_size(self): + assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + return self._megatron_global_info.dp_size + + @property + def pp_size(self): + assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + return self._megatron_global_info.pp_size + + @property + def cp_size(self): + assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + return self._megatron_global_info.cp_size + + def get_megatron_global_info(self): + return self._megatron_global_info diff --git a/verl/single_controller/base/register_center/__init__.py b/verl/single_controller/base/register_center/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/single_controller/base/register_center/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/single_controller/base/register_center/__pycache__/__init__.cpython-311.pyc b/verl/single_controller/base/register_center/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccadf1c2e813984c67361af3579b548824208ff7 Binary files /dev/null and b/verl/single_controller/base/register_center/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/single_controller/base/register_center/__pycache__/ray.cpython-311.pyc b/verl/single_controller/base/register_center/__pycache__/ray.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..270cdc7cfc1baa50155458fb85fc387d3d31dbca Binary files /dev/null and b/verl/single_controller/base/register_center/__pycache__/ray.cpython-311.pyc differ diff --git a/verl/single_controller/base/register_center/ray.py b/verl/single_controller/base/register_center/ray.py new file mode 100644 index 0000000000000000000000000000000000000000..af53987134c1ca1c56377cec5c59f5be92ffd0ad --- /dev/null +++ b/verl/single_controller/base/register_center/ray.py @@ -0,0 +1,38 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import ray + + +@ray.remote +class WorkerGroupRegisterCenter: + def __init__(self, rank_zero_info): + self.rank_zero_info = rank_zero_info + # rank -> node_id + self.workers_info: Dict[int, str] = {} + + def get_rank_zero_info(self): + return self.rank_zero_info + + def set_worker_info(self, rank, node_id) -> None: + self.workers_info[rank] = node_id + + def get_worker_info(self) -> Dict[int, str]: + return self.workers_info + + +def create_worker_group_register_center(name, info): + return WorkerGroupRegisterCenter.options(name=name).remote(info) diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..70f063763df539cce5d2caa2fe62fcb2f2c5cda1 --- /dev/null +++ b/verl/single_controller/base/worker.py @@ -0,0 +1,219 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +the class for Worker +""" + +import os +import socket +from dataclasses import dataclass +from typing import Dict + +import ray + +from .decorator import Dispatch, Execute, register + + +@dataclass +class DistRankInfo: + tp_rank: int + dp_rank: int + pp_rank: int + cp_rank: int + + +@dataclass +class DistGlobalInfo: + tp_size: int + dp_size: int + pp_size: int + cp_size: int + + +class WorkerHelper: + def _get_node_ip(self): + def get_node_ip_by_sdk(): + if os.getenv("WG_BACKEND", None) == "ray": + import ray + + return ray._private.services.get_node_ip_address() + else: + raise NotImplementedError("WG_BACKEND now just support ray mode.") + + host_ipv4 = os.getenv("MY_HOST_IP", None) + host_ipv6 = os.getenv("MY_HOST_IPV6", None) + host_ip_by_env = host_ipv4 or host_ipv6 + host_ip_by_sdk = get_node_ip_by_sdk() + + host_ip = host_ip_by_env or host_ip_by_sdk + return host_ip + + def _get_free_port(self): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + def get_availale_master_addr_port(self): + return self._get_node_ip(), str(self._get_free_port()) + + def _get_pid(self): + return os.getpid() + + +# we assume that in each WorkerGroup, there is a Master Worker +class Worker(WorkerHelper): + """A (distributed) worker.""" + + fused_worker_attr_name = "fused_worker_dict" + + def __new__(cls, *args, **kwargs): + instance = super().__new__(cls) + + # note that here we use int to distinguish + disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0)) + if disable_worker_init: + return instance + + rank = os.environ.get("RANK", None) + worker_group_prefix = os.environ.get("WG_PREFIX", None) + + # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init + if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: + instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank)) + + return instance + + def _configure_before_init(self, register_center_name: str, rank: int): + assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}" + + if rank == 0: + master_addr, master_port = self.get_availale_master_addr_port() + rank_zero_info = { + "MASTER_ADDR": master_addr, + "MASTER_PORT": master_port, + } + + if os.getenv("WG_BACKEND", None) == "ray": + from verl.single_controller.base.register_center.ray import create_worker_group_register_center + + self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info) + + os.environ.update(rank_zero_info) + else: + self.register_center = ray.get_actor(register_center_name) + + # set worker info for node affinity scheduling + ray.get(self.register_center.set_worker_info.remote(rank, ray.get_runtime_context().get_node_id())) + + @classmethod + def env_keys(cls): + """The keys of the environment variables that are used to configure the Worker.""" + return ["WORLD_SIZE", "RANK", "LOCAL_WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT", "CUDA_VISIBLE_DEVICES"] + + def __init__(self, cuda_visible_devices=None) -> None: + # construct a meta from environment variable. Note that the import must be inside the class because it is executed remotely + import os + + import torch + from packaging import version + + ### + # [SUPPORT AMD: torch] + if torch.cuda.is_available() and "AMD" in torch.cuda.get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("ROCR_VISIBLE_DEVICES") + os.environ["LOCAL_RANK"] = os.environ.get("RAY_LOCAL_RANK") + ### + + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + self._rank = rank + self._world_size = world_size + + master_addr = os.environ["MASTER_ADDR"] + master_port = os.environ["MASTER_PORT"] + + local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + ### + # [SUPPORT AMD: torch] + if torch.cuda.is_available() and "AMD" in torch.cuda.get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): + self.local_rank = int(os.environ["LOCAL_RANK"]) + cuda_visible_devices = str(local_rank) + ### + + store = { + "_world_size": world_size, + "_rank": rank, + "_local_world_size": local_world_size, + "_local_rank": local_rank, + "_master_addr": master_addr, + "_master_port": master_port, + } + if cuda_visible_devices is not None: + store["_cuda_visible_devices"] = cuda_visible_devices + + self._configure_with_store(store=store) + + ### + # [SUPPORT AMD: torch] + if torch.cuda.is_available() and "AMD" in torch.cuda.get_device_name() and version.parse(ray.__version__) < version.parse("2.45.0"): + torch.cuda.set_device(int(cuda_visible_devices)) + ### + + self.fused_worker_dict = {} + + def get_fused_worker_by_name(self, worker_name: str): + return self.fused_worker_dict.get(worker_name, None) + + def _configure_with_store(self, store: Dict): + """ + This function should only be called inside by WorkerGroup + """ + store_env_dict = {f"_{key.lower()}": store.get(f"_{key.lower()}", None) for key in type(self).env_keys()} + self.__dict__.update(store_env_dict) # this is hacky + # print(f"__dict__: {self.__dict__}") + for key in type(self).env_keys(): + val = self.__dict__.get(f"_{key.lower()}", None) + if val is not None: + # print(f"set {key} to {val}") + os.environ[key] = str(val) + os.environ["REDIS_STORE_SERVER_HOST"] = str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" + + def get_master_addr_port(self): + return self._master_addr, self._master_port + + def get_cuda_visible_devices(self): + import os + + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set") + return cuda_visible_devices + + @property + def world_size(self): + return self._world_size + + @property + def rank(self): + return self._rank + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC) + def execute_with_func_generator(self, func, *args, **kwargs): + ret_proto = func(self, *args, **kwargs) + return ret_proto + + @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO) + def execute_func_rank_zero(self, func, *args, **kwargs): + result = func(*args, **kwargs) + return result diff --git a/verl/single_controller/base/worker_group.py b/verl/single_controller/base/worker_group.py new file mode 100644 index 0000000000000000000000000000000000000000..afd2583c2928abcd9bc62dec555049d12a0cb11d --- /dev/null +++ b/verl/single_controller/base/worker_group.py @@ -0,0 +1,208 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +the class of WorkerGroup +""" + +import logging +import signal +import threading +import time +from typing import Any, Callable, Dict, List + +from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn + + +class ResourcePool: + """The resource pool with meta info such as world_size.""" + + def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_per_node=8) -> None: + if process_on_nodes is None: + process_on_nodes = [] + self._store = process_on_nodes + self.max_colocate_count = max_colocate_count + self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node + + def add_node(self, process_count): + self._store.append(process_count) + + @property + def world_size(self): + return sum(self._store) + + def __call__(self) -> Any: + return self._store + + @property + def store(self): + return self._store + + def local_world_size_list(self) -> List[int]: + nested_local_world_size_list = [[local_world_size for _ in range(local_world_size)] for local_world_size in self._store] + return [item for row in nested_local_world_size_list for item in row] + + def local_rank_list(self) -> List[int]: + nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] + return [item for row in nested_local_rank_list for item in row] + + +class ClassWithInitArgs: + """ + This class stores a class constructor and the args/kwargs to construct the class. + It is used to instantiate the remote class. + """ + + def __init__(self, cls, *args, **kwargs) -> None: + self.cls = cls + self.args = args + self.kwargs = kwargs + + self.fused_worker_used = False + + # def add_arg(self, arg): + # self.args += (arg,) + + # def add_kwarg(self, key, value): + # self.kwargs[key] = value + + def __call__(self) -> Any: + return self.cls(*self.args, **self.kwargs) + + +def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None: + import time + + while True: + for worker in workers: + if not is_alive(worker): + logging.warning(f"worker {worker} is not alive sending signal to main thread") + signal.raise_signal(signal.SIGABRT) + time.sleep(gap_time) + + +class WorkerGroup: + """A group of workers""" + + fused_worker_execute_fn_name = "_fuw_execute" + + def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: + self._is_init_with_detached_workers = resource_pool is None + + self.fused_worker_used = False + + if resource_pool is not None: + # handle the case when WorkGroup is attached to an existing one + self._procecss_dispatch_config = resource_pool() + else: + self._procecss_dispatch_config = None + + self._workers = [] + self._worker_names = [] + + self._master_addr = None + self._master_port = None + + self._checker_thread: threading.Thread = None + + def _is_worker_alive(self, worker): + raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.") + + def _block_until_all_workers_alive(self) -> None: + while True: + all_state = [self._is_worker_alive(worker) for worker in self._workers] + if False in all_state: + time.sleep(1) + else: + break + + def start_worker_aliveness_check(self, every_n_seconds=1) -> None: + # before starting checking worker aliveness, make sure all workers are already alive + self._block_until_all_workers_alive() + + self._checker_thread = threading.Thread(target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds)) + self._checker_thread.start() + + @property + def world_size(self): + return len(self._workers) + + # execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup, + # MegatronWorkerGroup, XperfWorkerGroup should skip + + def _bind_worker_method(self, user_defined_cls, func_generator): + """ + Bind the worker method to the WorkerGroup + """ + + method_names = [] + for method_name in dir(user_defined_cls): + try: + method = getattr(user_defined_cls, method_name) + assert callable(method), f"{method_name} in {user_defined_cls} is not callable" + except Exception: + # if it is a property, it will fail because Class doesn't have instance property + continue + + if hasattr(method, MAGIC_ATTR): + # this method is decorated by register + attribute = getattr(method, MAGIC_ATTR) + assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}" + assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key" + + dispatch_mode = attribute["dispatch_mode"] + execute_mode = attribute["execute_mode"] + blocking = attribute["blocking"] + + # get dispatch fn + if isinstance(dispatch_mode, Dispatch): + # get default dispatch fn + fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode) + dispatch_fn = fn["dispatch_fn"] + collect_fn = fn["collect_fn"] + else: + assert isinstance(dispatch_mode, dict) + assert "dispatch_fn" in dispatch_mode + assert "collect_fn" in dispatch_mode + dispatch_fn = dispatch_mode["dispatch_fn"] + collect_fn = dispatch_mode["collect_fn"] + + # get execute_fn_name + execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) + wg_execute_fn_name = execute_mode["execute_fn_name"] + + # get execute_fn from string + try: + execute_fn = getattr(self, wg_execute_fn_name) + assert callable(execute_fn), "execute_fn must be callable" + except Exception: + print(f"execute_fn {wg_execute_fn_name} is invalid") + raise + + # bind a new method to the RayWorkerGroup + func = func_generator( + self, + method_name, + dispatch_fn=dispatch_fn, + collect_fn=collect_fn, + execute_fn=execute_fn, + blocking=blocking, + ) + + try: + setattr(self, method_name, func) + method_names.append(method_name) + except Exception as e: + raise ValueError(f"Fail to set method_name {method_name}") from e + + return method_names diff --git a/verl/single_controller/ray/__init__.py b/verl/single_controller/ray/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6522c90518ca205ca6e8d0046dd302a508791e2e --- /dev/null +++ b/verl/single_controller/ray/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls, create_colocated_worker_cls_fused + +__all__ = ["RayClassWithInitArgs", "RayResourcePool", "RayWorkerGroup", "create_colocated_worker_cls", "create_colocated_worker_cls_fused"] diff --git a/verl/single_controller/ray/__pycache__/__init__.cpython-311.pyc b/verl/single_controller/ray/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a3b5b625e5bd735dd8f4c69a2fcc12f36e5594e Binary files /dev/null and b/verl/single_controller/ray/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/single_controller/ray/__pycache__/base.cpython-311.pyc b/verl/single_controller/ray/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4be5d899867ef3ca26fa2e3071c65d36f0d59fc1 Binary files /dev/null and b/verl/single_controller/ray/__pycache__/base.cpython-311.pyc differ diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6869d883c28bfb644b7b291b76af38494b14e3 --- /dev/null +++ b/verl/single_controller/ray/base.py @@ -0,0 +1,634 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import patch + +import ray +from ray.experimental.state.api import get_actor +from ray.util import list_named_actors +from ray.util.placement_group import PlacementGroup, placement_group +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy + +from verl.protocol import DataProto, _padding_size_key +from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup +from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch + +__all__ = ["Worker"] + + +def get_random_string(length: int) -> str: + import random + import string + + letters_digits = string.ascii_letters + string.digits + return "".join(random.choice(letters_digits) for _ in range(length)) + + +def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking): + def func(*args, **kwargs): + args, kwargs = dispatch_fn(self, *args, **kwargs) + padding_count = kwargs.pop(_padding_size_key, 0) + output = execute_fn(method_name, *args, **kwargs) + if blocking: + output = ray.get(output) + output = collect_fn(self, output) + if padding_count > 0: + if isinstance(output, DataProto): + indices = [i for i in range(len(output))][:-padding_count] + output = output.select_idxs(indices) + elif isinstance(output, list): + output = output[:-padding_count] + return output + + return func + + +def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]: + """ + Sort the placement groups by node ip, all bundles in a single placement group should be on the same node. + + FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK + to be consistent across nodes when resume from checkpoint. + + With this function, if there's only one resource pool and there's no node change, RANK should be consistent + across nodes in multiple ray jobs, even if the whole ray cluster is restarted. + """ + node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()} + pg_ip = {} + for pg in pgs: + specs = ray._private.state.state.placement_group_table(pg.id) + # all bunles should be on the same node + node_id = specs["bundles_to_node_id"][0] + pg_ip[pg.id] = node_ip[node_id] + return sorted(pgs, key=lambda pg: pg_ip[pg.id]) + + +class RayResourcePool(ResourcePool): + def __init__( + self, + process_on_nodes: Optional[List[int]] = None, + use_gpu: bool = True, + name_prefix: str = "", + max_colocate_count: int = 10, + detached=False, + ) -> None: + super().__init__(process_on_nodes, max_colocate_count) + self.use_gpu = use_gpu + # print(f"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}") + self.name_prefix = name_prefix + self.pgs = None + self.detached = detached + + def get_placement_groups(self, strategy="STRICT_PACK", name=None): + if self.pgs is not None: + return self.pgs + + pg_name_prefix = name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" + # print(f"pg_name_prefix = {pg_name_prefix}") + pg_scheme = [[{"CPU": self.max_colocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_colocate_count} for _ in range(process_count)] for process_count in self._store] + + lifetime = "detached" if self.detached else None + + pgs = [placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime) for idx, bundles in enumerate(pg_scheme)] + + ray.get([pg.ready() for pg in pgs]) + + self.pgs = pgs + return pgs + + +def extract_pg_from_exist(resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool) -> List: + src_pgs = [pg for role_name, resource_pool in resource_pools.items() for pg in resource_pool.get_placement_groups() if role_name in src_role_names] + + sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True) + sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True) + + unsorted_pgs: List[Tuple[int, PlacementGroup]] = [] + searching_idx = 0 + for request_process, original_idx in sorted_process_on_nodes: + assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node" + assert request_process <= sorted_src_pgs[searching_idx].bundle_count, f"requesting {request_process} processes, bundle count cannot satisfy" + unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx])) + searching_idx += 1 + + return [pg for _, pg in sorted(unsorted_pgs)] + + +def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool: + assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not" + assert rp1.max_colocate_count == rp2.max_colocate_count, "Both RayResourcePool must has the same max_colocate_count" + assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, "Both RayResourcePool must has the same n_gpus_per_node" + assert rp1.detached == rp2.detached, "Detached ResourcePool cannot be merged with non-detached ResourcePool" + + new_store = rp1.store + rp2.store + + merged = type(rp1)(new_store, rp1.use_gpu, f"{rp1.name_prefix}_{rp2.name_prefix}") + merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups() + + return merged + + +class RayClassWithInitArgs(ClassWithInitArgs): + def __init__(self, cls, *args, **kwargs) -> None: + # self._options = kwargs.pop('options', dict()) + super().__init__(cls, *args, **kwargs) + self._options = {} + self._additional_resource = {} + + def set_additional_resource(self, additional_resource): + self._additional_resource = additional_resource + + def update_options(self, options: Dict): + self._options.update(options) + + def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None) -> Any: + if sharing_with is not None: + target_node_id = ray.get(sharing_with.get_node_id.remote()) + cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote()) + options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)} + return self.cls.options(**options).remote(*self.args, cuda_visible_devices=cuda_visible_devices, **self.kwargs) + + options = {"scheduling_strategy": PlacementGroupSchedulingStrategy(placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx)} + options.update(self._options) + + if use_gpu: + options["num_gpus"] = num_gpus + + if len(self._additional_resource) > 1: + for k, v in self._additional_resource.items(): + options[k] = v + + # print("cls:", self.cls) + # print("args: ", self.args) + # print("kwargs: ", self.kwargs) + return self.cls.options(**options).remote(*self.args, **self.kwargs) + + +class RayWorkerGroup(WorkerGroup): + def __init__( + self, + resource_pool: RayResourcePool = None, + ray_cls_with_init: RayClassWithInitArgs = None, + bin_pack: bool = True, + name_prefix: str = None, + detached=False, + worker_names=None, + worker_handles: List[ray.actor.ActorHandle] = None, + ray_wait_register_center_timeout: int = 300, + **kwargs, + ) -> None: + super().__init__(resource_pool=resource_pool, **kwargs) + self.ray_cls_with_init = ray_cls_with_init + self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix + self._ray_wait_register_center_timeout = ray_wait_register_center_timeout + # Whether the WorkerGroup is a Colocate WorkerGroup created by FusedWorker. + self.fused_worker_used = ray_cls_with_init.fused_worker_used + # if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to this WorkerGroup. + self.sub_cls_name = "" + + if worker_names is not None and (not self.fused_worker_used): + assert self._is_init_with_detached_workers + self._worker_names = worker_names + + if self._is_init_with_detached_workers: + self._init_with_detached_workers(worker_names=worker_names, worker_handles=worker_handles) + else: + self._init_with_resource_pool(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached) + + if ray_cls_with_init is not None: + self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) + + self.wg_dict = None + self.method_names = [] + + def _is_worker_alive(self, worker: ray.actor.ActorHandle): + worker_state_dict = get_actor(worker._actor_id.hex()) + return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False + + def _init_with_detached_workers(self, worker_names, worker_handles): + # ray.get_actor holds a weak reference to the actor, which causes actors garbage collected unexpectedly + # if we only hold spawn RayWorkerGroup. By passing actor handle explicitly, spawn RayWorkerGroup have + # strong reference to these actors. + # https://github.com/ray-project/ray/pull/45699 + workers = worker_handles if worker_handles else [ray.get_actor(name=name) for name in worker_names] + self._workers = workers + self._world_size = len(worker_names) + + def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached): + use_gpu = resource_pool.use_gpu + + strategy = "PACK" + if bin_pack: + strategy = "STRICT_PACK" + pgs = resource_pool.get_placement_groups(strategy=strategy) + world_size = resource_pool.world_size + self._world_size = world_size + # cia.add_kwarg("_world_size", world_size) + num_gpus = 1 / resource_pool.max_colocate_count + + rank = -1 + local_world_size = resource_pool.store[0] + for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)): + assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the " + for local_rank in range(local_world_size): + rank += 1 + + # we pass in environment variable at option so that Worker can use environment variable to set + env_vars = { + "WORLD_SIZE": str(world_size), + "RANK": str(rank), + "WG_PREFIX": self.name_prefix, + "WG_BACKEND": "ray", + "RAY_LOCAL_WORLD_SIZE": str(local_world_size), + "RAY_LOCAL_RANK": str(local_rank), + } + if rank != 0: + env_vars["MASTER_ADDR"] = self._master_addr + env_vars["MASTER_PORT"] = self._master_port + + import re + + cia_name = type(ray_cls_with_init.cls).__name__ + match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)" + cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" + name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5 + + ray_cls_with_init.update_options({"runtime_env": {"env_vars": env_vars}, "name": name}) + + if detached: + ray_cls_with_init.update_options({"lifetime": "detached"}) + + # create a worker + worker = ray_cls_with_init(placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus) + self._workers.append(worker) + self._worker_names.append(name) + + if rank == 0: + register_center_actor = None + actor_name = f"{self.name_prefix}_register_center" + start_time = time.time() + + while time.time() - start_time < self._ray_wait_register_center_timeout: + if actor_name in list_named_actors(): + register_center_actor = ray.get_actor(actor_name) + break + + elapsed = int(time.time() - start_time) + if elapsed % 30 == 0: + logging.warning( + "Waiting for register center actor %s to be ready. Elapsed time: %s seconds out of %s seconds.", + actor_name, + elapsed, + self._ray_wait_register_center_timeout, + ) + time.sleep(1) + + if register_center_actor is None: + raise TimeoutError( + f"Failed to get register_center_actor {actor_name} " + f"in {list_named_actors(all_namespaces=True)} " + f"for {self._ray_wait_register_center_timeout} seconds. " + "Ensure that any lingering Ray resources from previous " + "runs are cleaned up (e.g., by restarting the Ray cluster), " + "or adjust the waiting time by modifying the config " + "`trainer.ray_wait_register_center_timeout`." + ) + + rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote()) + self._master_addr, self._master_port = rank_zero_info["MASTER_ADDR"], rank_zero_info["MASTER_PORT"] + # print(f"rank_zero_info: {rank_zero_info}") + # print(f"master_addr: {self._master_addr}, master_port: {self._master_port}") + + @property + def worker_names(self): + return self._worker_names + + @classmethod + def from_detached( + cls, + name_prefix, + worker_names=None, + worker_handles=None, + ray_cls_with_init=None, + ): + worker_group = cls(resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names, worker_handles=worker_handles) + return worker_group + + def spawn(self, prefix_set): + """ + spawn to a dictionary of worker groups, each with a subset of method with prefix. + + """ + if self.fused_worker_used: + return self.spawn_fused(prefix_set) + + def _rebind_actor_methods(worker_group, actor_name): + """ + bind the method with actor_prefix to its original name + """ + prefix: str = actor_name + "_" + for method_name in dir(worker_group): + if method_name.startswith(prefix): + # only valid when Python >= 3.9 + original_method_name = method_name.removeprefix(prefix) + method = getattr(worker_group, method_name) + setattr(worker_group, original_method_name, method) + + new_worker_group_dict = {} + for prefix in prefix_set: + new_worker_group = self.from_detached( + name_prefix=self.name_prefix, + worker_names=self._worker_names, + worker_handles=self._workers, + ray_cls_with_init=self.ray_cls_with_init, + ) + + _rebind_actor_methods(new_worker_group, prefix) + new_worker_group_dict[prefix] = new_worker_group + return new_worker_group_dict + + def spawn_fused(self, prefix_set): + wg_dict = dict() + for key in prefix_set: + new_wg = deepcopy(self) + new_wg._bind_worker_method(self.ray_cls_with_init.cls.raw_cls_dict[key], func_generator) + new_wg.sub_cls_name = key + wg_dict[key] = new_wg + return wg_dict + + def fuse(self, prefix_set): + if self.wg_dict is None: + self.wg_dict = self.spawn(prefix_set) + for role_name, role_wg in self.wg_dict.items(): + setattr(self, role_name, role_wg) + self.method_names = self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) + + def _execute_remote_single_worker(self, worker, method_name: str, *args, **kwargs): + if self.fused_worker_used and method_name not in self.method_names: + remote_call = getattr(worker, self.fused_worker_execute_fn_name) + return remote_call.remote(f"{self.sub_cls_name}_fwmn_{method_name}", *args, **kwargs) + # fused worker not used + remote_call = getattr(worker, method_name) + return remote_call.remote(*args, **kwargs) + + def execute_rank_zero_sync(self, method_name: str, *args, **kwargs): + return ray.get(self.execute_rank_zero_async(method_name, *args, **kwargs)) + + def execute_rank_zero_async(self, method_name: str, *args, **kwargs): + return self._execute_remote_single_worker(self._workers[0], method_name, *args, **kwargs) + + def execute_rank_zero(self, method_name: str, *args, **kwargs): + return self.execute_rank_zero_async(method_name, *args, **kwargs) + + def execute_all(self, method_name: str, *args, **kwargs): + return self.execute_all_async(method_name, *args, **kwargs) + + def execute_all_sync(self, method_name: str, *args, **kwargs): + return ray.get(self.execute_all_async(method_name, *args, **kwargs)) + + def execute_all_async(self, method_name: str, *args, **kwargs): + # Here, we assume that if all arguments in args and kwargs are lists, + # and their lengths match len(self._workers), we'll distribute each + # element in these lists to the corresponding worker + # print(f"execute_all_async: method {method_name}({args}, {kwargs})") + length = len(self._workers) + if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()): + if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()): + # print(f"splitting args and kwargs into {length} shards") + result = [] + for i in range(length): + sliced_args = tuple(arg[i] for arg in args) + sliced_kwargs = {k: v[i] for k, v in kwargs.items()} + result.append(self._execute_remote_single_worker(self._workers[i], method_name, *sliced_args, **sliced_kwargs)) + return result + + return [self._execute_remote_single_worker(worker, method_name, *args, **kwargs) for worker in self._workers] + + @property + def master_address(self): + return self._master_addr + + @property + def master_port(self): + return self._master_port + + @property + def workers(self): + return self._workers + + @property + def world_size(self): + return self._world_size + + +""" +Utilities that enables creating workers inside the same ray.Actor, +with code written in separate ray.Actors. +""" + + +# deprecated, switching to FusedWorker +def _bind_workers_method_to_parent(cls, key, user_defined_cls): + """ + Binds the methods of each worker to the WorkerDict. + Note that we only bind public methods that are decorated by register + """ + + for method_name in dir(user_defined_cls): + try: + method = getattr(user_defined_cls, method_name) + assert callable(method), f"{method_name} in {user_defined_cls} is not callable" + except Exception: + # if it is a property, it will fail because Class doesn't have instance property + continue + + if hasattr(method, MAGIC_ATTR): + + def generate_function(name, key=key): + def func(self, *args, **kwargs): + # dispatch to the actual worker + return getattr(self.worker_dict[key], name)(*args, **kwargs) + + return func # noqa: B023 + + func = generate_function(method_name) + # pass MAGIC_ATTR for outer worker group + attrs = getattr(method, MAGIC_ATTR) + setattr(func, MAGIC_ATTR, attrs) + try: + # bind direct rollout method to class without prefix + if attrs["dispatch_mode"] == Dispatch.DIRECT_ROLLOUT_METHOD and "rollout" in key: + assert not hasattr(cls, method_name), f"conflict direct rollout method {method_name} with role {key}" + setattr(cls, method_name, func) + print(f"bind role {key} method {method_name} to class {cls}") + else: + method_name_with_prefix = key + "_" + method_name + setattr(cls, method_name_with_prefix, func) + except Exception as e: + raise ValueError(f"Fail to set method_name {method_name}") from e + + +def _unwrap_ray_remote(cls): + if hasattr(cls, "__ray_actor_class__"): + cls = cls.__ray_actor_class__ + return cls + + +def _determine_fsdp_megatron_base_class(mros: List): + """ + - megatron: base class should be MegatronWorker + - fsdp: base class should be Worker + """ + for cls in mros[0]: + if cls.__name__ == "MegatronWorker": + return cls + if cls.__name__ == "Worker": + return cls + raise ValueError(f"Cannot determine base class for {mros}") + + +# deprecated, switching to FusedWorker +def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): + """ + This function should return a class instance that delegates the calls to every + cls in cls_dict + """ + cls_dict = {} + init_args_dict = {} + worker_cls = _determine_fsdp_megatron_base_class([cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()]) + assert issubclass(worker_cls, Worker), f"worker_cls {worker_cls} should be a subclass of Worker" + print(f"colocated worker base class {worker_cls}") + + for key, cls in class_dict.items(): + cls_dict[key] = cls.cls + init_args_dict[key] = {"args": cls.args, "kwargs": cls.kwargs} + + assert cls_dict.keys() == init_args_dict.keys() + + # TODO: create a class with customizable name + class WorkerDict(worker_cls): + def __init__(self): + super().__init__() + self.worker_dict = {} + for key, user_defined_cls in cls_dict.items(): + user_defined_cls = _unwrap_ray_remote(user_defined_cls) + # directly instantiate the class without remote + # in worker class, e.g. + # when DISABLE_WORKER_INIT == 1 it will return immediately + with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}): + self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {})) + + # now monkey-patch the methods from inner class to WorkerDict + for key, user_defined_cls in cls_dict.items(): + user_defined_cls = _unwrap_ray_remote(user_defined_cls) + _bind_workers_method_to_parent(WorkerDict, key, user_defined_cls) + + remote_cls = ray.remote(WorkerDict) + remote_cls = RayClassWithInitArgs(cls=remote_cls) + return remote_cls + + +FusedWorkerCLSName = "FusedWorker" + + +def create_colocated_worker_raw_cls(class_dict: dict[str, RayClassWithInitArgs]): + """ + This function returns a FusedWorker class. + + `FusedWorker.{class_name}` -> FusedClass + Use `class_name` as a param to directly access the underlying class. + + `FusedWorker._fuw_execute("{class_name}_fwmn_{method_name}", *args, **kwargs)` + First param must be "{class_name}_fwmn_{method_name}" in order to access `method_name` + of underlying class `{class_name}`. + + `FusedWorker.fused_worker_dict` -> {"class_name": FusedClass} + Stores all underlying classes. + + `FusedClass.fused_worker_dict` -> {"class_name": FusedClass} + The same as `FusedWorker.fused_worker_dict`, enables underlying class to access other + underlying classes. + """ + raw_cls_dict = {cls_name: _unwrap_ray_remote(cia.cls) for cls_name, cia in class_dict.items()} + init_args_dict = {cls_name: cia.args for cls_name, cia in class_dict.items()} + init_kwargs_dict = {cls_name: cia.kwargs for cls_name, cia in class_dict.items()} + cls_names = list(class_dict.keys()) + + # FusedWorker_Actor_Critic + class_name_renamed = "_".join([FusedWorkerCLSName] + cls_names) + + class FusedWorker(Worker): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cls_names = cls_names + self.raw_cls_dict = raw_cls_dict + self.init_args_dict = init_args_dict + self.init_kwargs_dict = init_kwargs_dict + + for cls_name, udc, ud_args, ud_kwargs in zip(self.cls_names, self.raw_cls_dict.values(), self.init_args_dict.values(), self.init_kwargs_dict.values()): + with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}): + udc._get_ray_actor_cls_name = lambda x, name_renamed=class_name_renamed: name_renamed + udc._get_ray_method_prefix = lambda x, name_prefixed=cls_name: f"{name_prefixed}_" + # cls_name = "actor", "critic", udc = ActorWorker, CriticWorker + self.fused_worker_dict[cls_name] = udc(*ud_args, **ud_kwargs) + setattr(self, cls_name, self.fused_worker_dict[cls_name]) + + # injecting fused_worker to each sub worker so they can be aware of existence of each other + for _, worker in self.fused_worker_dict.items(): + setattr(worker, Worker.fused_worker_attr_name, self.fused_worker_dict) + + def _fuw_execute(self, method_name: str, *args, **kwargs): + # for fused_worker, method_name is in a form of "{cls_name}_fwmn_{method_name}" + # where fwmn stands "fused worker method name" + names = method_name.split("_fwmn_") + cls_name = names[0] + method_name = names[1] + + assert cls_name in self.fused_worker_dict, f"calling {cls_name}'s {method_name}, but {cls_name} not in fused_worker_dict" + udc_method = getattr(self.fused_worker_dict[cls_name], method_name) + return udc_method(*args, **kwargs) + + renamed_fused_worker_cls = type(class_name_renamed, (FusedWorker,), {}) + renamed_fused_worker_cls.is_fused_worker = True + renamed_fused_worker_cls.raw_cls_dict = raw_cls_dict + + return renamed_fused_worker_cls + + +def create_colocated_worker_cls_fused(class_dict: dict[str, RayClassWithInitArgs]): + """ + This function returns a RayClassWithInitArgs instance of FusedWorker, which is an replacement + of `create_colocated_worker_cls`. WorkerGroup constructed using this class will be a colocated + WorkerGroup, which will be referenced as `ColocateWorkerGroup` below. + + `ColocateWorkerGroup.spawn(prefix_set)` + returns a dict of WorkerGroup {"class_name": WorkerGroup}, WorkerGroup in this dict will + have methods of underlying class `class_name` attached. + + `ColocateWorkerGroup.fuse(prefix_set)` + After executing this function, `ColocateWorkerGroup.{class_name}` will return WorkerGroup + with methods of underlying class `class_name` attached. + """ + raw_colocated_worker_cls = create_colocated_worker_raw_cls(class_dict) + + remote_cls = ray.remote(raw_colocated_worker_cls) + cia = RayClassWithInitArgs(cls=remote_cls) + cia.fused_worker_used = True + + return cia diff --git a/verl/single_controller/ray/megatron.py b/verl/single_controller/ray/megatron.py new file mode 100644 index 0000000000000000000000000000000000000000..339cf9dad0897fc2516f796cfdb41e617fd3ef09 --- /dev/null +++ b/verl/single_controller/ray/megatron.py @@ -0,0 +1,65 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import ray + +from verl.single_controller.base.megatron.worker import DistGlobalInfo, DistRankInfo +from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup + +from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup + + +# NOTE(sgm): for open-source megatron-core +class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): + """ + MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup + so that the dispatcher can use it to dispatch data. + """ + + def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs): + super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) + self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") + self._megatron_global_info: DistGlobalInfo = ray.get(self.execute_rank_zero_async(method_name="get_megatron_global_info")) + + +class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): + """ + MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup + so that the dispatcher can use it to dispatch data. + """ + + def __init__( + self, + resource_pool: RayResourcePool, + ray_cls_with_init: RayClassWithInitArgs, + default_megatron_kwargs: Dict = None, + **kwargs, + ): + super().__init__( + resource_pool=resource_pool, + ray_cls_with_init=ray_cls_with_init, + default_megatron_kwargs=default_megatron_kwargs, + **kwargs, + ) + self.init_megatron(default_megatron_kwargs=default_megatron_kwargs) + self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") + self._megatron_global_info: DistGlobalInfo = ray.get(self.execute_rank_zero_async(method_name="get_megatron_global_info")) + + def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None): + # after super, we will call init of each worker + if not self._is_init_with_detached_workers: + # only init_megatron if the WorkerGroup is created from scratch + self.execute_all_sync(method_name="init_megatron", default_megatron_kwargs=default_megatron_kwargs) diff --git a/verl/third_party/__init__.py b/verl/third_party/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/third_party/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/third_party/__pycache__/__init__.cpython-311.pyc b/verl/third_party/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb2a4f553dc335e3edf637dd705af0e390c7df7c Binary files /dev/null and b/verl/third_party/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/third_party/sglang/__init__.py b/verl/third_party/sglang/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64356bbeeece46e89490e276545769fc07614e59 --- /dev/null +++ b/verl/third_party/sglang/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/third_party/sglang/parallel_state.py b/verl/third_party/sglang/parallel_state.py new file mode 100644 index 0000000000000000000000000000000000000000..a6bed5bfb206505a6fdf021053fd9154455be449 --- /dev/null +++ b/verl/third_party/sglang/parallel_state.py @@ -0,0 +1,322 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The SGlang team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""Model and data parallel groups.""" + +import os +from typing import Optional + +import sglang.srt.distributed.parallel_state as ps +import torch +import torch.distributed +from sglang.srt.distributed.parallel_state import ( + get_pp_group, + get_world_group, + init_distributed_environment, + init_model_parallel_group, +) + +""" +This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. +- We assume the Megatron tp+dp+pp world is already established before calling this function. + +""" + +# Device mesh for using DTensor +_DEVICE_MESH = None + +# Tensor model parallel group that the current rank belongs to. +_TP = None +# Pipeline model parallel group that the current rank belongs to. +_PP = None + + +# This method is for initializing the ParallelGroup when using HybridEngine +# NOTE(linjunrong): this function is for megatron +def initialize_parallel_state( + distributed_init_method: str = "env://", + backend: str = "nccl", + tensor_model_parallel_size: int = 1, + num_tp_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +): + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + rank = int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) + if torch.distributed.get_world_size() > 1: + # NOTE: build a sepearate inference group with infer tp & micro dp + initialize_model_parallel_for_sglang( + tensor_model_parallel_size=tensor_model_parallel_size, + num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, + ) + else: + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + + +# NOTE(linjunrong): After init SGLang rollout using class EngineFragment, user should always remember to call +# this function to sync the _TP, _PP define at the beginning of this file. Otherwise, only the conterparts +# inside sglang.srt.distributed are init as ProcessGroup, the symbols defined in this file remain as None. +# It could be weird to maintain two _TP and _PP, I follow the same way to maintain an extra ones for +# verl itself as how it was done in verl.third_party.vllm.parallel_state. Note that the process is a little +# bit different +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + return + + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, f"tensor parallel group already initialized, but of unexpected size: {get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}" + pp_world_size = get_pp_group().world_size + assert pp_world_size == pipeline_model_parallel_size, f"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. {pipeline_model_parallel_size=}" + + +# TODO(sgm): deviate from the v0.5.4, not pp now +# NOTE(linjunrong): the SGLang version using _TP instead of ps._TP +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return _TP is not None + # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + + +def initialize_model_parallel_for_sglang( + tensor_model_parallel_size: int, + num_tensor_model_parallel_groups_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +) -> None: + pass + + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + + assert isinstance(tensor_model_parallel_size, int) + + # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group + # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group + + # Build the tensor model-parallel groups. + assert ps._TP is None, "tensor model parallel group is already initialized" + + global _TP + + world_size: int = torch.distributed.get_world_size() + + backend = torch.distributed.get_backend() + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + + if num_tensor_model_parallel_groups_per_train_tp == 1: + # if tensor_model_parallel_size == train_tensor_parallel_size: + # using the same tp group as Megatron/vllm + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine + else: + # initialize a micro_dp group and a tp group + # assume training tp=4, infer tp=2, then, weight is partitioned as + # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference + + # Build the inference tp groups + # train_tp = train_tensor_parallel_size + train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size + # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): + start = train_tp * i + end = train_tp * (i + 1) + for j in range(num_tensor_model_parallel_groups_per_train_tp): + ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp)) + for i in range(len(ranks)): + ranks[i] += j + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # Build the pipeline model-parallel groups. + # global _PIPELINE_MODEL_PARALLEL_GROUP + # global _PIPELINE_GLOBAL_RANKS + # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized") + + # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group() + # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks() + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + NOTE: This method is a hack from the open-sourced version without + asertion of world_size = tp * pp + + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) + + # NOTE(sgm) we don't assert world_size == tp * pp + # DP is not managed by vllm but by the VeRL WorkerGroup + # if (world_size != + # tensor_model_parallel_size * pipeline_model_parallel_size): + # raise RuntimeError( + # f"world_size ({world_size}) is not equal to " + # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + + global _TP + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + if ps._TP is not None: + _TP = ps._TP + else: + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + if ps._TP is not None: + _PP = ps._TP + else: + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP + + +""" +Device mesh utilities +""" + + +def get_device_mesh(): + assert _DEVICE_MESH is not None, "device mesh is not initialized" + return _DEVICE_MESH + + +""" +Tensor model parallel utilities +""" + + +# NOTE(linjunrong): In the vllm version parallel_state.py. verl created its own _TP and _PP as verl want to use +# the process group for some extra purpose. Under the hood, there is no difference between them and the original +# one in vllm.distributed.parallel_state. However, the implementation need to hack the init process of inference +# engine, as we do not maintain another SGLang here, I just use the original _TP and _PP directly. +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP.device_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size diff --git a/verl/third_party/vllm/__init__.py b/verl/third_party/vllm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c447b1775f56737b09a65ccad5c2044950ac150c --- /dev/null +++ b/verl/third_party/vllm/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from importlib.metadata import PackageNotFoundError, version + +from packaging import version as vs + +from verl.utils.import_utils import is_sglang_available + + +def get_version(pkg): + try: + return version(pkg) + except PackageNotFoundError: + return None + + +package_name = "vllm" +package_version = get_version(package_name) +vllm_version = None + + +if package_version == "0.5.4": + vllm_version = "0.5.4" + from .vllm_v_0_5_4 import parallel_state + from .vllm_v_0_5_4.llm import LLM, LLMEngine +elif package_version == "0.6.3" or package_version.startswith("0.6.3"): + # rocm version: "0.6.3+rocmxxx" + vllm_version = "0.6.3" + from .vllm_v_0_6_3 import parallel_state + from .vllm_v_0_6_3.llm import LLM, LLMEngine +elif vs.parse(package_version) >= vs.parse("0.7.0"): + # From 0.6.6.post2 on, vllm supports SPMD inference + # See https://github.com/vllm-project/vllm/pull/12071 + + from vllm import LLM + from vllm.distributed import parallel_state +else: + if not is_sglang_available(): + raise ValueError(f"vllm version {package_version} not supported and SGLang also not Found. Currently supported vllm versions are 0.6.3 and 0.7.0+") + +__all__ = ["LLM", "LLMEngine", "parallel_state"] diff --git a/verl/third_party/vllm/__pycache__/__init__.cpython-311.pyc b/verl/third_party/vllm/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a70884593cb8a01bbef23add282ff9bd804b4692 Binary files /dev/null and b/verl/third_party/vllm/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/third_party/vllm/vllm_v_0_5_4/__init__.py b/verl/third_party/vllm/vllm_v_0_5_4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py b/verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d83bbd7e784618ee49c66ffdeee581595177e984 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/arg_utils.py @@ -0,0 +1,447 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py + +import argparse +import dataclasses +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union + +from transformers import PretrainedConfig +from vllm.config import ( + CacheConfig, + DecodingConfig, + DeviceConfig, + EngineConfig, + LoRAConfig, + MultiModalConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, + TokenizerPoolConfig, +) +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger + +from .config import LoadConfig, ModelConfig + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup + +logger = init_logger(__name__) + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + + model_hf_config: PretrainedConfig = None # for verl + served_model_name = None # TODO(sgm): check this + # tokenizer: Optional[str] = None # TODO(sgm): check this + skip_tokenizer_init: bool = False + tokenizer_mode: str = "auto" + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = "auto" + dtype: str = "auto" + kv_cache_dtype: str = "auto" + quantization_param_path: Optional[str] = None + seed: int = 0 + max_model_len: Optional[int] = None + worker_use_ray: bool = False + # Note: Specifying a custom executor backend by passing a class + # is intended for expert use only. The API may change without + # notice. + distributed_executor_backend: Optional[Union[str, Type[ExecutorBase]]] = None + pipeline_parallel_size: int = 1 + tensor_parallel_size: int = 1 + max_parallel_loading_workers: Optional[int] = None + block_size: int = 16 + enable_prefix_caching: bool = False + disable_sliding_window: bool = False + use_v2_block_manager: bool = False + swap_space: int = 4 # GiB + cpu_offload_gb: int = 0 # GiB + gpu_memory_utilization: float = 0.90 + max_num_batched_tokens: Optional[int] = None + max_num_seqs: int = 256 + max_logprobs: int = 20 # Default value for OpenAI Chat Completions API + disable_log_stats: bool = False + revision: Optional[str] = None + code_revision: Optional[str] = None + rope_scaling: Optional[dict] = None + rope_theta: Optional[float] = None + tokenizer_revision: Optional[str] = None + quantization: Optional[str] = None + enforce_eager: bool = False + max_context_len_to_capture: Optional[int] = None + max_seq_len_to_capture: int = 8192 + disable_custom_all_reduce: bool = False + tokenizer_pool_size: int = 0 + # Note: Specifying a tokenizer pool by passing a class + # is intended for expert use only. The API may change without + # notice. + tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" + tokenizer_pool_extra_config: Optional[dict] = None + enable_lora: bool = False + max_loras: int = 1 + max_lora_rank: int = 16 + enable_prompt_adapter: bool = False + max_prompt_adapters: int = 1 + max_prompt_adapter_token: int = 0 + fully_sharded_loras: bool = False + lora_extra_vocab_size: int = 256 + long_lora_scaling_factors: Optional[Tuple[float]] = None + lora_dtype: str = "auto" + max_cpu_loras: Optional[int] = None + device: str = "auto" + ray_workers_use_nsight: bool = False + num_gpu_blocks_override: Optional[int] = None + num_lookahead_slots: int = 0 + model_loader_extra_config: Optional[dict] = None + ignore_patterns: Optional[Union[str, List[str]]] = None + preemption_mode: Optional[str] = None + + scheduler_delay_factor: float = 0.0 + enable_chunked_prefill: Optional[bool] = None + + guided_decoding_backend: str = "outlines" + # Speculative decoding configuration. + speculative_model: Optional[str] = None + speculative_draft_tensor_parallel_size: Optional[int] = None + num_speculative_tokens: Optional[int] = None + speculative_max_model_len: Optional[int] = None + speculative_disable_by_batch_size: Optional[int] = None + ngram_prompt_lookup_max: Optional[int] = None + ngram_prompt_lookup_min: Optional[int] = None + spec_decoding_acceptance_method: str = "rejection_sampler" + typical_acceptance_sampler_posterior_threshold: Optional[float] = None + typical_acceptance_sampler_posterior_alpha: Optional[float] = None + qlora_adapter_name_or_path: Optional[str] = None + disable_logprobs_during_spec_decoding: Optional[bool] = None + + otlp_traces_endpoint: Optional[str] = None + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Shared CLI arguments for vLLM engine.""" + # Model arguments + # TODO(shengguangming): delete the unused args + parser.add_argument("--model", type=str, default="facebook/opt-125m", help="name or path of the huggingface model to use") + parser.add_argument( + "--tokenizer", + type=str, + default=EngineArgs.tokenizer, + help="name or path of the huggingface tokenizer to use", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="the specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.", + ) + parser.add_argument( + "--tokenizer-revision", + type=str, + default=None, + help="the specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.", + ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default=EngineArgs.tokenizer_mode, + choices=["auto", "slow"], + help='tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer.', + ) + parser.add_argument("--trust-remote-code", action="store_true", help="trust remote code from huggingface") + parser.add_argument( + "--download-dir", + type=str, + default=EngineArgs.download_dir, + help="directory to download and load the weights, default to the default cache dir of huggingface", + ) + parser.add_argument( + "--load-format", + type=str, + default=EngineArgs.load_format, + choices=["auto", "pt", "safetensors", "npcache", "dummy"], + help="The format of the model weights to load. " + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling.", + ) + parser.add_argument( + "--dtype", + type=str, + default=EngineArgs.dtype, + choices=["auto", "half", "float16", "bfloat16", "float", "float32"], + help='data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.', + ) + parser.add_argument( + "--max-model-len", + type=int, + default=None, + help="model context length. If unspecified, will be automatically derived from the model.", + ) + # Parallel arguments + parser.add_argument( + "--worker-use-ray", + action="store_true", + help="use Ray for distributed serving, will be automatically set when using more than 1 GPU", + ) + parser.add_argument( + "--pipeline-parallel-size", + "-pp", + type=int, + default=EngineArgs.pipeline_parallel_size, + help="number of pipeline stages", + ) + parser.add_argument( + "--tensor-parallel-size", + "-tp", + type=int, + default=EngineArgs.tensor_parallel_size, + help="number of tensor parallel replicas", + ) + # KV cache arguments + parser.add_argument("--block-size", type=int, default=EngineArgs.block_size, choices=[8, 16, 32], help="token block size") + # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). + parser.add_argument("--seed", type=int, default=EngineArgs.seed, help="random seed") + parser.add_argument("--swap-space", type=int, default=EngineArgs.swap_space, help="CPU swap space size (GiB) per GPU") + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=EngineArgs.gpu_memory_utilization, + help="the percentage of GPU memory to be used forthe model executor", + ) + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=EngineArgs.max_num_batched_tokens, + help="maximum number of batched tokens per iteration", + ) + parser.add_argument( + "--max-num-seqs", + type=int, + default=EngineArgs.max_num_seqs, + help="maximum number of sequences per iteration", + ) + parser.add_argument("--disable-log-stats", action="store_true", help="disable logging statistics") + # Quantization settings. + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=["awq", None], + default=None, + help="Method used to quantize the weights", + ) + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return engine_args + + def create_engine_config( + self, + ) -> EngineConfig: + # bitsandbytes quantization needs a specific model loader + # so we make sure the quant method and the load format are consistent + if (self.quantization == "bitsandbytes" or self.qlora_adapter_name_or_path is not None) and self.load_format != "bitsandbytes": + raise ValueError(f"BitsAndBytes quantization and QLoRA adapter only support 'bitsandbytes' load format, but got {self.load_format}") + + if (self.load_format == "bitsandbytes" or self.qlora_adapter_name_or_path is not None) and self.quantization != "bitsandbytes": + raise ValueError(f"BitsAndBytes load format and QLoRA adapter only support 'bitsandbytes' quantization, but got {self.quantization}") + + assert self.cpu_offload_gb >= 0, f"CPU offload space must be non-negative, but got {self.cpu_offload_gb}" + + multimodal_config = MultiModalConfig() + device_config = DeviceConfig(self.device) + # NOTE(sgm): we only modify ModelConfig, other configs are import from vllm + model_config = ModelConfig( + hf_config=self.model_hf_config, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + enforce_eager=self.enforce_eager, + max_context_len_to_capture=self.max_context_len_to_capture, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name, + multimodal_config=multimodal_config, + ) + cache_config = CacheConfig( + block_size=self.block_size, + gpu_memory_utilization=self.gpu_memory_utilization, + swap_space=self.swap_space, + cache_dtype=self.kv_cache_dtype, + num_gpu_blocks_override=self.num_gpu_blocks_override, + sliding_window=model_config.get_sliding_window(), + enable_prefix_caching=self.enable_prefix_caching, + cpu_offload_gb=self.cpu_offload_gb, + ) + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + worker_use_ray=self.worker_use_ray, + max_parallel_loading_workers=self.max_parallel_loading_workers, + disable_custom_all_reduce=self.disable_custom_all_reduce, + tokenizer_pool_config=TokenizerPoolConfig.create_config( + self.tokenizer_pool_size, + self.tokenizer_pool_type, + self.tokenizer_pool_extra_config, + ), + ray_workers_use_nsight=self.ray_workers_use_nsight, + distributed_executor_backend=self.distributed_executor_backend, + ) + + # NOTE[VERL]: Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + parallel_config.world_size = world_size + + max_model_len = model_config.max_model_len + use_long_context = max_model_len > 32768 + if self.enable_chunked_prefill is None: + # If not explicitly set, enable chunked prefill by default for + # long context (> 32K) models. This is to avoid OOM errors in the + # initial memory profiling phase. + if use_long_context: + is_gpu = device_config.device_type == "cuda" + use_sliding_window = model_config.get_sliding_window() is not None + use_spec_decode = self.speculative_model is not None + has_seqlen_agnostic_layers = model_config.contains_seqlen_agnostic_layers(parallel_config) + if is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora and not self.enable_prompt_adapter and not self.enable_prefix_caching and not has_seqlen_agnostic_layers: + self.enable_chunked_prefill = True + logger.warning("Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.") + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = False + + if not self.enable_chunked_prefill and use_long_context: + logger.warning( + "The model has a long context length (%s). This may cause OOM errors during the initial memory profiling phase, or result in low performance due to small KV cache space. Consider setting --max-model-len to a smaller value.", + max_model_len, + ) + + # TODO: spec config + speculative_config = SpeculativeConfig.maybe_create_spec_config( + target_model_config=model_config, + target_parallel_config=parallel_config, + target_dtype=self.dtype, + speculative_model=self.speculative_model, + speculative_draft_tensor_parallel_size=self.speculative_draft_tensor_parallel_size, + num_speculative_tokens=self.num_speculative_tokens, + speculative_disable_by_batch_size=self.speculative_disable_by_batch_size, + speculative_max_model_len=self.speculative_max_model_len, + enable_chunked_prefill=self.enable_chunked_prefill, + use_v2_block_manager=self.use_v2_block_manager, + disable_log_stats=self.disable_log_stats, + ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, + draft_token_acceptance_method=self.spec_decoding_acceptance_method, + typical_acceptance_sampler_posterior_threshold=self.typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=self.typical_acceptance_sampler_posterior_alpha, + disable_logprobs=self.disable_logprobs_during_spec_decoding, + ) + + scheduler_config = SchedulerConfig( + max_num_batched_tokens=self.max_num_batched_tokens, + max_num_seqs=self.max_num_seqs, + max_model_len=model_config.max_model_len, + use_v2_block_manager=self.use_v2_block_manager, + num_lookahead_slots=(self.num_lookahead_slots if speculative_config is None else speculative_config.num_lookahead_slots), + delay_factor=self.scheduler_delay_factor, + enable_chunked_prefill=self.enable_chunked_prefill, + embedding_mode=model_config.embedding_mode, + preemption_mode=self.preemption_mode, + ) + lora_config = ( + LoRAConfig( + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + long_lora_scaling_factors=self.long_lora_scaling_factors, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None, + ) + if self.enable_lora + else None + ) + + if self.qlora_adapter_name_or_path is not None and self.qlora_adapter_name_or_path != "": + if self.model_loader_extra_config is None: + self.model_loader_extra_config = {} + self.model_loader_extra_config["qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path + + load_config = LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + prompt_adapter_config = PromptAdapterConfig(max_prompt_adapters=self.max_prompt_adapters, max_prompt_adapter_token=self.max_prompt_adapter_token) if self.enable_prompt_adapter else None + + decoding_config = DecodingConfig(guided_decoding_backend=self.guided_decoding_backend) + + observability_config = ObservabilityConfig(otlp_traces_endpoint=self.otlp_traces_endpoint) + + if model_config.get_sliding_window() is not None and scheduler_config.chunked_prefill_enabled and not scheduler_config.use_v2_block_manager: + raise ValueError("Chunked prefill is not supported with sliding window. Set --disable-sliding-window to disable sliding window.") + + return EngineConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + multimodal_config=multimodal_config, + speculative_config=speculative_config, + load_config=load_config, + decoding_config=decoding_config, + observability_config=observability_config, + prompt_adapter_config=prompt_adapter_config, + ) diff --git a/verl/third_party/vllm/vllm_v_0_5_4/config.py b/verl/third_party/vllm/vllm_v_0_5_4/config.py new file mode 100644 index 0000000000000000000000000000000000000000..62922a4ba20a49c42a01986cc26e1e97cd9c30dd --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/config.py @@ -0,0 +1,247 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py + +import enum +import json +from dataclasses import dataclass, field +from typing import List, Optional, Union + +import torch +from transformers import PretrainedConfig + +# Add for verl +from vllm.config import ( + ModelConfig, + MultiModalConfig, + _get_and_verify_dtype, + _get_and_verify_max_len, + get_served_model_name, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import get_quantization_config +from vllm.model_executor.model_loader import BaseModelLoader +from vllm.transformers_utils.config import get_hf_text_config +from vllm.utils import is_hip, print_warning_once + +GPTQMarlinConfig = get_quantization_config("gptq_marlin") + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +class ModelConfig(ModelConfig): + """Configuration for the model. + + Args: + model: Name or path of the huggingface model to use. + tokenizer: Name or path of the huggingface tokenizer to use. + tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if + available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + dtype: Data type for model weights and activations. The "auto" option + will use FP16 precision for FP32 and FP16 models, and BF16 precision + for BF16 models. + seed: Random seed for reproducibility. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. If unspecified, will use the default + version. + code_revision: The specific revision to use for the model code on + Hugging Face Hub. It can be a branch name, a tag name, or a + commit id. If unspecified, will use the default version. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. If unspecified, will use + the default version. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. + quantization: Quantization method that was used to quantize the model + weights. If None, we assume the model weights are not quantized. + quantization_param_path: Path to JSON file containing scaling factors. + Used to load KV cache scaling factors into the model when KV cache + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the + model dtype is FP8_E4M3 on ROCm. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. + served_model_name: The model name used in metrics tag `model_name`, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, + the model name will be the same as `model`. + """ + + def __init__( + self, + hf_config: PretrainedConfig, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + multimodal_config: Optional[MultiModalConfig] = None, + ) -> None: + self.model = hf_config._name_or_path + self.tokenizer = hf_config._name_or_path + # NOTE(sgm): same as open-sourced + self.tokenizer_mode = tokenizer_mode + self.trust_remote_code = trust_remote_code + self.seed = seed + self.revision = revision + self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + # The tokenizer version is consistent with the model version by default. + if tokenizer_revision is None: + self.tokenizer_revision = revision + else: + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.quantization_param_path = quantization_param_path + self.enforce_eager = enforce_eager + if max_context_len_to_capture is not None: + raise ValueError("`max_context_len_to_capture` is deprecated. Use `max_seq_len_to_capture` instead.") + self.max_seq_len_to_capture = max_seq_len_to_capture + self.max_logprobs = max_logprobs + self.disable_sliding_window = disable_sliding_window + self.skip_tokenizer_init = skip_tokenizer_init + + # self.hf_config = get_config(model, trust_remote_code, revision) + self.hf_config = hf_config + self.hf_text_config = get_hf_text_config(hf_config) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + # self.served_model_name = get_served_model_name(model, + # served_model_name) + # self._verify_load_format() + # self._verify_tokenizer_mode() + if not self.disable_sliding_window and self.hf_text_config.model_type == "gemma2" and self.hf_text_config.sliding_window is not None: + print_warning_once(f"Gemma 2 uses sliding window attention for every odd layer, which is currently not supported by vLLM. Disabling sliding window and capping the max length to the sliding window size ({self.hf_text_config.sliding_window}).") + self.disable_sliding_window = True + + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window(), + ) + self.served_model_name = get_served_model_name( + self.model, # str + served_model_name, + ) + self.multimodal_config = multimodal_config + + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() + self._verify_embedding_mode() + self._verify_quantization() + self._verify_cuda_graph() + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + MEGATRON = "megatron" + HF = "hf" + DTENSOR = "dtensor" + DUMMY_HF = "dummy_hf" + DUMMY_MEGATRON = "dummy_megatron" + DUMMY_DTENSOR = "dummy_dtensor" + + +# TODO: check whether this is necessary +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + + """ + + load_format: Union[str, LoadFormat, BaseModelLoader] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads(model_loader_extra_config) + self._verify_load_format() + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)] + raise ValueError(f"load format '{load_format}' is not supported in ROCm. Supported load formats are {rocm_supported_load_format}") diff --git a/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..65ba3bf4f3ae6a72de2afb7988d2f70aaedf6415 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/dtensor_weight_loaders.py @@ -0,0 +1,337 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models + +from typing import Dict + +import torch.nn as nn +from torch.distributed._tensor import DTensor +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import * +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import is_pp_missing_parameter + + +def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + stacked_name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if stacked_name.endswith(".bias") and stacked_name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[stacked_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "lm_head.weight" in name: + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight) + + +def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=vllm_model.config.n_routed_experts, + ) + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + local_loaded_weight.to(dtype=param.dtype), + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + pass + + +def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None): + param_name = _process_parameter_names(name=param_name) + if parallelize_plan is not None: + assert param_name in parallelize_plan, f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" + placement = parallelize_plan[param_name] + local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh, placements=placement).to_local() + else: + local_loaded_weights = loaded_weights.full_tensor() + return local_loaded_weights + + +def _process_parameter_names(name): + # Remove '.weight' if it exists at the end of the string + if name.endswith(".weight"): + name = name[:-7] + + # Remove 'model.layers.x.' or 'model.' prefix + if "model.layers" in name: + parts = name.split(".") + # Reconstruct the string without 'model.layers.x.' + name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' + elif name.startswith("model."): + name = name[6:] # Remove 'model.' + + return name + + +__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = { + "GPT2LMHeadModel": gpt2_dtensor_weight_loader, + "LlamaForCausalLM": llama_dtensor_weight_loader, + "LLaMAForCausalLM": llama_dtensor_weight_loader, + "MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM + "InternLMForCausalLM": llama_dtensor_weight_loader, + "AquilaModel": llama_dtensor_weight_loader, + "AquilaForCausalLM": llama_dtensor_weight_loader, + "Phi3ForCausalLM": llama_dtensor_weight_loader, + "GemmaForCausalLM": gemma_dtensor_weight_loader, + "Gemma2ForCausalLM": gemma_dtensor_weight_loader, + "GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights, + "Starcoder2ForCausalLM": starcoder2_dtensor_load_weights, + "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, + "DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader, +} + + +# the actor model is .state_dict() +# Load dtensor weights +def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): + weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) + weight_loader(actor_weights, vllm_model) + # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu + # after init, and we need this after sync model weights for in first iter. + vllm_model = vllm_model.cuda() + + +def _get_model_weight_loader(arch: str): + if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__: + return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch] + raise ValueError(f"Model architectures {arch} are not supported for now. Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}") + + +# NOTE(sgm): we use per-parameter weight loader in each vllm sub +def update_dtensor_weight_loader(): + pass diff --git a/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py b/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..273d4ea32ce456cdf512fd8252ff350ab7e97906 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py @@ -0,0 +1,41 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models + +from typing import Dict + +import torch.nn as nn +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + + +def update_hf_weight_loader(): + print("no hf weight loader need to be updated") + return + + +def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): + assert isinstance(actor_weights, Dict) + with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights: + del actor_weights["lm_head.weight"] + vllm_model.load_weights(actor_weights.items()) + for _, module in vllm_model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + vllm_model = vllm_model.cuda() diff --git a/verl/third_party/vllm/vllm_v_0_5_4/llm.py b/verl/third_party/vllm/vllm_v_0_5_4/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7afdb6fb47dbb2aecb46e7f870f6bbc3ca70fa --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/llm.py @@ -0,0 +1,224 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py + +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from tqdm import tqdm +from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm import LLM +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.utils import Counter + +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer + +from .arg_utils import EngineArgs +from .llm_engine_sp import LLMEngine + + +class LLM(LLM): + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + NOTE: This class is intended to be used for offline inference. For online + serving, use the `AsyncLLMEngine` class instead. + NOTE: For the comprehensive list of arguments, see `EngineArgs`. + + Args: + model: A HuggingFace Transformers model instance. + tokenizer: A HuggingFace Transformers tokenizer instance. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq". If None, we assume the model weights are not + quantized and use `dtype` to determine the data type of the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Otherwise, too small values may cause out-of-memory (OOM) errors. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. + disable_custom_all_reduce: See ParallelConfig + """ + + def __init__( + self, + model: Union[nn.Module, Dict], # model itself or its parameter dict + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer], + model_hf_config: PretrainedConfig, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + skip_tokenizer_init: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + cpu_offload_gb: float = 0, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + load_format="auto", + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + engine_args = EngineArgs( + model_hf_config=model_hf_config, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + load_format=load_format, + skip_tokenizer_init=skip_tokenizer_init, + **kwargs, + ) + tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer) + if not isinstance(tokenizer, tokenizer_cls): + raise ValueError(f"Unexpected tokenizer type: {type(tokenizer)}. Must beone of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer") + self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext + self.request_counter = Counter() + + def init_cache_engine(self): + self.llm_engine.init_cache_engine() + + def free_cache_engine(self): + self.llm_engine.free_cache_engine() + + def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + return self.llm_engine.tokenizer + + def set_tokenizer( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + ) -> None: + self.llm_engine.tokenizer = tokenizer + + def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"), + ) + # Run the engine. + outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] + total_in_toks = 0 + total_out_toks = 0 + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + total_in_toks += len(output.prompt_token_ids) + in_spd = total_in_toks / pbar.format_dict["elapsed"] + total_out_toks += sum(len(stp.token_ids) for stp in output.outputs) + out_spd = total_out_toks / pbar.format_dict["elapsed"] + pbar.postfix = f"est. speed input: {in_spd:.2f} toks/s, output: {out_spd:.2f} toks/s" + pbar.update(1) + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + outputs = sorted(outputs, key=lambda x: int(x.request_id)) + return self._post_process_outputs(outputs) + + # # NOTE(shengguangming): add for verl + # # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding. + # def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]: + # # remove the left padding in the prompt token_id + # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + # non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + # token_ids = prompt_token_ids[non_pad_index:].tolist() + # return token_ids + + # NOTE(shengguangming): add for verl + def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]: + output_token_ids = [] + logprobs = [] + for request_output in request_outputs: # List[RequestOutput] + outputs = request_output.outputs + for output in outputs: # List[CompletionOutput], usually len == 1 + output_token_ids.append(torch.tensor(output.token_ids)) + # TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits + logprobs_dicts = output.logprobs + if logprobs_dicts is not None: + logprob = [] + for logprobs_dict, id in zip(logprobs_dicts, output.token_ids): + logprob.append(logprobs_dict[id].logprob) + logprobs.append(torch.tensor(logprob)) + + pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id) + if len(logprobs) > 0: + logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id) + return output_token_ids, logprobs + + def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: + self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + def offload_model_weights(self) -> None: + self.llm_engine.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py b/verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..beede755d6152c808fc6b75d7e7d1f024e8dea7d --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/llm_engine_sp.py @@ -0,0 +1,331 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py + +from typing import Dict, Iterable, Optional, Type, Union + +from torch import nn +from vllm.config import ( + CacheConfig, + DecodingConfig, + DeviceConfig, + EngineConfig, + LoRAConfig, + MultiModalConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) +from vllm.core.scheduler import Scheduler +from vllm.engine.llm_engine import LLMEngine, _load_generation_config_dict +from vllm.engine.metrics import LoggingStatLogger, PrometheusStatLogger, StatLoggerBase +from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import INPUT_REGISTRY +from vllm.logger import init_logger +from vllm.tracing import init_tracer +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message +from vllm.utils import Counter +from vllm.version import __version__ as VLLM_VERSION + +from .arg_utils import EngineArgs +from .config import LoadConfig, ModelConfig +from .tokenizer import TokenizerGroup + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + + +class LLMEngine(LLMEngine): + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The `LLM` class wraps this class for offline batched inference and the + `AsyncLLMEngine` class wraps this class for online serving. + + NOTE: The config arguments are derived from the `EngineArgs` class. For the + comprehensive list of arguments, see `EngineArgs`. + + Args: + model: the actor model initialize outside vllm (add for verl) + tokenizer: the initialized tokenizer (add for verl) + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + distributed_init_method: The initialization method for distributed + execution. See `torch.distributed.init_process_group` for details. + placement_group: Ray placement group for distributed execution. + Required for distributed execution. + log_stats: Whether to log statistics. + """ + + def __init__( + self, + # NOTE(sgm): first two arguments are added for verl + model: Union[nn.Module, Dict], # model itself or its parameter dict + tokenizer: nn.Module, + # NOTE(sgm): vllm original arguments + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], + observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + executor_class: Type[ExecutorBase], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> None: + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, revision=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " + "enable_prefix_caching=%s)", + VLLM_VERSION, + model_config.model, + speculative_config, + model_config.tokenizer, + model_config.skip_tokenizer_init, + model_config.revision, + model_config.rope_scaling, + model_config.rope_theta, + model_config.tokenizer_revision, + model_config.trust_remote_code, + model_config.dtype, + model_config.max_model_len, + load_config.download_dir, + load_config.load_format, + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.disable_custom_all_reduce, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + observability_config, + model_config.seed, + model_config.served_model_name, + scheduler_config.use_v2_block_manager, + cache_config.enable_prefix_caching, + ) + # TODO(woosuk): Print more configs in debug mode. + + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.multimodal_config = multimodal_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.load_config = load_config + self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config or ObservabilityConfig() + self.log_stats = log_stats + + # self.model = model # should not store the model, it should be deleted + # TODO(shengguangming): maybe we can choose init here or from arguments + if not self.model_config.skip_tokenizer_init: + self.tokenizer = self._init_tokenizer(tokenizer) + self.detokenizer = Detokenizer(self.tokenizer) + else: + self.tokenizer = None + self.detokenizer = None + + self.seq_counter = Counter() + self.generation_config_fields = _load_generation_config_dict(model_config) + + self.input_processor = INPUT_REGISTRY.create_input_processor(self.model_config) + + self.model_executor = executor_class( + model=model, # add for spmd_gpu_executor + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + multimodal_config=multimodal_config, + speculative_config=speculative_config, + load_config=load_config, + prompt_adapter_config=prompt_adapter_config, + ) + + # Profile the memory usage and initialize the cache. + if not self.model_config.embedding_mode: + self._initialize_kv_caches() + + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import get_architecture_class_name + + usage_message.report_usage( + get_architecture_class_name(model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": str(model_config.dtype), + "tensor_parallel_size": parallel_config.tensor_parallel_size, + "block_size": cache_config.block_size, + "gpu_memory_utilization": cache_config.gpu_memory_utilization, + # Quantization + "quantization": model_config.quantization, + "kv_cache_dtype": str(cache_config.cache_dtype), + # Feature flags + "enable_lora": bool(lora_config), + "enable_prompt_adapter": bool(prompt_adapter_config), + "enable_prefix_caching": cache_config.enable_prefix_caching, + "enforce_eager": model_config.enforce_eager, + "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, + }, + ) + + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + self.scheduler = [Scheduler(scheduler_config, cache_config, lora_config, parallel_config.pipeline_parallel_size) for _ in range(parallel_config.pipeline_parallel_size)] + + # Metric Logging. + if self.log_stats: + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + self.stat_loggers = { + "logging": LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + "prometheus": PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len, + ), + } + self.stat_loggers["prometheus"].info("cache_config", self.cache_config) + + self.tracer = None + if self.observability_config.otlp_traces_endpoint: + self.tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint) + + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + self.get_tokenizer_for_seq, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + self.get_tokenizer_for_seq, + ), + ) + + # TODO(sgm): add for verl but we may not tokenizer in Rollout + def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs): + init_kwargs = dict(enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, max_input_length=None) + init_kwargs.update(tokenizer_init_kwargs) + return TokenizerGroup(tokenizer, **init_kwargs) + + def init_cache_engine(self): + # TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache + # Re-capture CUDAGraph would be time-consuming + self.model_executor.init_cache_engine() + + def free_cache_engine(self): + self.model_executor.free_cache_engine() + + # NOTE(sgm): currently, we only support GPU executor + # The GPUExecutor remove the Ray dependency + @classmethod + def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]: + assert engine_config.device_config.device_type == "cuda", "Currently, the vllm in verl only support running on GPU" + + if engine_config.parallel_config.world_size == 1: + engine_config.load_config.load_format = "dummy_hf" + + from .spmd_gpu_executor import SPMDGPUExecutor + + executor_class = SPMDGPUExecutor + return executor_class + + @classmethod + def from_engine_args( + cls, + model, + tokenizer, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(engine_config) + # Initialize the cluster and specify the executor class. + assert engine_config.device_config.device_type == "cuda", "Currently, the vllm in verl only support running on GPU" + + from .spmd_gpu_executor import SPMDGPUExecutor + + executor_class = SPMDGPUExecutor + + # Create the LLM engine. + engine = cls( + model, + tokenizer, + **engine_config.to_dict(), + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine + + def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: + self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + def offload_model_weights(self) -> None: + self.model_executor.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..c50b4ee3aa35da77287819fabeb27257e6fbfa04 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/megatron_weight_loaders.py @@ -0,0 +1,219 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models + +from typing import Dict, Iterable + +import torch +import torch.nn as nn +from vllm.model_executor.layers.linear import * +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding +from vllm.model_executor.models import ModelRegistry + + +# NOTE(shengguangming): replace the origin weight loader function in the class +def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Parallel Linear weight loader.""" + assert param.size() == loaded_weight.size(), "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format(param.size(), loaded_weight.size()) + assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" + + param.data = loaded_weight.data + + +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + assert param.size() == loaded_weight.size() + assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" + + param.data = loaded_weight.data + + +def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + # TODO: check megatron + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def _replace_name(megatron_name, name_mapping): + for m_name, v_name in name_mapping: + if m_name not in megatron_name: + continue + if "layers" in megatron_name: # deal with decoder layers + megatron_name = megatron_name.replace("decoder", "model") + megatron_name_list = megatron_name.split(".") + if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: + param_name_list = megatron_name_list[:3] + param_name_list.append(v_name) + param_name = ".".join(param_name_list) + else: + param_name_list = megatron_name_list[:3] + weight_or_bias = megatron_name_list[-1] + param_name_list.append(v_name) + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + return param_name + else: + param_name = megatron_name.replace(m_name, v_name) + return param_name + + +def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), + ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ( + "input_layernorm", + "input_layernorm", + ), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def mistral_megatron_weight_loader(actor_weights: Iterable, vllm_model: nn.Module) -> nn.Module: + # TODO: need to implement a general way to deal with prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { + ColumnParallelLinear: parallel_weight_loader, + MergedColumnParallelLinear: parallel_weight_loader, + QKVParallelLinear: parallel_weight_loader, + RowParallelLinear: parallel_weight_loader, + VocabParallelEmbedding: parallel_weight_loader, + ParallelLMHead: parallel_weight_loader, + # "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights + # "default_weight_loader": default_weight_loader +} + +# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): +# # setattr(layer_class, 'megatron_weight_loader', weight_loader) +# layer_class.weight_loader = weight_loader + +__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = { + "GPT2LMHeadModel": gpt2_weight_loader, + "LlamaForCausalLM": llama_megatron_weight_loader, # use te backend for open-source megatron + "LLaMAForCausalLM": llama_megatron_weight_loader, + "MistralForCausalLM": mistral_megatron_weight_loader, +} + + +# the actor model is .state_dict() +# Load megatron weights +def load_megatron_weights(actor_weights: Iterable, vllm_model: nn.Module): + weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) + weight_loader(actor_weights, vllm_model) + # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu + # after init, and we need this after sync model weights for in first iter. + vllm_model = vllm_model.cuda() + + +def _get_model_weight_loader(arch: str): + if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__: + return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch] + raise ValueError(f"Model architectures {arch} are not supported for now. Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def update_megatron_weight_loader(): + for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): + layer_class.weight_loader = weight_loader diff --git a/verl/third_party/vllm/vllm_v_0_5_4/model_loader.py b/verl/third_party/vllm/vllm_v_0_5_4/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..30da15a265de22dbd625298c3dbad7f9eb9b64bd --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/model_loader.py @@ -0,0 +1,329 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + MultiModalConfig, + ParallelConfig, + SchedulerConfig, +) +from vllm.distributed.communication_op import tensor_model_parallel_all_gather +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader import BaseModelLoader +from vllm.model_executor.model_loader.loader import _initialize_model +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + +from .config import LoadConfig, LoadFormat, ModelConfig +from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader +from .hf_weight_loader import update_hf_weight_loader +from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader + + +def get_model( + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + cache_config: CacheConfig = None, +) -> nn.Module: + loader = get_model_loader(load_config) + if load_config.load_format.startswith("dummy"): + return loader.load_model( + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + multimodal_config=multimodal_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) + else: + return loader.load_model( + actor_model=actor_model, + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + multimodal_config=multimodal_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.AUTO: + update_megatron_weight_loader() + return MegatronLoader(load_config) + + # NOTE(sgm): change the weight_loader function in runtime + if load_config.load_format == LoadFormat.MEGATRON: + update_megatron_weight_loader() + return MegatronLoader(load_config) + + if load_config.load_format == LoadFormat.HF: + update_hf_weight_loader() + return HFLoader(load_config) + + if load_config.load_format == LoadFormat.DTENSOR: + update_dtensor_weight_loader() + return DTensorLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_HF: + update_hf_weight_loader() + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_MEGATRON: + update_megatron_weight_loader() + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_DTENSOR: + update_dtensor_weight_loader() + return DummyModelLoader(load_config) + + raise ValueError("load format not supported in verl: {}, only support {} and {}".format(load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF)) + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype), torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + # initialize_dummy_weights(model) + return model.eval() + + +class MegatronLoader(BaseModelLoader): + """Model loader that can load the model weights from partitioned megatron model.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") + + def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): + # NOTE(shengguangming) Load the weights from the actor model + pass + # if isinstance(actor_model, nn.Module): + # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + # else: + # load_weights(actor_weights=actor_model, vllm_model=model) + # return actor_model + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) + + # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm + if isinstance(actor_model, nn.Module): + load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + else: + load_megatron_weights(actor_weights=actor_model, vllm_model=model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +class HFLoader(BaseModelLoader): + """Model loader that can load the model weights from model's full params.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") + + def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]): + if isinstance(actor_model, Dict): + return actor_model.items() + elif isinstance(actor_model, nn.Module): + return dict(actor_model.named_parameters()).items() + else: + raise ValueError(f"actor model should be Dict or nn.Module, but get {type(actor_model)}") + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + # with torch.device(device_config.device): + # NOTE(sgm): init the model in cpu + model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) + model.load_weights(self._get_weights_iterator(actor_model)) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +class DTensorLoader(BaseModelLoader): + """Model loader that can load the model weights from partitioned megatron model.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") + + def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): + # NOTE(shengguangming) Load the weights from the actor model + pass + # if isinstance(actor_model, nn.Module): + # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + # else: + # load_weights(actor_weights=actor_model, vllm_model=model) + # return actor_model + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) + + # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm + if isinstance(actor_model, nn.Module): + load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + else: + load_dtensor_weights(actor_weights=actor_model, vllm_model=model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +# FIXME(sgm): hack the _get_logits function in vllm v0.4.2 +# as they use ray, the _get_logits result will only need to return to the driver node, +# therefore gather is enough. However, we use SPMD instead of a central scheduler, +# all_gather is required (aligned with v0.2.6) +def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, : self.org_vocab_size] + return logits + + +def logitsprocessor_init( + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None, +) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super(LogitsProcessor, self).__init__() + self.scale = scale + self.vocab_size = vocab_size + # Whether the input is logits (default is hidden states). + self.logits_as_input = logits_as_input + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + # Soft cap the logits. Used in Gemma 2. + self.soft_cap = soft_cap + # Whether to use gather or all-gather to gather the logits. + self.use_gather = False + + +LogitsProcessor.__init__ = logitsprocessor_init # use all_gather diff --git a/verl/third_party/vllm/vllm_v_0_5_4/model_runner.py b/verl/third_party/vllm/vllm_v_0_5_4/model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..71645634aa32e586502bca30212bf39e3a2ed9e3 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/model_runner.py @@ -0,0 +1,155 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py + +import warnings +from enum import IntEnum +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import vllm.envs as envs +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + MultiModalConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, +) +from vllm.logger import init_logger +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models.interfaces import supports_lora, supports_vision +from vllm.prompt_adapter.worker_manager import LRUCacheWorkerPromptAdapterManager +from vllm.utils import CudaMemoryProfiler, is_hip +from vllm.worker.model_runner import ModelRunner + +from .config import LoadConfig, ModelConfig +from .model_loader import get_model + +logger = init_logger(__name__) + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + + +class ModelRunner(ModelRunner): + def __init__( + self, + model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + multimodal_config: Optional[MultiModalConfig] = None, + return_hidden_states: bool = False, + ): + super().__init__( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config, + kv_cache_dtype, + is_driver_worker=True, # a hack + prompt_adapter_config=prompt_adapter_config, + multimodal_config=multimodal_config, + return_hidden_states=return_hidden_states, + ) + + # NOTE(sgm): add for verl + self.model = model # this will be replaced by get_model() + + # NOTE(sgm): initialize model using the actor model + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with CudaMemoryProfiler() as m: + self.model = get_model( + actor_model=self.model, + model_config=self.model_config, + device_config=self.device_config, + lora_config=self.lora_config, + load_config=self.load_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + multimodal_config=self.multimodal_config, + cache_config=self.cache_config, + ) + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) + + if self.lora_config: + assert supports_lora(self.model), "Model does not support LoRA" + assert not supports_vision(self.model), "To be tested: vision language model with LoRA settings." + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=self.model.config.max_position_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.device, + self.prompt_adapter_config, + ) + self.model = self.prompt_adapter_manager.create_prompt_adapter_manager(self.model) + + if self.kv_cache_dtype == "fp8" and is_hip(): + # Currently only ROCm accepts kv-cache scaling factors + # via quantization_param_path and this will be deprecated + # in the future. + if self.model_config.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + warnings.warn( + "Loading kv cache scaling factor from JSON is deprecated and will be removed. Please include kv cache scaling factors in the model checkpoint.", + FutureWarning, + stacklevel=2, + ) + self.model.load_kv_cache_scales(self.model_config.quantization_param_path) + logger.info("Loaded KV cache scaling factors from %s", self.model_config.quantization_param_path) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but model %s does not support loading scaling factors.", + self.model.__class__, + ) + else: + logger.warning("Using FP8 KV cache but no scaling factors provided. Defaulting to scaling factors of 1.0. This may lead to less accurate results!") + + if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE: + self.model = torch.compile(self.model, fullgraph=True, backend="eager") diff --git a/verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py b/verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py new file mode 100644 index 0000000000000000000000000000000000000000..6888c1e9b895abeba51319d8fca4e21a74bd2a08 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/parallel_state.py @@ -0,0 +1,302 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""Model and data parallel groups.""" + +import os +from typing import Optional + +import torch +import torch.distributed +import vllm.distributed.parallel_state as ps +from vllm.distributed.parallel_state import ( + get_pp_group, + get_world_group, + init_distributed_environment, + init_model_parallel_group, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) +""" +This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. +- We assume the Megatron tp+dp+pp world is already established before calling this function. + +""" + +# Device mesh for using DTensor +_DEVICE_MESH = None + +# Tensor model parallel group that the current rank belongs to. +_TP = None +# Pipeline model parallel group that the current rank belongs to. +_PP = None + + +# This method is for initializing the ParallelGroup when using HybridEngine +def initialize_parallel_state( + distributed_init_method: str = "env://", + backend: str = "nccl", + tensor_model_parallel_size: int = 1, + num_tp_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +): + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + rank = int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) + if torch.distributed.get_world_size() > 1: + # NOTE: build a sepearate inference group with infer tp & micro dp + initialize_model_parallel_for_vllm( + tensor_model_parallel_size=tensor_model_parallel_size, + num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, + ) + else: + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + return + + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, f"tensor parallel group already initialized, but of unexpected size: {get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}" + pp_world_size = get_pp_group().world_size + assert pp_world_size == pipeline_model_parallel_size, f"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. {pipeline_model_parallel_size=}" + + +# TODO(sgm): deviate from the v0.5.4, not pp now +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ps._TP is not None + # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + + +def initialize_model_parallel_for_vllm( + tensor_model_parallel_size: int, + num_tensor_model_parallel_groups_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +) -> None: + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + + assert isinstance(tensor_model_parallel_size, int) + + # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group + # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group + + # Build the tensor model-parallel groups. + assert ps._TP is None, "tensor model parallel group is already initialized" + + global _TP + + world_size: int = torch.distributed.get_world_size() + + backend = torch.distributed.get_backend() + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + + if num_tensor_model_parallel_groups_per_train_tp == 1: + # if tensor_model_parallel_size == train_tensor_parallel_size: + # using the same tp group as Megatron/vllm + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine + else: + # initialize a micro_dp group and a tp group + # assume training tp=4, infer tp=2, then, weight is partitioned as + # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference + + # Build the inference tp groups + # train_tp = train_tensor_parallel_size + train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size + # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): + start = train_tp * i + end = train_tp * (i + 1) + for j in range(num_tensor_model_parallel_groups_per_train_tp): + ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp)) + for i in range(len(ranks)): + ranks[i] += j + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # Build the pipeline model-parallel groups. + # global _PIPELINE_MODEL_PARALLEL_GROUP + # global _PIPELINE_GLOBAL_RANKS + # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized") + + # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group() + # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks() + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + NOTE: This method is a hack from the open-sourced version without + asertion of world_size = tp * pp + + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) + + # NOTE(sgm) we don't assert world_size == tp * pp + # DP is not managed by vllm but by the verl WorkerGroup + # if (world_size != + # tensor_model_parallel_size * pipeline_model_parallel_size): + # raise RuntimeError( + # f"world_size ({world_size}) is not equal to " + # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + global _TP + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +""" +Device mesh utilities +""" + + +def get_device_mesh(): + assert _DEVICE_MESH is not None, "device mesh is not initialized" + return _DEVICE_MESH + + +""" +Tensor model parallel utilities +""" + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP.device_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size diff --git a/verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py b/verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..3647bbaa2576c7c6d4622e8d359408b97dcb3584 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py @@ -0,0 +1,250 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py + +import os +import socket +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + MultiModalConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import ExecuteModelRequest, SamplerOutput + +from .config import LoadConfig, ModelConfig + +logger = init_logger(__name__) + + +class SPMDGPUExecutor(ExecutorBase): + """SPMD-based multi-GPU executor implementations.""" + + def __init__( + self, + model, # pytorch model itself or its parameter dict + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + ) -> None: + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.multimodal_config = multimodal_config + self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config + + distributed_init_method = initialize_cluster(parallel_config) + self._init_executor(model, distributed_init_method) + + # TODO(sgm): verl not support speculative decode now + def _init_executor(self, model, distributed_init_method) -> None: + assert not self.speculative_config, "Speculative decoding not yet supported for multi-GPU backend." + + # Create the parallel worker for each GPU. + self._init_workers_sp(model, distributed_init_method) + + def _init_workers_sp(self, model, distributed_init_method: str): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from .worker import Worker + + rank = int(os.getenv("RANK")) + local_rank = int(os.getenv("LOCAL_RANK")) + print(f"local rank {local_rank}") + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ["NCCL_CUMEM_ENABLE"] = "0" + + self.worker = Worker( + model, + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + self.cache_config, + self.load_config, + local_rank, + rank, + distributed_init_method, + lora_config=self.lora_config, + multimodal_config=self.multimodal_config, + speculative_config=None, + prompt_adapter_config=self.speculative_config, + is_driver_worker=True, + model_runner_cls=None, # use the default one + ) + + # NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model() + self.worker.init_device() + self.worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self.worker.determine_num_available_blocks() + + # NOTE(shengguangming): Now we don't use a shared centralized controler but each process will + # have its own scheduler + num_gpu_blocks = num_blocks[0] + num_cpu_blocks = num_blocks[1] + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers.""" + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + if torch.distributed.get_rank() == 0: + print(f"before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB") + self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + if torch.distributed.get_rank() == 0: + print(f"after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB") + + # NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache + def init_cache_engine(self) -> None: + self.worker._init_cache_engine() + + def free_cache_engine(self) -> None: + self.worker.free_cache_engine() + + def execute_model(self, execute_model_req) -> List[SamplerOutput]: + all_outputs = self.worker.execute_model(execute_model_req=execute_model_req) + + # NOTE(sgm): + # Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs + # In vllm with ray, only the driver worker returns the sampling results. + return all_outputs + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self.worker.add_lora(lora_request=lora_request) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.worker.remove_lora(lora_id=lora_id) + + def list_loras(self) -> Set[int]: + return self.worker.list_loras() + + def check_health(self) -> None: + # SPMDExecutor will always be healthy as long as + # it's running. + return + + # NOTE(sgm) add for verl to pass the abstract class test, not used + from vllm.prompt_adapter.request import PromptAdapterRequest + + def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.add_prompt_adapter(prompt_adapter_request) + + def list_prompt_adapters(self) -> Set[int]: + return self.worker.list_prompt_adapters() + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.worker.pin_lora(lora_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.pin_prompt_adapter(prompt_adapter_id) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.remove_prompt_adapter(prompt_adapter_id) + + # NOTE(sgm): add for verl + def offload_model_weights(self) -> None: + self.worker.offload_model_weights() + + def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: + self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + +def initialize_cluster( + parallel_config: ParallelConfig, + engine_use_ray: bool = False, + ray_address: Optional[str] = None, +) -> Tuple[str, Optional[None]]: + """Initialize the distributed cluster probably with Ray. + + Args: + parallel_config: The configurations for parallel execution. + + Returns: + The `distributed_init_method` is the address for initializing the + distributed backend. + """ + + # Initialize cluster locally. + # We need to setup the distributed init method to make sure + # the distributed megatron code (e.g., get world size) works correctly. + # distributed_init_method = f"tcp://localhost:{port}" + distributed_init_method = "env://" + return distributed_init_method + + +def get_open_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +# TODO(sgm): not implemented async executor yet +class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase): + async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + """Executes one model step on the given sequences.""" + raise NotImplementedError + + async def check_health_async(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + self.check_health() diff --git a/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py b/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd77e9296b2ec3312dcbff47b949f7c154094fd --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py @@ -0,0 +1,69 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py + +from typing import List, Optional + +from transformers import PreTrainedTokenizer +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizers import * +from vllm.utils import LRUCache + + +class TokenizerGroup: + """A group of tokenizers that can be used for LoRA adapters.""" + + def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int]): + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = tokenizer + self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None + + def ping(self) -> bool: + """Check if the tokenizer group is alive.""" + return True + + def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self.max_input_length + + def encode(self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + return tokenizer.encode(prompt) + + async def encode_async(self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + return tokenizer.encode(prompt) + + def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + # TODO(sgm): the lora tokenizer is also passed, but may be different + tokenizer = self.tokenizer + # tokenizer = (get_lora_tokenizer( + # lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + # FIXME(sgm): for simplicity, we assign the special token here + @property + def pad_token_id(self): + return self.tokenizer.pad_token_id + + @property + def eos_token_id(self): + return self.tokenizer.eos_token_id diff --git a/verl/third_party/vllm/vllm_v_0_5_4/worker.py b/verl/third_party/vllm/vllm_v_0_5_4/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec8d718a16fa193d7fc6c70ceb066c137eeb730 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_5_4/worker.py @@ -0,0 +1,323 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py +"""A GPU worker class.""" + +import gc +import os +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.distributed +import torch.nn as nn +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + MultiModalConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) + +# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state +from vllm.distributed import get_tensor_model_parallel_group, init_distributed_environment, set_custom_all_reduce +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest, IntermediateTensors, SamplerOutput +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase +from vllm.worker.model_runner_base import ModelRunnerInputBase +from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype +from vllm.worker.worker_base import WorkerInput + +from .config import LoadConfig, LoadFormat, ModelConfig +from .dtensor_weight_loaders import load_dtensor_weights +from .hf_weight_loader import load_hf_weights +from .megatron_weight_loaders import load_megatron_weights +from .model_runner import ModelRunner +from .parallel_state import ensure_model_parallel_initialized + + +class Worker(Worker): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model: Union[nn.Module, Dict], # model itself or its parameter dict + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + multimodal_config: Optional[MultiModalConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, + ) -> None: + # self.model = model # will be replaced in the init_model + self.model_config = model_config + self.parallel_config = parallel_config + self.parallel_config.rank = rank + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker # TODO: we don't need driver + # if parallel_config and is_driver_worker: + # assert rank % parallel_config.tensor_parallel_size == 0, \ + # "Driver worker should be rank 0 of tensor parallel group." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + + init_cached_hf_modules() + self.multimodal_config = multimodal_config + + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_args = {} if speculative_config is None or (speculative_config.draft_model_config.model == model_config.model) or (speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) else {"return_hidden_states": True} + + # TODO(sgm): set correct model runner class + ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner + if model_runner_cls is not None: + ModelRunnerClass = model_runner_cls + elif self.model_config.embedding_mode: + ModelRunnerClass = EmbeddingModelRunner + self.model_runner: GPUModelRunnerBase = ModelRunnerClass( + model, # [VERL]: add for verl + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + multimodal_config=multimodal_config, + **speculative_args, + ) + + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] = None + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.gpu_cache: Optional[List[List[torch.Tensor]]] = None + + # NOTE(sgm): [VERL] For offloading inference engine params + self.cpu_model = None + + def init_device(self) -> None: + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + self.device = torch.device(f"cuda:{local_rank}") + if self.rank < 0: + raise ValueError("Invalid or unspecified rank.") + torch.cuda.set_device(self.device) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + self.parallel_config.world_size = world_size + + _check_if_gpu_supports_dtype(self.model_config.dtype) + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError(f"Not support device type: {self.device_config.device}") + + # Initialize the distributed environment. + init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + # self.model = get_model(actor_model=self.model, model_config=self.model_config) + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + # torch.cuda.reset_peak_memory_stats() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = total_gpu_memory - free_gpu_memory + + assert peak_memory > 0, "Error in memory profiling. This happens when the GPU memory was not properly cleaned up before initializing the vLLM instance." + + cache_block_size = self.get_cache_block_size_bytes() + + # NOTE(sgm) [VERL] use the remaining memory + num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size) + # num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size) + + num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + + # NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank + num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda") + num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda") + + torch.distributed.all_reduce(num_gpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group) + torch.distributed.all_reduce(num_cpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group) + num_gpu_blocks = num_gpu_blocks.item() + num_cpu_blocks = num_cpu_blocks.item() + gc.collect() + torch.cuda.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def _init_cache_engine(self): + if self.cache_engine is None and self.gpu_cache is None: + super()._init_cache_engine() + + def free_cache_engine(self): + # ensure `enforce_eager=True` + self.cache_engine = None + self.gpu_cache = None + + # NOTE(sgm): [VERL]: adapt from _execute_model_spmd() + def execute_model(self, execute_model_req: ExecuteModelRequest, intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]: + """ + Execute model in Single Program Multiple Data (SPMD) fashion. + All workers take the same request, prepare the input and + execute the model. + """ + assert execute_model_req is not None, "_execute_model_spmd() requires each worker to take in an ExecuteModelRequest" + worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = self.model_runner.prepare_model_input(execute_model_req.seq_group_metadata_list) + + # verl.worker.workerbase.WorkerBase + # swap cache + super().execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + return self.model_runner.execute_model( + model_input, + self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, + intermediate_tensors, + ) + + # assume the input is .state_dict() + def sync_model_weights(self, actor_weights: Dict, load_format: str): + if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]: + load_megatron_weights(actor_weights, self.model_runner.model) + elif load_format == LoadFormat.HF: + # full model state dict without no sharding + load_hf_weights(actor_weights, self.model_runner.model) + elif load_format == LoadFormat.DTENSOR: + load_dtensor_weights(actor_weights, self.model_runner.model) + + def offload_model_weights(self) -> None: + if self.cpu_model is None: + self.cpu_model = {} + for name, params in self.model_runner.model.named_parameters(): + self.cpu_model[name] = torch.empty_like(params, device="cpu") + params.data = self.cpu_model[name] + else: + for name, params in self.model_runner.model.named_parameters(): + params.data = self.cpu_model[name] + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = "env://", + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + # NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) + + ensure_model_parallel_initialized( + tensor_model_parallel_size=parallel_config.tensor_parallel_size, + pipeline_model_parallel_size=parallel_config.pipeline_parallel_size, + ) + + # TODO(sgm): check whether need this + # if pynccl_utils.is_initialized(): + # pynccl_world_size = pynccl_utils.get_world_size() + # if pynccl_world_size != parallel_config.world_size: + # raise RuntimeError( + # "pynccl is already initialized but the pynccl world " + # "size does not match parallel_config.world_size " + # f"({pynccl_world_size} vs. {parallel_config.world_size}).") + # elif parallel_config.world_size > 1: + # # NOTE(woosuk): We don't initialize pynccl process group when world size + # # is 1. + # # NOTE(kaichao): By default, pynccl is initialized for tp group. + # pynccl_utils.init_process_group( + # group=get_tensor_model_parallel_cpu_group()) + + # # Initialize a custom fast all-reduce implementation. + # if not parallel_config.disable_custom_all_reduce: + # init_custom_ar() + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + # if pynccl_utils.is_initialized(): + # pynccl_utils.all_reduce(torch.zeros(1).cuda()) diff --git a/verl/third_party/vllm/vllm_v_0_6_3/__init__.py b/verl/third_party/vllm/vllm_v_0_6_3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py b/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..53510a565360d11ff42291a68ef5f51fecec1993 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py @@ -0,0 +1,78 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py + +import os +from dataclasses import dataclass + +from transformers import PretrainedConfig +from vllm.config import EngineConfig +from vllm.engine.arg_utils import EngineArgs + +from .config import LoadConfig, ModelConfig + + +@dataclass +class EngineArgs(EngineArgs): + model_hf_config: PretrainedConfig = None # for verl + + def __post_init__(self): + pass + + def create_model_config(self) -> ModelConfig: + return ModelConfig( + hf_config=self.model_hf_config, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + enforce_eager=self.enforce_eager, + max_context_len_to_capture=self.max_context_len_to_capture, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name, + limit_mm_per_prompt=self.limit_mm_per_prompt, + use_async_output_proc=not self.disable_async_output_proc, + override_neuron_config=self.override_neuron_config, + config_format=self.config_format, + mm_processor_kwargs=self.mm_processor_kwargs, + ) + + def create_load_config(self) -> LoadConfig: + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + def create_engine_config(self) -> EngineConfig: + engine_config = super().create_engine_config() + + # NOTE[VERL]: Use the world_size set by torchrun + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + engine_config.parallel_config.world_size = world_size + + return engine_config diff --git a/verl/third_party/vllm/vllm_v_0_6_3/config.py b/verl/third_party/vllm/vllm_v_0_6_3/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c22413fbcdbd1fdbcf50291043f10ac7fbf53eca --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/config.py @@ -0,0 +1,100 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py + +import enum +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Union + +from transformers import PretrainedConfig + +# Add for verl +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.utils import is_hip + +if TYPE_CHECKING: + from vllm.model_executor.model_loader.loader import BaseModelLoader + +logger = init_logger(__name__) + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + MEGATRON = "megatron" + HF = "hf" + DTENSOR = "dtensor" + DUMMY_HF = "dummy_hf" + DUMMY_MEGATRON = "dummy_megatron" + DUMMY_DTENSOR = "dummy_dtensor" + + +class ModelConfig(ModelConfig): + def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None: + super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs) # noqa: B026 + self.hf_config = hf_config + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + + """ + + load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads(model_loader_extra_config) + self._verify_load_format() + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)] + raise ValueError(f"load format '{load_format}' is not supported in ROCm. Supported load formats are {rocm_supported_load_format}") diff --git a/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..feff1270abce6a8afaed8936ae04c230ad166dda --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py @@ -0,0 +1,374 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch.nn as nn +from torch.distributed._tensor import DTensor +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import is_pp_missing_parameter + + +def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + stacked_name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if stacked_name.endswith(".bias") and stacked_name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[stacked_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "lm_head.weight" in name: + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight) + + +def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=vllm_model.config.n_routed_experts, + ) + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + local_loaded_weight.to(dtype=param.dtype), + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + pass + + +def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None): + param_name = _process_parameter_names(name=param_name) + if parallelize_plan is not None: + assert param_name in parallelize_plan, f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" + placement = parallelize_plan[param_name] + local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh, placements=placement).to_local() + else: + local_loaded_weights = loaded_weights.full_tensor() + return local_loaded_weights + + +def _process_parameter_names(name): + # Remove '.weight' if it exists at the end of the string + if name.endswith(".weight"): + name = name[:-7] + + # Remove 'model.layers.x.' or 'model.' prefix + if "model.layers" in name: + parts = name.split(".") + # Reconstruct the string without 'model.layers.x.' + name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' + elif name.startswith("model."): + name = name[6:] # Remove 'model.' + + return name + + +__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = { + "GPT2LMHeadModel": gpt2_dtensor_weight_loader, + "LlamaForCausalLM": llama_dtensor_weight_loader, + "LLaMAForCausalLM": llama_dtensor_weight_loader, + "MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM + "InternLMForCausalLM": llama_dtensor_weight_loader, + "AquilaModel": llama_dtensor_weight_loader, + "AquilaForCausalLM": llama_dtensor_weight_loader, + "Phi3ForCausalLM": llama_dtensor_weight_loader, + "GemmaForCausalLM": gemma_dtensor_weight_loader, + "Gemma2ForCausalLM": gemma_dtensor_weight_loader, + "GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights, + "Starcoder2ForCausalLM": starcoder2_dtensor_load_weights, + "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, + "DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader, + "Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader, +} + + +# the actor model is .state_dict() +# Load dtensor weights +def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): + weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) + weight_loader(actor_weights, vllm_model) + # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu + # after init, and we need this after sync model weights for in first iter. + vllm_model = vllm_model.cuda() + + +def _get_model_weight_loader(arch: str): + if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__: + return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch] + raise ValueError(f"Model architectures {arch} are not supported for now. Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}") + + +# NOTE(sgm): we use per-parameter weight loader in each vllm sub +def update_dtensor_weight_loader(): + pass diff --git a/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py b/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..e2279fd88d4960b414ae453b43472f4f63409f2e --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py @@ -0,0 +1,41 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch.nn as nn +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + + +def update_hf_weight_loader(): + print("no hf weight loader need to be updated") + return + + +def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): + assert isinstance(actor_weights, Dict) + with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights: + del actor_weights["lm_head.weight"] + vllm_model.load_weights(actor_weights.items()) + for _, module in vllm_model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + vllm_model = vllm_model.cuda() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm.py b/verl/third_party/vllm/vllm_v_0_6_3/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..5197caccfed5ebc789f533ca2f5f1e128d6fd07a --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/llm.py @@ -0,0 +1,197 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py + +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm import LLM +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.utils import Counter + +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer + +from .arg_utils import EngineArgs +from .llm_engine_sp import LLMEngine + + +class LLM(LLM): + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + NOTE: This class is intended to be used for offline inference. For online + serving, use the `AsyncLLMEngine` class instead. + NOTE: For the comprehensive list of arguments, see `EngineArgs`. + + Args: + model: A HuggingFace Transformers model instance. + tokenizer: A HuggingFace Transformers tokenizer instance. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq". If None, we assume the model weights are not + quantized and use `dtype` to determine the data type of the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Otherwise, too small values may cause out-of-memory (OOM) errors. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. + disable_custom_all_reduce: See ParallelConfig + """ + + def __init__( + self, + model: Union[nn.Module, Dict], # model itself or its parameter dict + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer], + model_hf_config: PretrainedConfig, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + skip_tokenizer_init: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + cpu_offload_gb: float = 0, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + load_format="auto", + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + removed_vision_keys = ("image_token_id", "image_feature_size", "image_input_shape", "image_input_type") + if any(k in kwargs for k in removed_vision_keys): + raise TypeError("There is no need to pass vision-related arguments anymore.") + engine_args = EngineArgs( + model_hf_config=model_hf_config, + # tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + load_format=load_format, + **kwargs, + ) + tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer) + if not isinstance(tokenizer, tokenizer_cls): + raise ValueError(f"Unexpected tokenizer type: {type(tokenizer)}. Must beone of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer") + self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext + self.request_counter = Counter() + + def init_cache_engine(self): + self.llm_engine.init_cache_engine() + + def free_cache_engine(self): + self.llm_engine.free_cache_engine() + + def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + return self.llm_engine.tokenizer + + def set_tokenizer( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + ) -> None: + self.llm_engine.tokenizer = tokenizer + + def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + outputs = super()._run_engine(use_tqdm=use_tqdm) + return self._post_process_outputs(outputs) + + # # NOTE(shengguangming): add for verl + # # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding. + # def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]: + # # remove the left padding in the prompt token_id + # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + # non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + # token_ids = prompt_token_ids[non_pad_index:].tolist() + # return token_ids + + # NOTE(shengguangming): add for verl + def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]: + output_token_ids = [] + logprobs = [] + for request_output in request_outputs: # List[RequestOutput] + outputs = request_output.outputs + for output in outputs: # List[CompletionOutput], usually len == 1 + output_token_ids.append(torch.tensor(output.token_ids)) + # TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits + logprobs_dicts = output.logprobs + if logprobs_dicts is not None: + logprob = [] + for logprobs_dict, id in zip(logprobs_dicts, output.token_ids): + logprob.append(logprobs_dict[id].logprob) + logprobs.append(torch.tensor(logprob)) + + pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id) + if len(logprobs) > 0: + logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id) + return output_token_ids, logprobs + + def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: + self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + def offload_model_weights(self) -> None: + self.llm_engine.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py b/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py new file mode 100644 index 0000000000000000000000000000000000000000..855564e534af8e6b76285b8f7c9653c49019a069 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py @@ -0,0 +1,390 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py + +from functools import partial +from typing import Callable, Dict, Iterable, Optional, Type, Union + +import torch.nn as nn +from vllm.config import ( + CacheConfig, + DecodingConfig, + DeviceConfig, + EngineConfig, + LoRAConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) +from vllm.core.scheduler import Scheduler +from vllm.engine.llm_engine import LLMEngine, SchedulerContext, SchedulerOutputState, _load_generation_config_dict +from vllm.engine.metrics_types import StatLoggerBase +from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.sequence import Sequence +from vllm.tracing import init_tracer +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message +from vllm.utils import Counter, weak_bind +from vllm.version import __version__ as VLLM_VERSION + +from .arg_utils import EngineArgs +from .config import LoadConfig, ModelConfig +from .tokenizer import TokenizerGroup + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + + +class LLMEngine(LLMEngine): + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. + + The config arguments are derived from :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + device_config: The configuration related to the device. + lora_config (Optional): The configuration related to serving multi-LoRA. + speculative_config (Optional): The configuration related to speculative + decoding. + executor_class: The model executor class for managing distributed + execution. + prompt_adapter_config (Optional): The configuration related to serving + prompt adapters. + log_stats: Whether to log statistics. + usage_context: Specified entry point, used for usage info collection. + """ + + def __init__( + self, + # NOTE(sgm): first two arguments are added for verl + model: Union[nn.Module, Dict], # model itself or its parameter dict + tokenizer: nn.Module, + # NOTE(sgm): vllm original arguments + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], + observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + executor_class: Type[ExecutorBase], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + use_cached_outputs: bool = False, + ) -> None: + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " + "num_scheduler_steps=%d, chunked_prefill_enabled=%s " + "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " + "use_async_output_proc=%s, use_cached_outputs=%s, " + "mm_processor_kwargs=%s)", + VLLM_VERSION, + model_config.model, + speculative_config, + model_config.tokenizer, + model_config.skip_tokenizer_init, + model_config.tokenizer_mode, + model_config.revision, + model_config.override_neuron_config, + model_config.rope_scaling, + model_config.rope_theta, + model_config.tokenizer_revision, + model_config.trust_remote_code, + model_config.dtype, + model_config.max_model_len, + load_config.download_dir, + load_config.load_format, + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.disable_custom_all_reduce, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + observability_config, + model_config.seed, + model_config.served_model_name, + scheduler_config.use_v2_block_manager, + scheduler_config.num_scheduler_steps, + scheduler_config.chunked_prefill_enabled, + scheduler_config.multi_step_stream_outputs, + cache_config.enable_prefix_caching, + model_config.use_async_output_proc, + use_cached_outputs, + model_config.mm_processor_kwargs, + ) + # TODO(woosuk): Print more configs in debug mode. + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.load_config = load_config + self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config or ObservabilityConfig() + self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs + + if not self.model_config.skip_tokenizer_init: + self.tokenizer = self._init_tokenizer(tokenizer) + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() + else: + self.tokenizer = None + self.detokenizer = None + tokenizer_group = None + + # Ensure that the function doesn't contain a reference to self, + # to avoid engine GC issues + def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: + assert tokenizer_group, "tokenizer_group cannot be None, make sure skip_tokenizer_init is False" + return tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + self.seq_counter = Counter() + self.generation_config_fields = _load_generation_config_dict(model_config) + + self.input_preprocessor = InputPreprocessor(model_config, self.tokenizer) + + self.input_registry = input_registry + self.input_processor = input_registry.create_input_processor(model_config) + + self.model_executor = executor_class( + model=model, # add for spmd_gpu_executor + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + speculative_config=speculative_config, + load_config=load_config, + prompt_adapter_config=prompt_adapter_config, + observability_config=self.observability_config, + ) + + if not self.model_config.embedding_mode: + self._initialize_kv_caches() + + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import get_architecture_class_name + + usage_message.report_usage( + get_architecture_class_name(model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": str(model_config.dtype), + "tensor_parallel_size": parallel_config.tensor_parallel_size, + "block_size": cache_config.block_size, + "gpu_memory_utilization": cache_config.gpu_memory_utilization, + # Quantization + "quantization": model_config.quantization, + "kv_cache_dtype": str(cache_config.cache_dtype), + # Feature flags + "enable_lora": bool(lora_config), + "enable_prompt_adapter": bool(prompt_adapter_config), + "enable_prefix_caching": cache_config.enable_prefix_caching, + "enforce_eager": model_config.enforce_eager, + "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, + }, + ) + + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + + self.cached_scheduler_outputs = [SchedulerOutputState() for _ in range(self.parallel_config.pipeline_parallel_size)] + + self.scheduler_contexts = [SchedulerContext(multi_step_stream_outputs=self.scheduler_config.multi_step_stream_outputs) for _ in range(self.parallel_config.pipeline_parallel_size)] + + if model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [partial(process_model_outputs, ctx=self.scheduler_contexts[v_id]) for v_id in range(self.parallel_config.pipeline_parallel_size)] + else: + self.async_callbacks = [] + + # Currently used by AsyncLLMEngine to ensure quick append + # of request outputs to asyncio queues + self.process_request_outputs_callback: Optional[Callable] = None + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + self.scheduler = [ + Scheduler( + scheduler_config, + cache_config, + lora_config, + parallel_config.pipeline_parallel_size, + self.async_callbacks[v_id] if model_config.use_async_output_proc else None, + ) + for v_id in range(parallel_config.pipeline_parallel_size) + ] + + # Metric Logging. + if self.log_stats: + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from vllm.engine.metrics import LoggingStatLogger, PrometheusStatLogger + + self.stat_loggers = { + "logging": LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + "prometheus": PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len, + ), + } + self.stat_loggers["prometheus"].info("cache_config", self.cache_config) + + self.tracer = None + if self.observability_config.otlp_traces_endpoint: + self.tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint) + + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + get_tokenizer_for_seq, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + get_tokenizer_for_seq, + ), + ) + + # TODO(sgm): add for verl but we may not tokenizer in Rollout + def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs): + init_kwargs = dict(enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, max_input_length=None) + init_kwargs.update(tokenizer_init_kwargs) + return TokenizerGroup(tokenizer, **init_kwargs) + + def init_cache_engine(self): + # TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache + # Re-capture CUDAGraph would be time-consuming + self.model_executor.init_cache_engine() + + def free_cache_engine(self): + self.model_executor.free_cache_engine() + + # NOTE(sgm): currently, we only support GPU executor + # The GPUExecutor remove the Ray dependency + @classmethod + def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]: + # Initialize the cluster and specify the executor class.] + assert engine_config.device_config.device_type == "cuda", "Currently, the vllm in verl only support running on GPU" + + # print('Waiting for debugger'); import os,debugpy; debugpy.listen(('localhost', 5678 + int(os.getenv('RANK', '0')))); debugpy.wait_for_client() + if engine_config.parallel_config.world_size == 1: + engine_config.load_config.load_format = "dummy_hf" + + from .spmd_gpu_executor import SPMDGPUExecutor + + executor_class = SPMDGPUExecutor + + return executor_class + + @classmethod + def from_engine_args( + cls, + model, + tokenizer, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(engine_config) + # Initialize the cluster and specify the executor class. + assert engine_config.device_config.device_type == "cuda", "Currently, the vllm in verl only support running on GPU" + + from .spmd_gpu_executor import SPMDGPUExecutor + + executor_class = SPMDGPUExecutor + + # Create the LLM engine. + engine = cls( + model, + tokenizer, + **engine_config.to_dict(), + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine + + def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: + self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + def offload_model_weights(self) -> None: + self.model_executor.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..68139b6b0db465bd24b3a9ffb5ebf8fd8d2f367d --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py @@ -0,0 +1,241 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict, Iterable + +import torch +import torch.nn as nn +from vllm.model_executor.layers.linear import * +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding +from vllm.model_executor.models import ModelRegistry + + +# NOTE(shengguangming): replace the origin weight loader function in the class +def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Parallel Linear weight loader.""" + assert param.size() == loaded_weight.size(), "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format(param.size(), loaded_weight.size()) + assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" + + param.data = loaded_weight.data + + +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + assert param.size() == loaded_weight.size() + assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same" + + param.data = loaded_weight.data + + +def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + # TODO: check megatron + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def qwen2_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), + ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ( + "input_layernorm", + "input_layernorm", + ), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def _replace_name(megatron_name, name_mapping): + for m_name, v_name in name_mapping: + if m_name not in megatron_name: + continue + if "layers" in megatron_name: # deal with decoder layers + megatron_name = megatron_name.replace("decoder", "model") + megatron_name_list = megatron_name.split(".") + if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: + param_name_list = megatron_name_list[:3] + param_name_list.append(v_name) + param_name = ".".join(param_name_list) + else: + param_name_list = megatron_name_list[:3] + weight_or_bias = megatron_name_list[-1] + param_name_list.append(v_name) + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + return param_name + else: + param_name = megatron_name.replace(m_name, v_name) + return param_name + + +def mistral_megatron_weight_loader(actor_weights: Iterable, vllm_model: nn.Module) -> nn.Module: + # TODO: need to implement a general way to deal with prefix + params_dict = dict(vllm_model.named_parameters()) + for name, weight in actor_weights: + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, weight) + + +def megatron_core_te_weight_loader(actor_weights: Iterable, vllm_model: nn.Module) -> nn.Module: + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, weight in actor_weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, weight) + + +__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { + ColumnParallelLinear: parallel_weight_loader, + MergedColumnParallelLinear: parallel_weight_loader, + QKVParallelLinear: parallel_weight_loader, + RowParallelLinear: parallel_weight_loader, + VocabParallelEmbedding: parallel_weight_loader, + ParallelLMHead: parallel_weight_loader, + # "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights + # "default_weight_loader": default_weight_loader +} + +# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): +# # setattr(layer_class, 'megatron_weight_loader', weight_loader) +# layer_class.weight_loader = weight_loader + +__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = { + "GPT2LMHeadModel": gpt2_weight_loader, + "LlamaForCausalLM": megatron_core_te_weight_loader, # use te backend for open-source megatron + "LLaMAForCausalLM": megatron_core_te_weight_loader, + "MistralForCausalLM": mistral_megatron_weight_loader, + "Qwen2ForCausalLM": megatron_core_te_weight_loader, +} + + +# the actor model is .state_dict() +# Load megatron weights +def load_megatron_weights(actor_weights: Iterable, vllm_model: nn.Module): + weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) + weight_loader(actor_weights, vllm_model) + # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu + # after init, and we need this after sync model weights for in first iter. + vllm_model = vllm_model.cuda() + + +def _get_model_weight_loader(arch: str): + if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__: + return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch] + raise ValueError(f"Model architectures {arch} are not supported for now. Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def update_megatron_weight_loader(): + for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): + layer_class.weight_loader = weight_loader diff --git a/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py b/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..211f715207a889ae5314230c97335dee0139b04f --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py @@ -0,0 +1,328 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models +"""Utilities for selecting and loading models.""" + +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig +from vllm.distributed.communication_op import tensor_model_parallel_all_gather +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader import BaseModelLoader +from vllm.model_executor.model_loader.loader import _initialize_model +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + +from .config import LoadConfig, LoadFormat, ModelConfig +from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader +from .hf_weight_loader import update_hf_weight_loader +from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader + + +def get_model( + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + cache_config: CacheConfig = None, +) -> nn.Module: + loader = get_model_loader(load_config) + if load_config.load_format.startswith("dummy"): + return loader.load_model( + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) + else: + return loader.load_model( + actor_model=actor_model, + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.AUTO: + update_megatron_weight_loader() + return MegatronLoader(load_config) + + # NOTE(sgm): change the weight_loader function in runtime + if load_config.load_format == LoadFormat.MEGATRON: + update_megatron_weight_loader() + return MegatronLoader(load_config) + + if load_config.load_format == LoadFormat.HF: + update_hf_weight_loader() + return HFLoader(load_config) + + if load_config.load_format == LoadFormat.DTENSOR: + update_dtensor_weight_loader() + return DTensorLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_HF: + update_hf_weight_loader() + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_MEGATRON: + update_megatron_weight_loader() + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_DTENSOR: + update_dtensor_weight_loader() + return DummyModelLoader(load_config) + + raise ValueError("load format not supported in verl: {}, only support {} and {}".format(load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF)) + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype), torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + # initialize_dummy_weights(model) + return model.eval() + + +class MegatronLoader(BaseModelLoader): + """Model loader that can load the model weights from partitioned megatron model.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): + # NOTE(shengguangming) Load the weights from the actor model + pass + # if isinstance(actor_model, nn.Module): + # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + # else: + # load_weights(actor_weights=actor_model, vllm_model=model) + # return actor_model + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + + # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm + if isinstance(actor_model, nn.Module): + load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + else: + load_megatron_weights(actor_weights=actor_model, vllm_model=model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +class HFLoader(BaseModelLoader): + """Model loader that can load the model weights from model's full params.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]): + if isinstance(actor_model, Dict): + return actor_model.items() + elif isinstance(actor_model, nn.Module): + return dict(actor_model.named_parameters()).items() + else: + raise ValueError(f"actor model should be Dict or nn.Module, but get {type(actor_model)}") + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + # with torch.device(device_config.device): + # NOTE(sgm): init the model in cpu + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + model.load_weights(self._get_weights_iterator(actor_model)) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +class DTensorLoader(BaseModelLoader): + """Model loader that can load the model weights from partitioned megatron model.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): + # NOTE(shengguangming) Load the weights from the actor model + pass + # if isinstance(actor_model, nn.Module): + # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + # else: + # load_weights(actor_weights=actor_model, vllm_model=model) + # return actor_model + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + + # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm + if isinstance(actor_model, nn.Module): + load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + else: + load_dtensor_weights(actor_weights=actor_model, vllm_model=model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +# FIXME(sgm): hack the _get_logits function in vllm v0.4.2 +# as they use ray, the _get_logits result will only need to return to the driver node, +# therefore gather is enough. However, we use SPMD instead of a central scheduler, +# all_gather is required (aligned with v0.2.6) +def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, : self.org_vocab_size] + return logits + + +def logitsprocessor_init( + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None, +) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super(LogitsProcessor, self).__init__() + self.scale = scale + self.vocab_size = vocab_size + # Whether the input is logits (default is hidden states). + self.logits_as_input = logits_as_input + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + # Soft cap the logits. Used in Gemma 2. + self.soft_cap = soft_cap + # Whether to use gather or all-gather to gather the logits. + self.use_gather = False + + +LogitsProcessor.__init__ = logitsprocessor_init # use all_gather diff --git a/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py b/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..028a4aff7515287f7ab7624d30af93a7ff70bf26 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py @@ -0,0 +1,174 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py + +import warnings +from enum import IntEnum +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import vllm.envs as envs +from vllm.compilation.levels import CompilationLevel +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, +) +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models.interfaces import supports_lora +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.prompt_adapter.worker_manager import LRUCacheWorkerPromptAdapterManager +from vllm.utils import DeviceMemoryProfiler, is_hip, supports_dynamo +from vllm.worker.model_runner import ModelRunner + +from .config import LoadConfig, ModelConfig +from .model_loader import get_model + +logger = init_logger(__name__) + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + + +class ModelRunner(ModelRunner): + def __init__( + self, + model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + super().__init__( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config, + kv_cache_dtype, + is_driver_worker=True, # a hack + prompt_adapter_config=prompt_adapter_config, + return_hidden_states=return_hidden_states, + observability_config=observability_config, + input_registry=input_registry, + mm_registry=mm_registry, + ) + + # NOTE(sgm): add for verl + self.model = model # this will be replaced by get_model() + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: + self.model = get_model( + self.model, + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config, + ) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) + + if self.lora_config: + assert supports_lora(self.model), f"{self.model.__class__.__name__} does not support LoRA yet." + + # if supports_multimodal(self.model): + # logger.warning( + # "Regarding multimodal models, vLLM currently only supports adding LoRA to language model." + # ) + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = self.model.config.text_config.max_position_embeddings + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.device, + self.prompt_adapter_config, + ) + self.model = self.prompt_adapter_manager.create_prompt_adapter_manager(self.model) + + if self.kv_cache_dtype == "fp8" and is_hip(): + # Currently only ROCm accepts kv-cache scaling factors + # via quantization_param_path and this will be deprecated + # in the future. + if self.model_config.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + warnings.warn( + "Loading kv cache scaling factor from JSON is deprecated and will be removed. Please include kv cache scaling factors in the model checkpoint.", + FutureWarning, + stacklevel=2, + ) + self.model.load_kv_cache_scales(self.model_config.quantization_param_path) + logger.info("Loaded KV cache scaling factors from %s", self.model_config.quantization_param_path) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but model %s does not support loading scaling factors.", + self.model.__class__, + ) + else: + logger.warning("Using FP8 KV cache but no scaling factors provided. Defaulting to scaling factors of 1.0. This may lead to less accurate results!") + + if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + from vllm.plugins import get_torch_compile_backend + + backend = get_torch_compile_backend() or "eager" + self.model = torch.compile(self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=backend) diff --git a/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py b/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py new file mode 100644 index 0000000000000000000000000000000000000000..f07aee0fdf39a205ff605f2a00ce86fe5d2956f4 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py @@ -0,0 +1,304 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""Model and data parallel groups.""" + +import os +from typing import Optional + +import torch +import torch.distributed +import vllm.distributed.parallel_state as ps +from vllm.distributed.parallel_state import ( + get_pp_group, + get_world_group, + init_distributed_environment, + init_model_parallel_group, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) +""" +This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. +- We assume the Megatron tp+dp+pp world is already established before calling this function. + +""" + +# Device mesh for using DTensor +_DEVICE_MESH = None + +# Tensor model parallel group that the current rank belongs to. +_TP = None +# Pipeline model parallel group that the current rank belongs to. +_PP = None + + +# This method is for initializing the ParallelGroup when using HybridEngine +def initialize_parallel_state( + distributed_init_method: str = "env://", + backend: str = "nccl", + tensor_model_parallel_size: int = 1, + num_tp_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +): + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + rank = int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) + if torch.distributed.get_world_size() > 1: + # NOTE: build a sepearate inference group with infer tp & micro dp + initialize_model_parallel_for_vllm( + tensor_model_parallel_size=tensor_model_parallel_size, + num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, + ) + else: + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + return + + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, f"tensor parallel group already initialized, but of unexpected size: {get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}" + pp_world_size = get_pp_group().world_size + assert pp_world_size == pipeline_model_parallel_size, f"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. {pipeline_model_parallel_size=}" + + +# TODO(sgm): deviate from the v0.5.4, not pp now +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ps._TP is not None + # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + + +def initialize_model_parallel_for_vllm( + tensor_model_parallel_size: int, + num_tensor_model_parallel_groups_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +) -> None: + pass + + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + + assert isinstance(tensor_model_parallel_size, int) + + # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group + # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group + + # Build the tensor model-parallel groups. + assert ps._TP is None, "tensor model parallel group is already initialized" + + global _TP + + world_size: int = torch.distributed.get_world_size() + + backend = torch.distributed.get_backend() + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + + if num_tensor_model_parallel_groups_per_train_tp == 1: + # if tensor_model_parallel_size == train_tensor_parallel_size: + # using the same tp group as Megatron/vllm + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine + else: + # initialize a micro_dp group and a tp group + # assume training tp=4, infer tp=2, then, weight is partitioned as + # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference + + # Build the inference tp groups + # train_tp = train_tensor_parallel_size + train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size + # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): + start = train_tp * i + end = train_tp * (i + 1) + for j in range(num_tensor_model_parallel_groups_per_train_tp): + ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp)) + for i in range(len(ranks)): + ranks[i] += j + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # Build the pipeline model-parallel groups. + # global _PIPELINE_MODEL_PARALLEL_GROUP + # global _PIPELINE_GLOBAL_RANKS + # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized") + + # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group() + # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks() + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + NOTE: This method is a hack from the open-sourced version without + asertion of world_size = tp * pp + + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) + + # NOTE(sgm) we don't assert world_size == tp * pp + # DP is not managed by vllm but by the VeRL WorkerGroup + # if (world_size != + # tensor_model_parallel_size * pipeline_model_parallel_size): + # raise RuntimeError( + # f"world_size ({world_size}) is not equal to " + # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + global _TP + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +""" +Device mesh utilities +""" + + +def get_device_mesh(): + assert _DEVICE_MESH is not None, "device mesh is not initialized" + return _DEVICE_MESH + + +""" +Tensor model parallel utilities +""" + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP.device_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size diff --git a/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py b/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..3abbd4705fa444cbeddaee6139c0b1dd02629ef6 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py @@ -0,0 +1,250 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py + +import os +import socket +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest + +from .config import LoadConfig, ModelConfig + +logger = init_logger(__name__) + + +class SPMDGPUExecutor(ExecutorBase): + """SPMD-based multi-GPU executor implementations.""" + + def __init__( + self, + model, # pytorch model itself or its parameter dict + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + observability_config: Optional[ObservabilityConfig], + ) -> None: + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config + + distributed_init_method = initialize_cluster(parallel_config) + self._init_executor(model, distributed_init_method) + + # TODO(sgm): verl not support speculative decode now + def _init_executor(self, model, distributed_init_method) -> None: + assert not self.speculative_config, "Speculative decoding not yet supported for multi-GPU backend." + + # Create the parallel worker for each GPU. + self._init_workers_sp(model, distributed_init_method) + + def _init_workers_sp(self, model, distributed_init_method: str): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from .worker import Worker + + rank = int(os.getenv("RANK")) + local_rank = int(os.getenv("LOCAL_RANK")) + print(f"local rank {local_rank}") + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ["NCCL_CUMEM_ENABLE"] = "0" + + self.worker = Worker( + model, + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + self.cache_config, + self.load_config, + local_rank, + rank, + distributed_init_method, + lora_config=self.lora_config, + speculative_config=None, + prompt_adapter_config=self.speculative_config, + is_driver_worker=True, + model_runner_cls=None, # use the default one + ) + + # NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model() + self.worker.init_device() + self.worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self.worker.determine_num_available_blocks() + + # NOTE(shengguangming): Now we don't use a shared centralized controler but each process will + # have its own scheduler + num_gpu_blocks = num_blocks[0] + num_cpu_blocks = num_blocks[1] + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers.""" + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + if torch.distributed.get_rank() == 0: + print(f"before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB") + self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + if torch.distributed.get_rank() == 0: + print(f"after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB") + + # NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache + def init_cache_engine(self) -> None: + self.worker._init_cache_engine() + + def free_cache_engine(self) -> None: + self.worker.free_cache_engine() + + def execute_model(self, execute_model_req) -> List[SamplerOutput]: + all_outputs = self.worker.execute_model(execute_model_req=execute_model_req) + + # NOTE(sgm): + # Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs + # In vllm with ray, only the driver worker returns the sampling results. + return all_outputs + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self.worker.add_lora(lora_request=lora_request) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.worker.remove_lora(lora_id=lora_id) + + def list_loras(self) -> Set[int]: + return self.worker.list_loras() + + def check_health(self) -> None: + # SPMDExecutor will always be healthy as long as + # it's running. + return + + # NOTE(sgm) add for verl to pass the abstract class test, not used + from vllm.prompt_adapter.request import PromptAdapterRequest + + def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.add_prompt_adapter(prompt_adapter_request) + + def list_prompt_adapters(self) -> Set[int]: + return self.worker.list_prompt_adapters() + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.worker.pin_lora(lora_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.pin_prompt_adapter(prompt_adapter_id) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.remove_prompt_adapter(prompt_adapter_id) + + # NOTE(sgm): add for verl + def offload_model_weights(self) -> None: + self.worker.offload_model_weights() + + def sync_model_weights(self, actor_weights: Iterable, load_format: str) -> None: + self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + +def initialize_cluster( + parallel_config: ParallelConfig, + engine_use_ray: bool = False, + ray_address: Optional[str] = None, +) -> Tuple[str, Optional[None]]: + """Initialize the distributed cluster probably with Ray. + + Args: + parallel_config: The configurations for parallel execution. + + Returns: + The `distributed_init_method` is the address for initializing the + distributed backend. + """ + + # Initialize cluster locally. + # We need to setup the distributed init method to make sure + # the distributed megatron code (e.g., get world size) works correctly. + # distributed_init_method = f"tcp://localhost:{port}" + distributed_init_method = "env://" + return distributed_init_method + + +def get_open_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +# TODO(sgm): not implemented async executor yet +class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase): + async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + """Executes one model step on the given sequences.""" + raise NotImplementedError + + async def check_health_async(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + self.check_health() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py b/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7e115bea7197f4c9eaf999ec972f6c4a432077 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py @@ -0,0 +1,39 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py + +from typing import Optional + +from transformers import PreTrainedTokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.utils import LRUCache + + +class TokenizerGroup(TokenizerGroup): + """A group of tokenizers that can be used for LoRA adapters.""" + + def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int]): + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = tokenizer + self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None + + # FIXME(sgm): for simplicity, we assign the special token here + @property + def pad_token_id(self): + return self.tokenizer.pad_token_id + + @property + def eos_token_id(self): + return self.tokenizer.eos_token_id diff --git a/verl/third_party/vllm/vllm_v_0_6_3/worker.py b/verl/third_party/vllm/vllm_v_0_6_3/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..a251700a011862a516836345d25e511c0afc1b7a --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/worker.py @@ -0,0 +1,320 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py +"""A GPU worker class.""" + +import gc +import os +from typing import Dict, Iterable, List, Optional, Tuple, Type, Union + +import torch +import torch.distributed +import torch.nn as nn +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) + +# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state +from vllm.distributed import get_tensor_model_parallel_group, init_distributed_environment, set_custom_all_reduce +from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase +from vllm.worker.model_runner_base import ModelRunnerInputBase +from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype +from vllm.worker.worker_base import WorkerInput + +from .config import LoadConfig, LoadFormat, ModelConfig +from .dtensor_weight_loaders import load_dtensor_weights +from .hf_weight_loader import load_hf_weights +from .megatron_weight_loaders import load_megatron_weights +from .model_runner import ModelRunner +from .parallel_state import ensure_model_parallel_initialized + + +class Worker(Worker): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model: Union[nn.Module, Dict], # model itself or its parameter dict + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, + ) -> None: + # self.model = model # will be replaced in the init_model + self.model_config = model_config + self.parallel_config = parallel_config + self.parallel_config.rank = rank + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker # TODO: we don't need driver + # if parallel_config and is_driver_worker: + # assert rank % parallel_config.tensor_parallel_size == 0, \ + # "Driver worker should be rank 0 of tensor parallel group." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + + init_cached_hf_modules() + + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_args = {} if speculative_config is None or (speculative_config.draft_model_config.model == model_config.model) or (speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) else {"return_hidden_states": True} + + # TODO(sgm): set correct model runner class + ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner + if model_runner_cls is not None: + ModelRunnerClass = model_runner_cls + elif self.model_config.embedding_mode: + ModelRunnerClass = EmbeddingModelRunner + self.model_runner: GPUModelRunnerBase = ModelRunnerClass( + model, # [VERL]: add for verl + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + **speculative_args, + ) + + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] = None + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.gpu_cache: Optional[List[List[torch.Tensor]]] = None + + # NOTE(sgm): [VERL] For offloading inference engine params + self.cpu_model = None + + def init_device(self) -> None: + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + self.device = torch.device(f"cuda:{local_rank}") + if self.rank < 0: + raise ValueError("Invalid or unspecified rank.") + torch.cuda.set_device(self.device) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + self.parallel_config.world_size = world_size + + _check_if_gpu_supports_dtype(self.model_config.dtype) + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError(f"Not support device type: {self.device_config.device}") + + # Initialize the distributed environment. + init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + # self.model = get_model(actor_model=self.model, model_config=self.model_config) + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + # torch.cuda.reset_peak_memory_stats() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = total_gpu_memory - free_gpu_memory + + assert peak_memory > 0, "Error in memory profiling. This happens when the GPU memory was not properly cleaned up before initializing the vLLM instance." + + cache_block_size = self.get_cache_block_size_bytes() + + # NOTE(sgm) [VERL] use the remaining memory + num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size) + # num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size) + + num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + + # NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank + num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda") + num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda") + + torch.distributed.all_reduce(num_gpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group) + torch.distributed.all_reduce(num_cpu_blocks, op=torch.distributed.ReduceOp.MIN, group=get_tensor_model_parallel_group().device_group) + num_gpu_blocks = num_gpu_blocks.item() + num_cpu_blocks = num_cpu_blocks.item() + gc.collect() + torch.cuda.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def _init_cache_engine(self): + if self.cache_engine is None and self.gpu_cache is None: + super()._init_cache_engine() + + def free_cache_engine(self): + # ensure `enforce_eager=True` + self.cache_engine = None + self.gpu_cache = None + + # NOTE(sgm): [VERL]: adapt from _execute_model_spmd() + def execute_model(self, execute_model_req: ExecuteModelRequest, intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]: + """ + Execute model in Single Program Multiple Data (SPMD) fashion. + All workers take the same request, prepare the input and + execute the model. + """ + assert execute_model_req is not None, "_execute_model_spmd() requires each worker to take in an ExecuteModelRequest" + worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = self.model_runner.prepare_model_input(execute_model_req.seq_group_metadata_list) + + # verl.worker.workerbase.WorkerBase + # swap cache + super().execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + return self.model_runner.execute_model( + model_input, + self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, + intermediate_tensors, + ) + + # assume the input is .state_dict() + def sync_model_weights(self, actor_weights: Iterable, load_format: str): + if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]: + load_megatron_weights(actor_weights, self.model_runner.model) + elif load_format == LoadFormat.HF: + # full model state iterable without no sharding + load_hf_weights(actor_weights, self.model_runner.model) + elif load_format == LoadFormat.DTENSOR: + load_dtensor_weights(actor_weights, self.model_runner.model) + + def offload_model_weights(self) -> None: + if self.cpu_model is None: + self.cpu_model = {} + for name, params in self.model_runner.model.named_parameters(): + self.cpu_model[name] = torch.empty_like(params, device="cpu") + params.data = self.cpu_model[name] + else: + for name, params in self.model_runner.model.named_parameters(): + params.data = self.cpu_model[name] + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = "env://", + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + # NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) + + ensure_model_parallel_initialized( + tensor_model_parallel_size=parallel_config.tensor_parallel_size, + pipeline_model_parallel_size=parallel_config.pipeline_parallel_size, + ) + + # TODO(sgm): check whether need this + # if pynccl_utils.is_initialized(): + # pynccl_world_size = pynccl_utils.get_world_size() + # if pynccl_world_size != parallel_config.world_size: + # raise RuntimeError( + # "pynccl is already initialized but the pynccl world " + # "size does not match parallel_config.world_size " + # f"({pynccl_world_size} vs. {parallel_config.world_size}).") + # elif parallel_config.world_size > 1: + # # NOTE(woosuk): We don't initialize pynccl process group when world size + # # is 1. + # # NOTE(kaichao): By default, pynccl is initialized for tp group. + # pynccl_utils.init_process_group( + # group=get_tensor_model_parallel_cpu_group()) + + # # Initialize a custom fast all-reduce implementation. + # if not parallel_config.disable_custom_all_reduce: + # init_custom_ar() + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + # if pynccl_utils.is_initialized(): + # pynccl_utils.all_reduce(torch.zeros(1).cuda()) diff --git a/verl/tools/__init__.py b/verl/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2736ffff7a85a4a89469c147a0c6fb007ddaf09f --- /dev/null +++ b/verl/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/tools/base_tool.py b/verl/tools/base_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..254a426f133aa63bc4e62290a6a2f211b874f427 --- /dev/null +++ b/verl/tools/base_tool.py @@ -0,0 +1,86 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Tuple +from uuid import uuid4 + +from .schemas import OpenAIFunctionToolSchema + + +class BaseTool: + """Base class for tools. + + A tool should support the following methods: + + - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + self.config = config + self.name = tool_schema.function.name + self.tool_schema = tool_schema + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + """ + if instance_id is None: + return str(uuid4()) + else: + return instance_id + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]: + """Execute the tool. + + Args: + instance_id: The instance id of the tool. + parameters: The json string of the parameters of the tool. + + Returns: tool_response, tool_reward_score, tool_metrics + tool_response: The response str of the tool. + tool_reward_score: The step reward score of the tool. + tool_metrics: The metrics of the tool. + """ + return "Updated the tool state.", 0.0, {} + + async def calc_reward(self, instance_id: str, **kwargs) -> float: + """Calculate the reward of the tool. + + Args: + instance_id: The instance id of the tool. + + Returns: + The reward of the tool. + """ + return 0.0 + + async def release(self, instance_id: str, **kwargs) -> None: + """Release the tool instance. + + Args: + instance_id: The instance id of the tool. + """ + pass diff --git a/verl/tools/gsm8k_tool.py b/verl/tools/gsm8k_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..aa43065a8711c3283ea8833e2e2455c929d967b0 --- /dev/null +++ b/verl/tools/gsm8k_tool.py @@ -0,0 +1,104 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Optional, Tuple +from uuid import uuid4 + +from verl.utils.reward_score import gsm8k + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class Gsm8kTool(BaseTool): + """A demo tool for calculating the reward of gsm8k. + + - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """ + _tool_schema = OpenAIFunctionToolSchema.model_validate({ + "type": "function", + "function": { + "name": "calc_gsm8k_reward", + "description": "A tool for calculating the reward of gsm8k", + "parameters": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The answer to the question", + }, + }, + "required": ["answer"], + }, + } + }) + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]: + answer = parameters.get("answer", "") + if not isinstance(answer, str): + answer = str(answer) + + if answer.startswith("#### "): + self._instance_dict[instance_id]["response"] = answer + else: + self._instance_dict[instance_id]["response"] = "#### " + answer + + reward = await self.calc_reward(instance_id) + # penalty for non improved answer submission + tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 + # update the reward + self._instance_dict[instance_id]["reward"] = reward + + return f"Current parsed {answer=} {reward=}", tool_reward, {} + + async def calc_reward(self, instance_id: str, **kwargs) -> float: + return gsm8k.compute_score( + self._instance_dict[instance_id]["response"], + self._instance_dict[instance_id]["ground_truth"], + method="flexible", + format_score=0.0, + score=1.0, + ) + + async def release(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/verl/tools/schemas.py b/verl/tools/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..78124391291acc87c0946708e8e6140a6b5a39f6 --- /dev/null +++ b/verl/tools/schemas.py @@ -0,0 +1,87 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Any, Literal + +from pydantic import BaseModel + + +class OpenAIFunctionPropertySchema(BaseModel): + """The schema of a parameter in OpenAI format.""" + + type: str + description: str | None = None + enum: list[str] | None = None + + +class OpenAIFunctionParametersSchema(BaseModel): + """The schema of parameters in OpenAI format.""" + + type: str + properties: dict[str, OpenAIFunctionPropertySchema] + required: list[str] + + +class OpenAIFunctionSchema(BaseModel): + """The schema of a function in OpenAI format.""" + + name: str + description: str + parameters: OpenAIFunctionParametersSchema + strict: bool = False + + +class OpenAIFunctionToolSchema(BaseModel): + """The schema of a tool in OpenAI format.""" + + type: str + function: OpenAIFunctionSchema + + +class OpenAIFunctionParsedSchema(BaseModel): + """The parsed schema of a tool in OpenAI format.""" + + name: str + arguments: str # JSON string + + +class OpenAIFunctionCallSchema(BaseModel): + """The parsed schema of a tool in OpenAI format.""" + + name: str + arguments: dict[str, Any] + + @staticmethod + def from_openai_function_parsed_schema(parsed_schema: OpenAIFunctionParsedSchema) -> tuple["OpenAIFunctionCallSchema", bool]: + has_decode_error = False + try: + arguments = json.loads(parsed_schema.arguments) + except json.JSONDecodeError: + arguments = {} + has_decode_error = True + # If the arguments is not a dict, it means the arguments is not a valid JSON string + if not isinstance(arguments, dict): + arguments = {} + has_decode_error = True + + return OpenAIFunctionCallSchema(name=parsed_schema.name, arguments=arguments), has_decode_error + + +class OpenAIFunctionToolCall(BaseModel): + """The tool call in OpenAI format.""" + + id: str + type: Literal["function"] = "function" + function: OpenAIFunctionCallSchema diff --git a/verl/trainer/__init__.py b/verl/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/trainer/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/trainer/__pycache__/__init__.cpython-311.pyc b/verl/trainer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ca1d9a7c1d9ff10c6673b3865cb13781f636c8f Binary files /dev/null and b/verl/trainer/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/trainer/config/evaluation.yaml b/verl/trainer/config/evaluation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ca735c4baddc2cc8bf1e97bb4ed08cc07316023 --- /dev/null +++ b/verl/trainer/config/evaluation.yaml @@ -0,0 +1,13 @@ +data: + path: /tmp/math_Qwen2-7B-Instruct.parquet + prompt_key: prompt + response_key: responses + data_source_key: data_source + reward_model_key: reward_model + +custom_reward_function: + path: null + name: compute_score + +ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. \ No newline at end of file diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3f5a6931ec35e56b8c7ece97ae0b66c638fd2cfd --- /dev/null +++ b/verl/trainer/config/generation.yaml @@ -0,0 +1,50 @@ +trainer: + nnodes: 1 + n_gpus_per_node: 8 + +data: + path: ~/data/rlhf/math/test.parquet + prompt_key: prompt + n_samples: 5 + output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet + batch_size: 128 + +model: + path: ~/models/Qwen2-7B-Instruct + external_lib: null +rollout: + name: vllm + mode: sync # sync: LLM, async: AsyncLLM + temperature: 1.0 + top_k: 50 # 0 for hf rollout, -1 for vllm rollout + top_p: 0.7 + prompt_length: 1536 + response_length: 512 + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 8 + # for fire vllm rollout + use_fire_sampling: False # enable FIRE https://arxiv.org/abs/2410.21236 + # for hf rollout + do_sample: True + disable_log_stats: True + enable_chunked_prefill: True + n: 1 +actor: + strategy: fsdp # This is for backward-compatibility + ulysses_sequence_parallel_size: 1 # sp size + fsdp_config: + fsdp_size: -1 + +ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f7d45cd4462e11f961777622f7ec8adb1d66adc --- /dev/null +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -0,0 +1,276 @@ +data: + tokenizer: null + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + shuffle: True + filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. + filter_overlong_prompts_workers: 1 + truncation: error + custom_cls: + path: null + name: null + +actor_rollout_ref: + hybrid_engine: True + model: + path: ~/models/deepseek-llm-7b-chat + external_lib: null + override_config: {} + enable_gradient_checkpointing: False + gradient_checkpointing_kwargs: + ## Activation Checkpointing + activations_checkpoint_method: null # 'uniform', 'block'; not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_granularity: null # 'selective' or 'full' + # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention + activations_checkpoint_num_layers: null # not used with 'selective' + actor: + strategy: megatron # This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: False + use_torch_compile: True # False to disable torch compile + # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) + clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 + loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" + # NOTE: "token-mean" is the default behavior + entropy_coeff: 0 + use_kl_loss: False # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + data_loader_seed: null + shuffle: False + optim: + lr: 1e-6 + clip_grad: 1.0 + lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + weight_decay: 0.01 + megatron: + param_offload: False + grad_offload: False + optimizer_offload: False + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests + context_parallel_size: 1 + sequence_parallel: True + use_distributed_optimizer: True + use_dist_checkpointing: False + dist_checkpointing_path: null + seed: 1 + profile: # profile the actor model in `update_policy` + use_profile: False # open it when you want to profile the actor model + profile_ranks: null # list, you can specify the ranks to profile + step_start: -1 # start step in update_policy + step_end: -1 # end step + save_path: null # the path to save the profile result + load_weight: True + checkpoint: + contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + ref: + strategy: megatron + use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} + megatron: + param_offload: False + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests + context_parallel_size: 1 + sequence_parallel: True + use_distributed_optimizer: False + use_dist_checkpointing: False + dist_checkpointing_path: null + seed: 1 + profile: + use_profile: False + profile_ranks: null + step_start: -1 + step_end: -1 + save_path: null + load_weight: True + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + rollout: + name: vllm + mode: sync # sync: LLM, async: AsyncLLM + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # for xperf_gpt + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_megatron + tensor_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + disable_log_stats: True + enable_chunked_prefill: False # could get higher throughput + # for hf rollout + do_sample: True + layer_name_map: + qkv_layer_name: qkv + gate_proj_layer_name: gate_up + # number of responses (i.e. num sample times) + n: 1 + engine_kwargs: # inference engine parameters + swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB + val_kwargs: + # sampling parameters for validation + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1.0 + temperature: 0 + n: 1 + do_sample: False # default eager for validation + multi_turn: + enable: False # should set rollout.name to sglang_async if True + max_turns: null # null for no limit (default max_length // 3) + tool_config_path: null # null for no tool + format: chatml # chatml, more formats will be supported in the future + +critic: + rollout_n: ${actor_rollout_ref.rollout.n} + strategy: megatron + optim: + lr: 1e-5 + clip_grad: 1.0 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + weight_decay: 0.01 + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: {} + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: False + gradient_checkpointing_kwargs: + ## Activation Checkpointing + activations_checkpoint_method: null + activations_checkpoint_granularity: null + activations_checkpoint_num_layers: null + megatron: + param_offload: False + grad_offload: False + optimizer_offload: False + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests + context_parallel_size: 1 + sequence_parallel: True + use_distributed_optimizer: True + use_dist_checkpointing: False + dist_checkpointing_path: null + seed: 1 + load_weight: True + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed} + shuffle: ${actor_rollout_ref.actor.shuffle} + cliprange_value: 0.5 + kl_ctrl: + type: fixed + kl_coef: 0.001 + checkpoint: + contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + +reward_model: + enable: False + strategy: megatron + megatron: + param_offload: False + grad_offload: False + optimizer_offload: False + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests + context_parallel_size: 1 + sequence_parallel: True + use_distributed_optimizer: False + use_dist_checkpointing: False + dist_checkpointing_path: null + seed: 1 + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + load_weight: True + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: null + use_dynamic_bsz: ${critic.use_dynamic_bsz} + max_length: null + launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob + +custom_reward_function: + path: null + name: compute_score + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: True + use_kl_in_reward: False + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + +trainer: + balance_batch: True + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: ['console', 'wandb'] + log_val_generations: 0 + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or disable or resume_path if resume_from_path is set + resume_from_path: null + del_local_ckpt_after_load: False + val_before_train: True + test_freq: 2 + critic_warmup: 0 + default_hdfs_dir: null + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + # The timeout for ray worker group to wait for the register center to be ready + ray_wait_register_center_timeout: 300 + +ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..367a386d8855a266e3d5304f8d600e2781b9ad6e --- /dev/null +++ b/verl/trainer/config/ppo_trainer.yaml @@ -0,0 +1,246 @@ +data: + tokenizer: null + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + shuffle: True + filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + custom_cls: + path: null + name: null + +actor_rollout_ref: + hybrid_engine: True + model: + path: ~/models/deepseek-llm-7b-chat + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: False + use_liger: False + trust_remote_code: False + actor: + strategy: fsdp # [fsdp, fsdp2], This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) + clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 + loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" + entropy_coeff: 0 + use_kl_loss: False # True for GRPO + use_torch_compile: True # False to disable torch compile + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + checkpoint: + contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + optim: + lr: 1e-6 + lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + weight_decay: 0.01 + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + offload_policy: False # only for fsdp2, offload param\grad\optimizer during train + reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] + fsdp_size: -1 + ref: + strategy: fsdp + fsdp_config: + param_offload: False + reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + mode: sync # sync: LLM, async: AsyncLLM + chat_scheduler: null # async chat scheduler, e.g examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + use_fire_sampling: False # https://arxiv.org/abs/2410.21236 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 1 # > 1 for grpo + engine_kwargs: # inference engine parameters + swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB + val_kwargs: + # sampling parameters for validation + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1.0 + temperature: 0 + n: 1 + do_sample: False # default eager for validation + multi_turn: + enable: False # should set rollout.name to sglang_async if True + max_turns: null # null for no limit (default max_length // 3) + tool_config_path: null # null for no tool + format: chatml # chatml, more formats will be supported in the future + +critic: + rollout_n: ${actor_rollout_ref.rollout.n} + strategy: fsdp # [fsdp, fsdp2] + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + weight_decay: 0.01 + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + trust_remote_code: ${actor_rollout_ref.model.trust_remote_code} + fsdp_config: + param_offload: False + optimizer_offload: False + offload_policy: False # only for fsdp2, offload param\grad\optimizer during train + reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + checkpoint: + contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + trust_remote_code: False + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: False + reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size] + fsdp_size: -1 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: null # set a number + max_length: null + ulysses_sequence_parallel_size: 1 # sp size + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob + +custom_reward_function: + path: null + name: compute_score + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: True + use_kl_in_reward: False + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + +trainer: + balance_batch: True + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: [ 'console', 'wandb' ] + log_val_generations: 0 + rollout_data_dir: null # directory for logging the rollout data, no dump if null + validation_data_dir: null # directory for logging the validation data, no dump if null + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or disable or resume_path if resume_from_path is set + resume_from_path: null + val_before_train: True + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + del_local_ckpt_after_load: False + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + # The timeout for ray worker group to wait for the register center to be ready + ray_wait_register_center_timeout: 300 + +ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. \ No newline at end of file diff --git a/verl/trainer/config/sft_trainer.yaml b/verl/trainer/config/sft_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7b7818d815c3cf15a6e0bfa6f67a0765c20f7fb1 --- /dev/null +++ b/verl/trainer/config/sft_trainer.yaml @@ -0,0 +1,56 @@ +data: + train_batch_size: 256 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: 4 # this is also val batch size + train_files: ~/data/gsm8k/train.parquet + val_files: ~/data/gsm8k/test.parquet + # Single-turn settings + prompt_key: question + response_key: answer + prompt_dict_keys: ['question'] + response_dict_keys: ['answer'] + # Multi-turn settings + multiturn: + enable: false # Set to true to use multi-turn dataset + messages_key: messages # Key for messages list in multi-turn mode + max_length: 1024 + truncation: error + balance_dp_token: False + chat_template: null + custom_cls: + path: null + name: null +model: + partial_pretrain: ~/models/gemma-1.1-7b-it + fsdp_config: + wrap_policy: + min_num_params: 0 + cpu_offload: False + offload_params: False + external_lib: null + enable_gradient_checkpointing: False + trust_remote_code: False + lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) + lora_alpha: 16 # LoRA scaling factor + target_modules: all-linear # Target modules for LoRA adaptation + use_liger: False +optim: + lr: 1e-5 + betas: [0.9, 0.95] + weight_decay: 0.01 + warmup_steps_ratio: 0.1 + clip_grad: 1.0 + lr_scheduler: cosine +ulysses_sequence_parallel_size: 1 +use_remove_padding: False +trainer: + default_local_dir: /tmp/sft_model + default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here + resume_path: null + project_name: gsm8k-sft + experiment_name: test + total_epochs: 4 + total_training_steps: null + logger: ['console'] + seed: 1 + diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f04221e037a19e7434307e2120c731786734db29 --- /dev/null +++ b/verl/trainer/fsdp_sft_trainer.py @@ -0,0 +1,544 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A lightweight one-file FSDP SFT Trainer +TODO(zhangchi.usc1992) +- Add calculation of mfu +- Add validation +""" + +import os + +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +import logging +import re +from contextlib import nullcontext + +import hydra +import torch +import torch.distributed +from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input +from peft import LoraConfig, TaskType, get_peft_model +from tensordict import TensorDict +from torch import nn, optim +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel + +import verl.utils.hdfs_io as hdfs_io +from verl.utils.dataset import SFTDataset +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset +from verl.utils.debug import log_gpu_memory_usage +from verl.utils.distributed import initialize_global_process_group +from verl.utils.fs import copy_to_local +from verl.utils.fsdp_utils import get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn +from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup +from verl.utils.tracking import Tracking +from verl.utils.ulysses import ( + gather_outpus_and_unpad, + get_ulysses_sequence_parallel_world_size, + ulysses_pad_and_slice_inputs, +) +from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) + + +def extract_step(path): + match = re.search(r"global_step_(\d+)", path) + if match: + return int(match.group(1)) + return None + + +def convert_to_regular_types(obj): + """Convert Hydra configs and other special types to regular Python types.""" + from omegaconf import DictConfig, ListConfig + + if isinstance(obj, (ListConfig, DictConfig)): + return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj) + elif isinstance(obj, (list, tuple)): + return [convert_to_regular_types(x) for x in obj] + elif isinstance(obj, dict): + return {k: convert_to_regular_types(v) for k, v in obj.items()} + return obj + + +class FSDPSFTTrainer: + def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh, tokenizer, train_dataset: Dataset, val_dataset: Dataset): + self.config = config + self.device_mesh = device_mesh + self.ulysses_device_mesh = ulysses_device_mesh + self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self.tokenizer = tokenizer + if self.config.data.chat_template is not None: + raise ValueError("Apply Chat template from config is not supported yet.") + + # normalize dp size + self._normalize_config_bsz() + + # Set sequence parallel size + self.config.ulysses_sequence_parallel_size = getattr(self.config, "ulysses_sequence_parallel_size", 1) + self.use_remove_padding = getattr(self.config, "use_remove_padding", False) + if self.device_mesh.get_rank() == 0: + print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}") + print(f"Using remove padding: {self.use_remove_padding}") + + self._build_dataloader(train_dataset, val_dataset) + # build model + self._build_model_optimizer() + + # TODO: add checkpoint manager + if self.device_mesh.get_rank() == 0: + print(self.config) + + def _normalize_config_bsz(self): + dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) + if self.device_mesh.get_rank() == 0: + print(f"Normalize batch size by dp {dp_size}") + + assert self.config.data.train_batch_size % dp_size == 0, f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" + + self.config.data.train_batch_size //= dp_size + + assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0 + + def _build_dataloader(self, train_dataset, val_dataset): + # build dataset + config = self.config + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + # build dataloader + # Use data parallel rank and size instead of global rank and world size + + # If doing SP, we need to use the local rank and size + if self.config.ulysses_sequence_parallel_size > 1: + rank = self.ulysses_device_mesh.get_local_rank("dp") + world_size = self.ulysses_device_mesh.size(0) + if self.ulysses_device_mesh.get_rank() == 0: + print(f"Using SP rank {rank} and size {world_size} for data distribution") + print("Each SP rank gets different data, but the same data WITHIN the same rank") + else: + rank = self.device_mesh.get_rank() + world_size = self.device_mesh.size() + if self.device_mesh.get_rank() == 0: + print(f"Using FSDP rank {rank} and size {world_size} for data distribution") + + self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True) + self.train_dataloader = DataLoader( + dataset=self.train_dataset, + batch_size=config.data.train_batch_size, + sampler=self.train_sampler, + num_workers=8, + pin_memory=True, + drop_last=True, + ) + + self.val_sampler = DistributedSampler(self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True) + self.val_dataloader = DataLoader( + dataset=self.val_dataset, + batch_size=config.data.micro_batch_size_per_gpu, + sampler=self.val_sampler, + num_workers=8, + pin_memory=True, + drop_last=True, + ) + + def _build_model_optimizer(self): + # TODO (zhangchi.usc1992): + # 1. support pretrain from random weights + # 2. support init directly from sharded weights + local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True) + + if self.config.model.get("external_lib", None) is not None: + # This is used to import external_lib into the huggingface systems + import importlib + + importlib.import_module(self.config.model.external_lib) + + log_gpu_memory_usage("Before model allocation", logger=logger) + + trust_remote_code = self.config.model.trust_remote_code + # load config first + config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) + if self.config.ulysses_sequence_parallel_size > 1: + assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" + + # This may be very large + init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh) + + with init_context(): + self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + local_model_path, + config=config, + torch_dtype=torch.float32, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + + if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1: + from verl.models.transformers.monkey_patch import apply_monkey_patch + + apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size) + + # Apply Liger kernel if use_liger is enabled + if self.config.model.get("use_liger", False): + from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance + + _apply_liger_kernel_to_instance(model=self.model) + + if self.config.model.get("lora_rank", 0) > 0: + self.model.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "bias": "none", + } + self.model = get_peft_model(self.model, LoraConfig(**lora_config)) + + if self.config.model.enable_gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + log_gpu_memory_usage("After model allocation", logger=logger) + + mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + + auto_wrap_policy = get_fsdp_wrap_policy( + self.model, + config=self.config.model.fsdp_config.wrap_policy, + is_lora=self.config.model.get("lora_rank", 0) > 0, + ) + if self.device_mesh.get_rank() == 0: + print(auto_wrap_policy) + + if not self.config.model.fsdp_config.cpu_offload: + cpu_offload = None + else: + cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params) + + self.fsdp_model = FSDP( + module=self.model, + auto_wrap_policy=auto_wrap_policy, + param_init_fn=init_fn, + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + device_mesh=self.device_mesh, + sync_module_states=True, + device_id=torch.cuda.current_device(), + cpu_offload=cpu_offload, + use_orig_params=False, + ) + + log_gpu_memory_usage("After FSDP wrapping", logger=logger) + + self.optimizer = optim.AdamW( + self.fsdp_model.parameters(), + lr=self.config.optim.lr, + betas=self.config.optim.betas, + weight_decay=self.config.optim.weight_decay, + ) + + log_gpu_memory_usage("After initialize optimizer", logger=logger) + + self.steps_per_epoch = len(self.train_dataloader) + self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs + + if self.device_mesh.get_rank() == 0: + print(f"Number of steps/epoch {self.steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {self.total_steps}") + + num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio) + + if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine": + self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps) + elif self.config.optim.lr_scheduler == "wsd": + self.lr_scheduler = get_wsd_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps) + else: + raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}") + + def _compute_loss_and_backward(self, batch, do_backward=True): + """Compute loss with optional sequence parallelism and remove padding features""" + use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 + + # Move inputs to GPU and prepare loss mask + input_ids = batch["input_ids"].cuda() + attention_mask = batch["attention_mask"].cuda() + position_ids = batch["position_ids"].cuda() + loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).cuda() + loss_fct = nn.CrossEntropyLoss(reduction="none") + + # Context manager for sequence parallel if needed + context = self.sharding_manager if use_sp else nullcontext() + with context, torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if not use_sp: + # Standard forward pass without sequence parallel + labels = input_ids[:, 1:].contiguous() + output = self.fsdp_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False) + logits = output.logits + + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels.contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + loss = loss * loss_mask.to(loss.device) + else: + # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks + # i.e., each GPU has <1 sequence, and each SP group has 1 sequence + # 1. All SP ranks will receive the *SAME* batch + # 2. Different SP groups will receive *DIFFERENT* batches + # This is implemented by the DistributedSampler + + batch_size, seqlen = input_ids.shape + # Remove padding + input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # Unpad position_ids to align rotary + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1) + + # Pad and slice inputs for sequence parallelism + input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()) + # For computing loss + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size()) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + + # Forward pass + output = self.fsdp_model( + input_ids=input_ids_rmpad_sliced, + attention_mask=None, # Not needed with flash attention varlen + position_ids=position_ids_rmpad_padded, + use_cache=False, + ) + + # Compute loss locally then aggregate + logits_rmpad = output.logits.squeeze(0) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device) + loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) + # Gather and unpad for sequence parallelism + loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) + + # This is the loss collected from all ulysses ranks + full_loss = pad_input(hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen) + full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss + full_loss = full_loss.reshape(-1) + loss_mask = loss_mask.to(full_loss.device) + loss = full_loss * loss_mask + + valid_token_this_rank = torch.sum(loss_mask) + + if self.config.data.balance_dp_token: + torch.distributed.all_reduce(valid_token_this_rank) + dp_size = self.ulysses_device_mesh.size("dp") if use_sp else torch.distributed.get_world_size() + else: + dp_size = 1 + + loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size + + if do_backward: + loss.backward() + return loss + + def training_step(self, batch: TensorDict): + self.fsdp_model.train() + + log_gpu_memory_usage("Before optimizer zero_grad", logger=logger) + + self.optimizer.zero_grad() + + log_gpu_memory_usage("After optimizer zero_grad", logger=logger) + + micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu) + n_micro_batches = len(micro_batches) + step_loss = 0 + for micro_batch in micro_batches: + loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches + step_loss += loss.item() + + grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) + + log_gpu_memory_usage("Before optimizer step", logger=logger) + + # if grad_norm is not finite, skip the update + if not torch.isfinite(grad_norm): + print(f"WARN: grad_norm is not finite: {grad_norm}") + self.optimizer.zero_grad() + else: + self.optimizer.step() + + log_gpu_memory_usage("After optimizer step", logger=logger) + + self.lr_scheduler.step() + + # reduce loss across dp ranks + lr = self.lr_scheduler.get_last_lr()[0] + + log_gpu_memory_usage("After offload weights", logger=logger) + + step_loss = torch.tensor(step_loss).cuda() + torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) + return {"train/loss": step_loss.detach().item(), "train/lr(1e-3)": lr * 1e3} + + def validation_step(self, batch: TensorDict): + self.fsdp_model.eval() + with torch.no_grad(): + loss = self._compute_loss_and_backward(batch, do_backward=False) + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + return loss + + def save_checkpoint(self, step): + # save checkpoint + from torch.distributed.fsdp import FullStateDictConfig, StateDictType + + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg): + state_dict = self.fsdp_model.state_dict() + + path = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}") + # save huggingface model + if self.device_mesh.get_rank() == 0: + os.makedirs(path, exist_ok=True) + self.model.save_pretrained(path, state_dict=state_dict) + self.tokenizer.save_pretrained(path) + if self.config.trainer.default_hdfs_dir: + hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True) + hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True) + torch.distributed.barrier() + + def fit(self): + rank = self.device_mesh.get_rank() + + # TODO: add a unified tracking + if rank == 0: + tracking = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + ) + + global_step = 0 + # compute the total training steps. + # the total training steps in SFT is mainly for early exit + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + # TODO (zhangchi.usc1992) add back checkpoint manager. + # Currently, it blocks when uploading to hdfs. So very slow. + + for epoch in range(self.config.trainer.total_epochs): + self.train_sampler.set_epoch(epoch=epoch) + for data in tqdm( + self.train_dataloader, + total=self.steps_per_epoch, + desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", + ): + global_step += 1 + data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() + metric = self.training_step(data) + if rank == 0: + tracking.log(data=metric, step=global_step) + + # for early exit validation + if global_step >= self.total_training_steps: + # Perform final validation + val_losses = [] + for val_data in self.val_dataloader: + val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + val_loss = self.validation_step(val_data) + val_losses.append(val_loss) + if rank == 0: + avg_val_loss = torch.mean(torch.stack(val_losses)) + metric = {"val/loss": avg_val_loss.detach().item()} + tracking.log(data=metric, step=global_step) + torch.distributed.barrier() + + # Save final checkpoint + self.save_checkpoint(step=global_step) + return + + # validation + val_losses = [] + for data in self.val_dataloader: + data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + val_loss = self.validation_step(data) + val_losses.append(val_loss) + if rank == 0: + val_loss = torch.mean(torch.stack(val_losses)) + metric = {"val/loss": val_loss.detach().item()} + tracking.log(data=metric, step=global_step) + torch.distributed.barrier() + + # save checkpoint + self.save_checkpoint(step=global_step) + + +@hydra.main(config_path="config", config_name="sft_trainer", version_base=None) +def main(config): + local_rank, rank, world_size = initialize_global_process_group() + + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) + dp_size = world_size // config.ulysses_sequence_parallel_size + ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp")) + # build tokenizer and datasets first + from verl.utils import hf_tokenizer + + local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) + tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) + train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer) + val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer) + + trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset) + + trainer.fit() + + +def create_sft_dataset(data_paths, data_config, tokenizer): + """Create a dataset.""" + # build dataset + # First check if a custom dataset class is specified + if data_config.custom_cls.get("path", None): + from verl.utils.import_utils import load_extern_type + + dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + # Then check if multi-turn dataset should be used + elif data_config.get("multiturn", {}).get("enable", False): + dataset_cls = MultiTurnSFTDataset + # Default to single-turn dataset + else: + dataset_cls = SFTDataset + + # Create datasets based on the selected class + dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config) + return dataset + + +if __name__ == "__main__": + main() diff --git a/verl/trainer/main_eval.py b/verl/trainer/main_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6e57f9e8146feef91fe5896db485394efa9596 --- /dev/null +++ b/verl/trainer/main_eval.py @@ -0,0 +1,113 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Offline evaluate the performance of a generated file using reward model and ground truth verifier. +The input is a parquet file that contains N generated sequences and (optional) the ground truth. + +""" + +from collections import defaultdict + +import hydra +import numpy as np +import pandas as pd +import ray +from tqdm import tqdm + +from verl.utils.fs import copy_to_local + + +def get_custom_reward_fn(config): + import importlib.util + import os + import sys + + reward_fn_config = config.get("custom_reward_function") or {} + file_path = reward_fn_config.get("path") + if not file_path: + return None + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Reward function file '{file_path}' not found.") + + spec = importlib.util.spec_from_file_location("custom_module", file_path) + module = importlib.util.module_from_spec(spec) + try: + sys.modules["custom_module"] = module + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f"Error loading module from '{file_path}'") from e + + function_name = reward_fn_config.get("name") + if not hasattr(module, function_name): + raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.") + + print(f"using customized reward function '{function_name}' from '{file_path}'") + raw_fn = getattr(module, function_name) + + reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {})) + + def wrapped_fn(*args, **kwargs): + return raw_fn(*args, **kwargs, **reward_kwargs) + + return wrapped_fn + + +@ray.remote +def process_item(reward_fn, data_source, response_lst, reward_data): + ground_truth = reward_data["ground_truth"] + score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst] + return data_source, np.mean(score_lst) + + +@hydra.main(config_path="config", config_name="evaluation", version_base=None) +def main(config): + local_path = copy_to_local(config.data.path) + dataset = pd.read_parquet(local_path) + responses = dataset[config.data.response_key] + data_sources = dataset[config.data.data_source_key] + reward_model_data = dataset[config.data.reward_model_key] + + total = len(dataset) + + # Initialize Ray + if not ray.is_initialized(): + ray.init(num_cpus=config.ray_init.num_cpus) + + # evaluate test_score based on data source + data_source_reward = defaultdict(list) + compute_score = get_custom_reward_fn(config) + + # Create remote tasks + remote_tasks = [process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)] + + # Process results as they come in + with tqdm(total=total) as pbar: + while len(remote_tasks) > 0: + # Use ray.wait to get completed tasks + done_ids, remote_tasks = ray.wait(remote_tasks) + for result_id in done_ids: + data_source, score = ray.get(result_id) + data_source_reward[data_source].append(score) + pbar.update(1) + + metric_dict = {} + for data_source, rewards in data_source_reward.items(): + metric_dict[f"test_score/{data_source}"] = np.mean(rewards) + + print(metric_dict) + + +if __name__ == "__main__": + main() diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..133397038d5456f927b35cd56db8b0d4a2860107 --- /dev/null +++ b/verl/trainer/main_generation.py @@ -0,0 +1,144 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generate responses given a dataset of prompts +""" + +import os + +import hydra +import numpy as np +import ray + +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" +# os.environ['TORCH_COMPILE_DISABLE'] = '1' + +from pprint import pprint + +import pandas as pd +from omegaconf import OmegaConf + +from verl import DataProto +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_to_local +from verl.utils.hdfs_io import makedirs +from verl.utils.model import compute_position_id_with_mask +from verl.workers.fsdp_workers import ActorRolloutRefWorker + + +@hydra.main(config_path="config", config_name="generation", version_base=None) +def main(config): + run_generation(config) + + +def run_generation(config) -> None: + if not ray.is_initialized(): + # this is for local ray cluster + ray.init( + runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}, + num_cpus=config.ray_init.num_cpus, + ) + + ray.get(main_task.remote(config)) + + +@ray.remote(num_cpus=1) +def main_task(config): + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + local_path = copy_to_local(config.model.path) + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + + if config.rollout.temperature == 0.0: + assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1." + assert config.data.n_samples >= 1, "n_samples should always >= 1" + + # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary) + dataset = pd.read_parquet(config.data.path) + chat_lst = dataset[config.data.prompt_key].tolist() + + chat_lst = [chat.tolist() for chat in chat_lst] + + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") + resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) + wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) + wg.init_model() + + total_samples = len(dataset) + config_batch_size = config.data.batch_size + num_batch = -(-total_samples // config_batch_size) + output_lst = [[] for _ in range(config.data.n_samples)] + + for batch_idx in range(num_batch): + print(f"[{batch_idx + 1}/{num_batch}] Start to process.") + batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size] + inputs = tokenizer.apply_chat_template( + batch_chat_lst, + add_generation_prompt=True, + padding=True, + truncation=True, + max_length=config.rollout.prompt_length, + return_tensors="pt", + return_dict=True, + tokenize=True, + ) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + position_ids = compute_position_id_with_mask(attention_mask) + batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} + + data = DataProto.from_dict(batch_dict) + data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size) + + # START TO GENERATE FOR n_samples TIMES + print(f"[{batch_idx + 1}/{num_batch}] Start to generate.") + for n_sample in range(config.data.n_samples): + output_padded = wg.generate_sequences(data_padded) + output = unpad_dataproto(output_padded, pad_size=pad_size) + + output_texts = [] + for i in range(len(output)): + data_item = output[i] + prompt_length = data_item.batch["prompts"].shape[-1] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_ids = data_item.batch["responses"][:valid_response_length] + response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True) + output_texts.append(response_str) + + output_lst[n_sample].extend(output_texts) + + # convert output_lst from (n_samples, n_data) to (n_data, n_sampels) + output_lst = np.array(output_lst, dtype=object) + output_lst = np.transpose(output_lst, axes=(1, 0)).tolist() + + # add to the data frame + dataset["responses"] = output_lst + + # write to a new parquet + output_dir = os.path.dirname(config.data.output_path) + makedirs(output_dir, exist_ok=True) + dataset.to_parquet(config.data.output_path) + + +if __name__ == "__main__": + main() diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd7ee32bfe9c754f305cecc1cb70b139396a85d --- /dev/null +++ b/verl/trainer/main_ppo.py @@ -0,0 +1,246 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os + +import hydra +import ray + +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.trainer.ppo.reward import load_reward_manager + + +def get_custom_reward_fn(config): + import importlib.util + import sys + + reward_fn_config = config.get("custom_reward_function") or {} + file_path = reward_fn_config.get("path") + if not file_path: + return None + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Reward function file '{file_path}' not found.") + + spec = importlib.util.spec_from_file_location("custom_module", file_path) + module = importlib.util.module_from_spec(spec) + try: + sys.modules["custom_module"] = module + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e + + function_name = reward_fn_config.get("name") + if not hasattr(module, function_name): + raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.") + + print(f"using customized reward function '{function_name}' from '{file_path}'") + raw_fn = getattr(module, function_name) + + reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {})) + + def wrapped_fn(*args, **kwargs): + return raw_fn(*args, **kwargs, **reward_kwargs) + + return wrapped_fn + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + run_ppo(config) + + +def run_ppo(config) -> None: + if not ray.is_initialized(): + # this is for local ray cluster + ray.init( + runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}}, + num_cpus=config.ray_init.num_cpus, + ) + + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # print initial config + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_to_local(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + + # define worker classes + if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: + assert config.critic.strategy in ["fsdp", "fsdp2"] + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + + actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + + actor_rollout_cls = ActorRolloutRefWorker + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(actor_rollout_cls), + Role.Critic: ray.remote(CriticWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy in ["fsdp", "fsdp2"]: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # use reference model + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})) + val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + from verl.utils.dataset.rl_dataset import collate_fn + + train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) + val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) + train_sampler = create_rl_sampler(config.data, train_dataset) + trainer = RayPPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + trainer.init_workers() + trainer.fit() + + +def create_rl_dataset(data_paths, data_config, tokenizer, processor): + """Create a dataset. + + Arguments: + data_config: The data config. + tokenizer (Tokenizer): The tokenizer. + processor (Processor): The processor. + + Returns: + dataset (Dataset): The dataset. + """ + from torch.utils.data import Dataset + + from verl.utils.dataset.rl_dataset import RLHFDataset + + if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + from verl.utils.import_utils import load_extern_type + + dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + if not issubclass(dataset_cls, Dataset): + raise TypeError(f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset") + else: + dataset_cls = RLHFDataset + print(f"Using dataset class: {dataset_cls.__name__}") + + dataset = dataset_cls( + data_files=data_paths, + tokenizer=tokenizer, + processor=processor, + config=data_config, + ) + + return dataset + + +def create_rl_sampler(data_config, dataset): + """Create a sampler for the dataset. + + Arguments: + data_config: The data config. + dataset (Dataset): The dataset. + + Returns: + sampler (Sampler): The sampler. + """ + import torch + from torch.utils.data import RandomSampler, SequentialSampler + + # use sampler for better ckpt resume + if data_config.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(data_config.get("seed", 1)) + sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: + sampler = SequentialSampler(data_source=dataset) + + return sampler + + +if __name__ == "__main__": + main() diff --git a/verl/trainer/ppo/__init__.py b/verl/trainer/ppo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/trainer/ppo/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/trainer/ppo/__pycache__/__init__.cpython-311.pyc b/verl/trainer/ppo/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14b83a7aa2226c8796e1585c9193976b255fd50c Binary files /dev/null and b/verl/trainer/ppo/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/trainer/ppo/__pycache__/core_algos.cpython-311.pyc b/verl/trainer/ppo/__pycache__/core_algos.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74f89a060bf5898ed6c891b54e7a1825f38f417f Binary files /dev/null and b/verl/trainer/ppo/__pycache__/core_algos.cpython-311.pyc differ diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py new file mode 100644 index 0000000000000000000000000000000000000000..3818fab75b0b4aa0e9fb12815f3a76c52cfdd686 --- /dev/null +++ b/verl/trainer/ppo/core_algos.py @@ -0,0 +1,501 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Core functions to implement PPO algorithms. +The function implemented in this file should be used by trainer with different distributed strategies to +implement PPO +""" + +from collections import defaultdict + +import numpy as np +import torch + +import verl.utils.torch_functional as verl_F + + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + + def __init__(self, init_kl_coef, target_kl, horizon): + self.value = init_kl_coef + self.target = target_kl + self.horizon = horizon + + def update(self, current_kl, n_steps): + target = self.target + proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current_kl, n_steps): + pass + + +def get_kl_controller(kl_ctrl): + if kl_ctrl.type == "fixed": + return FixedKLController(kl_coef=kl_ctrl.kl_coef) + elif kl_ctrl.type == "adaptive": + assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" + return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) + else: + raise NotImplementedError + + +def compute_gae_advantage_return( + token_level_rewards: torch.Tensor, + values: torch.Tensor, + response_mask: torch.Tensor, + gamma: torch.Tensor, + lam: torch.Tensor, +): + """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + values: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. + gamma: `(float)` + discounted factor used in RL + lam: `(float)` + lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + + """ + with torch.no_grad(): + lastgaelam = 0 + advantages_reversed = [] + gen_len = token_level_rewards.shape[-1] + + for t in reversed(range(gen_len)): + nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] + lastgaelam = delta + gamma * lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + + returns = advantages + values + advantages = verl_F.masked_whiten(advantages, response_mask) + return advantages, returns + + +# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. +def compute_grpo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + norm_adv_by_std_in_grpo: str = True, +): + """ + Compute advantage for GRPO, operating only on Outcome reward + (with only one scalar reward for each response). + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + norm_adv_by_std_in_grpo: (bool) + whether to scale the GRPO advantage. + If True, the advantage is scaled by the std, as in the original GRPO. + If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + id2std[idx] = torch.std(torch.tensor([id2score[idx]])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + if norm_adv_by_std_in_grpo: + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + else: + scores[i] = scores[i] - id2mean[index[i]] + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6): + """ + Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward + (with only one scalar reward for each response). + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = scores[i] - id2mean[index[i]] + + scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask + scores = verl_F.masked_whiten(scores, response_mask) + + return scores, scores + + +def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6): + """ + Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + response_num = len(id2score[index[i]]) + if response_num > 1: + scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (response_num - 1) + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor): + """ + Compute advantage for REINFORCE++. + This implementation is based on the paper: https://arxiv.org/abs/2501.03262 + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + + with torch.no_grad(): + returns = torch.zeros_like(token_level_rewards) + running_return = 0 + + for t in reversed(range(token_level_rewards.shape[1])): + running_return = token_level_rewards[:, t] + gamma * running_return + returns[:, t] = running_return + # Reset after EOS + running_return = running_return * response_mask[:, t] + + advantages = verl_F.masked_whiten(returns, response_mask) + advantages = advantages * response_mask + + return advantages, returns + + +def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor): + """ + Compute advantage for ReMax, operating only on Outcome reward + This implementation is based on the paper: https://arxiv.org/abs/2310.10505 + + (with only one scalar reward for each response). + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + reward_baselines: `(torch.Tensor)` + shape: (bs,) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + + with torch.no_grad(): + returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + advantages = returns - reward_baselines.unsqueeze(-1) * response_mask + + return advantages, returns + + +def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): + kl = old_log_prob - ref_log_prob + return token_level_scores - kl * kl_ratio + + +def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str): + """ + Aggregate the loss matrix into a scalar. + Args: + loss_mat: `(torch.Tensor)` + shape: (bs, response_length) + loss_mask: `(torch.Tensor)` + shape: (bs, response_length) + loss_agg_mode: (str) choices: "token-mean" / + "seq-mean-token-sum" / + "seq-mean-token-mean" / + "seq-mean-token-sum-norm" / + "token-mean" is the default behavior + Returns: + loss: `a scalar torch.Tensor` + aggregated loss + """ + if loss_agg_mode == "token-mean": + loss = verl_F.masked_mean(loss_mat, loss_mask) + elif loss_agg_mode == "seq-mean-token-sum": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum + loss = torch.mean(seq_losses) # seq-mean + elif loss_agg_mode == "seq-mean-token-mean": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean + loss = torch.mean(seq_losses) # seq-mean + elif loss_agg_mode == "seq-mean-token-sum-norm": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) + loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor + # (loss_mask.shape[-1]) should ideally be constant + # throughout training to well-replicate the DrGRPO paper. + # TODO: Perhaps add user-defined normalizer argument to + # agg_loss to ensure divisor stays constant throughout. + else: + raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") + + return loss + + +def compute_policy_loss( + old_log_prob, + log_prob, + advantages, + response_mask, + cliprange=None, + cliprange_low=None, + cliprange_high=None, + clip_ratio_c=3.0, + loss_agg_mode="token-mean", +): + """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 + Args: + old_log_prob: `(torch.Tensor)` + shape: (bs, response_length) + log_prob: `(torch.Tensor)` + shape: (bs, response_length) + advantages: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + cliprange: (float) + The clip range used in PPO. See https://arxiv.org/abs/1707.06347 + cliprange_low: (float) + The lower clip range used in PPO. + cliprange_high: (float) + The higher clip range used in PPO. + clip_ratio_c: (float) default: 3.0 + The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729 + loss_agg_mode: (str) choices: "token-mean" / + "seq-mean-token-sum" / + "seq-mean-token-mean" / + "seq-mean-token-sum-norm" / + "token-mean" is the default behavior + + Returns: + pg_loss: `a scalar torch.Tensor` + policy gradient loss computed via PPO + pg_clipfrac: (float) + the fraction of policy gradient loss being clipped + ppo_kl: (float) + the estimated KL divergence between the latest updating policy and the old sampling policy + pg_clipfrac_lower: (float) + the fraction of policy gradient loss being clipped when the advantage is negative + """ + assert clip_ratio_c > 1.0, "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}." + + negative_approx_kl = log_prob - old_log_prob + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_losses1 = -advantages * ratio + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) + + pg_losses3 = -advantages * clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask) + + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower + + +def compute_entropy_loss(logits, response_mask): + """Compute Categorical entropy loss + + Args: + logits: `(torch.Tensor)` + shape: (bs, response_length, vocab_size) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + + Returns: + entropy: a scalar torch.Tensor + + """ + # compute entropy + entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) + entropy_loss = verl_F.masked_mean(entropy, mask=response_mask) + return entropy_loss + + +def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): + """Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 + + Args: + vpreds (`torch.FloatTensor`): + Predicted values of the value head, shape (`batch_size`, `response_length`) + values (`torch.FloatTensor`): + Old values of value head, shape (`batch_size`, `response_length`) + returns: (`torch.FloatTensor`): + Ground truth returns, shape (`batch_size`, `response_length`) + + Returns: + vf_loss: a scalar (`torch.FloatTensor`): + value function loss + vf_clipfrac: a float + The ratio of vf being clipped + + """ + vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) + vf_losses1 = (vpreds - returns) ** 2 + vf_losses2 = (vpredclipped - returns) ** 2 + vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask) + vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) + return vf_loss, vf_clipfrac + + +def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: + """Compute KL divergence given logprob and ref_logprob. + Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 + + Args: + logprob: + ref_logprob: + + Returns: + + """ + if kl_penalty == "kl": + return logprob - ref_logprob + + if kl_penalty == "abs": + return (logprob - ref_logprob).abs() + + if kl_penalty == "mse": + return 0.5 * (logprob - ref_logprob).square() + + # J. Schulman. Approximating kl divergence, 2020. + # # URL http://joschu.net/blog/kl-approx.html. + if kl_penalty == "low_var_kl": + kl = ref_logprob - logprob + ratio = torch.exp(kl) + kld = (ratio - kl - 1).contiguous() + return torch.clamp(kld, min=-10, max=10) + + if kl_penalty == "full": + # so, here logprob and ref_logprob should contain the logits for every token in vocabulary + raise NotImplementedError + + raise NotImplementedError diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..949a838c897a87f3153c7c45b36239bfb95a8302 --- /dev/null +++ b/verl/trainer/ppo/metric_utils.py @@ -0,0 +1,426 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Metrics related to the PPO trainer. +""" + +from collections import defaultdict +from functools import partial +from typing import Any, Callable, Dict, List + +import numpy as np +import torch + +from verl import DataProto +from verl.utils.import_utils import deprecated + +@deprecated("verl.utils.metric.reduce_metrics") +def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: + """ + Reduces a dictionary of metric lists by computing the mean of each list. + + Args: + metrics: A dictionary mapping metric names to lists of metric values. + + Returns: + A dictionary with the same keys but with each list replaced by its mean value. + + Example: + >>> metrics = {"loss": [1.0, 2.0, 3.0], "accuracy": [0.8, 0.9, 0.7]} + >>> reduce_metrics(metrics) + {"loss": 2.0, "accuracy": 0.8} + """ + from verl.utils.metric import reduce_metrics + + return reduce_metrics(metrics) + + +def _compute_response_info(batch: DataProto) -> Dict[str, Any]: + """ + Computes information about prompts and responses from a batch. + + This is an internal helper function that extracts masks and lengths for prompts and responses. + + Args: + batch: A DataProto object containing batch data with responses and attention masks. + + Returns: + A dictionary containing: + - response_mask: Attention mask for the response tokens + - prompt_length: Tensor of prompt lengths for each item in the batch + - response_length: Tensor of response lengths for each item in the batch + """ + response_length = batch.batch["responses"].shape[-1] + + prompt_mask = batch.batch["attention_mask"][:, :-response_length] + response_mask = batch.batch["attention_mask"][:, -response_length:] + + prompt_length = prompt_mask.sum(-1).float() + response_length = response_mask.sum(-1).float() # (batch_size,) + + return dict( + response_mask=response_mask, + prompt_length=prompt_length, + response_length=response_length, + ) + + +def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, Any]: + """ + Computes various metrics from a batch of data for PPO training. + + This function calculates metrics related to scores, rewards, advantages, returns, values, + and sequence lengths from a batch of data. It provides statistical information (mean, max, min) + for each metric category. + + Args: + batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. + use_critic: Whether to include critic-specific metrics. Defaults to True. + + Returns: + A dictionary of metrics including: + - critic/score/mean, max, min: Statistics about sequence scores + - critic/rewards/mean, max, min: Statistics about sequence rewards + - critic/advantages/mean, max, min: Statistics about advantages + - critic/returns/mean, max, min: Statistics about returns + - critic/values/mean, max, min: Statistics about critic values (if use_critic=True) + - critic/vf_explained_var: Explained variance of the value function (if use_critic=True) + - response_length/mean, max, min, clip_ratio: Statistics about response lengths + - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths + """ + sequence_score = batch.batch["token_level_scores"].sum(-1) + sequence_reward = batch.batch["token_level_rewards"].sum(-1) + + advantages = batch.batch["advantages"] + returns = batch.batch["returns"] + + max_response_length = batch.batch["responses"].shape[-1] + + prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() + response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() + + max_prompt_length = prompt_mask.size(-1) + + response_info = _compute_response_info(batch) + prompt_length = response_info["prompt_length"] + response_length = response_info["response_length"] + + valid_adv = torch.masked_select(advantages, response_mask) + valid_returns = torch.masked_select(returns, response_mask) + + if use_critic: + values = batch.batch["values"] + valid_values = torch.masked_select(values, response_mask) + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) + + metrics = { + # score + "critic/score/mean": torch.mean(sequence_score).detach().item(), + "critic/score/max": torch.max(sequence_score).detach().item(), + "critic/score/min": torch.min(sequence_score).detach().item(), + # reward + "critic/rewards/mean": torch.mean(sequence_reward).detach().item(), + "critic/rewards/max": torch.max(sequence_reward).detach().item(), + "critic/rewards/min": torch.min(sequence_reward).detach().item(), + # adv + "critic/advantages/mean": torch.mean(valid_adv).detach().item(), + "critic/advantages/max": torch.max(valid_adv).detach().item(), + "critic/advantages/min": torch.min(valid_adv).detach().item(), + # returns + "critic/returns/mean": torch.mean(valid_returns).detach().item(), + "critic/returns/max": torch.max(valid_returns).detach().item(), + "critic/returns/min": torch.min(valid_returns).detach().item(), + **( + { + # values + "critic/values/mean": torch.mean(valid_values).detach().item(), + "critic/values/max": torch.max(valid_values).detach().item(), + "critic/values/min": torch.min(valid_values).detach().item(), + # vf explained var + "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } + if use_critic + else {} + ), + # response length + "response_length/mean": torch.mean(response_length).detach().item(), + "response_length/max": torch.max(response_length).detach().item(), + "response_length/min": torch.min(response_length).detach().item(), + "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), + # prompt length + "prompt_length/mean": torch.mean(prompt_length).detach().item(), + "prompt_length/max": torch.max(prompt_length).detach().item(), + "prompt_length/min": torch.min(prompt_length).detach().item(), + "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + } + return metrics + + +def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]: + """ + Computes timing metrics for different processing stages in PPO training. + + This function calculates both raw timing metrics (in seconds) and per-token timing metrics + (in milliseconds) for various processing stages like generation, reference computation, + value computation, advantage computation, and model updates. + + Args: + batch: A DataProto object containing batch data with responses and attention masks. + timing_raw: A dictionary mapping stage names to their execution times in seconds. + + Returns: + A dictionary containing: + - timing_s/{name}: Raw timing in seconds for each stage + - timing_per_token_ms/{name}: Per-token timing in milliseconds for each stage + + Note: + Different stages use different token counts for normalization: + - "gen" uses only response tokens + - Other stages ("ref", "values", "adv", "update_critic", "update_actor") use all tokens + (prompt + response) + """ + response_info = _compute_response_info(batch) + num_prompt_tokens = torch.sum(response_info["prompt_length"]).item() + num_response_tokens = torch.sum(response_info["response_length"]).item() + num_overall_tokens = num_prompt_tokens + num_response_tokens + + num_tokens_of_section = { + "gen": num_response_tokens, + **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]}, + } + + return { + **{f"timing_s/{name}": value for name, value in timing_raw.items()}, + **{f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())}, + } + + +def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]: + """ + Computes throughput metrics for PPO training. + + This function calculates performance metrics related to token processing speed, + including the total number of tokens processed, time per step, and throughput + (tokens per second per GPU). + + Args: + batch: A DataProto object containing batch data with meta information about token counts. + timing_raw: A dictionary mapping stage names to their execution times in seconds. + Must contain a "step" key with the total step time. + n_gpus: Number of GPUs used for training. + + Returns: + A dictionary containing: + - perf/total_num_tokens: Total number of tokens processed in the batch + - perf/time_per_step: Time taken for the step in seconds + - perf/throughput: Tokens processed per second per GPU + + Note: + The throughput is calculated as total_tokens / (time * n_gpus) to normalize + across different GPU counts. + """ + total_num_tokens = sum(batch.meta_info["global_token_num"]) + time = timing_raw["step"] + # estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time) + # f'Actual TFLOPs/s/GPU​': estimated_flops/(n_gpus), + # f'Theoretical TFLOPs/s/GPU​': promised_flops, + return { + "perf/total_num_tokens": total_num_tokens, + "perf/time_per_step": time, + "perf/throughput": total_num_tokens / (time * n_gpus), + } + + +def bootstrap_metric( + data: list[Any], + subset_size: int, + reduce_fns: list[Callable[[np.ndarray], float]], + n_bootstrap: int = 1000, + seed: int = 42, +) -> list[tuple[float, float]]: + """ + Performs bootstrap resampling to estimate statistics of metrics. + + This function uses bootstrap resampling to estimate the mean and standard deviation + of metrics computed by the provided reduction functions on random subsets of the data. + + Args: + data: List of data points to bootstrap from. + subset_size: Size of each bootstrap sample. + reduce_fns: List of functions that compute a metric from a subset of data. + n_bootstrap: Number of bootstrap iterations. Defaults to 1000. + seed: Random seed for reproducibility. Defaults to 42. + + Returns: + A list of tuples, where each tuple contains (mean, std) for a metric + corresponding to each reduction function in reduce_fns. + + Example: + >>> data = [1, 2, 3, 4, 5] + >>> reduce_fns = [np.mean, np.max] + >>> bootstrap_metric(data, 3, reduce_fns) + [(3.0, 0.5), (4.5, 0.3)] # Example values + """ + np.random.seed(seed) + + bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))] + for _ in range(n_bootstrap): + bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True) + bootstrap_data = [data[i] for i in bootstrap_idxs] + for i, reduce_fn in enumerate(reduce_fns): + bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data)) + return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts] + + +def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float: + """ + Calculate a value based on majority voting. + + This function identifies the most common value for a specified vote key + in the data, then returns the corresponding value for that majority vote. + + Args: + data: List of dictionaries, where each dictionary contains both vote_key and val_key. + vote_key: The key in each dictionary used for voting/counting. + val_key: The key in each dictionary whose value will be returned for the majority vote. + + Returns: + The value associated with the most common vote. + + Example: + >>> data = [ + ... {"pred": "A", "val": 0.9}, + ... {"pred": "B", "val": 0.8}, + ... {"pred": "A", "val": 0.7} + ... ] + >>> calc_maj_val(data, vote_key="pred", val_key="val") + 0.9 # Returns the first "val" for the majority vote "A" + """ + vote2vals = defaultdict(list) + for d in data: + vote2vals[d[vote_key]].append(d[val_key]) + + vote2cnt = {k: len(v) for k, v in vote2vals.items()} + maj_vote = max(vote2cnt, key=vote2cnt.get) + + maj_val = vote2vals[maj_vote][0] + + return maj_val + + +def process_validation_metrics(data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42) -> dict[str, dict[str, dict[str, float]]]: + """ + Process validation metrics into a structured format with statistical analysis. + + This function organizes validation metrics by data source and prompt, then computes + various statistical measures including means, standard deviations, best/worst values, + and majority voting results. It also performs bootstrap sampling to estimate statistics + for different sample sizes. + + Args: + data_sources: List of data source identifiers for each sample. + sample_inputs: List of input prompts corresponding to each sample. + infos_dict: Dictionary mapping variable names to lists of values for each sample. + seed: Random seed for bootstrap sampling. Defaults to 42. + + Returns: + A nested dictionary with the structure: + { + data_source: { + variable_name: { + metric_name: value + } + } + } + + Where metric_name includes: + - "mean@N": Mean value across N samples + - "std@N": Standard deviation across N samples + - "best@N/mean": Mean of the best values in bootstrap samples of size N + - "best@N/std": Standard deviation of the best values in bootstrap samples + - "worst@N/mean": Mean of the worst values in bootstrap samples + - "worst@N/std": Standard deviation of the worst values in bootstrap samples + - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists) + - "maj@N/std": Standard deviation of majority voting results (if "pred" exists) + + Example: + >>> data_sources = ["source1", "source1", "source2"] + >>> sample_inputs = ["prompt1", "prompt1", "prompt2"] + >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]} + >>> result = process_validation_metrics(data_sources, sample_inputs, infos_dict) + >>> # result will contain statistics for each data source and variable + """ + # Group metrics by data source, prompt and variable + data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for sample_idx, data_source in enumerate(data_sources): + prompt = sample_inputs[sample_idx] + var2vals = data_src2prompt2var2vals[data_source][prompt] + for var_name, var_vals in infos_dict.items(): + var2vals[var_name].append(var_vals[sample_idx]) + + # Calculate metrics for each group + data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + for data_source, prompt2var2vals in data_src2prompt2var2vals.items(): + print(prompt2var2vals) + for prompt, var2vals in prompt2var2vals.items(): + for var_name, var_vals in var2vals.items(): + if isinstance(var_vals[0], str): + continue + + metric = {} + n_resps = len(var_vals) + metric[f"mean@{n_resps}"] = np.mean(var_vals) + + if n_resps > 1: + metric[f"std@{n_resps}"] = np.std(var_vals) + + ns = [] + n = 2 + while n < n_resps: + ns.append(n) + n *= 2 + ns.append(n_resps) + + for n in ns: + [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed) + metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std + metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std + if var2vals.get("pred", None) is not None: + vote_data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"])] + [(maj_n_mean, maj_n_std)] = bootstrap_metric( + data=vote_data, + subset_size=n, + reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")], + seed=seed, + ) + metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std + + data_src2prompt2var2metric[data_source][prompt][var_name] = metric + + # Aggregate metrics across prompts + data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for data_source, prompt2var2metric in data_src2prompt2var2metric.items(): + for prompt, var2metric in prompt2var2metric.items(): + for var_name, metric in var2metric.items(): + for metric_name, metric_val in metric.items(): + data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val) + + data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float))) + for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items(): + for var_name, metric2prompt_vals in var2metric2prompt_vals.items(): + for metric_name, prompt_vals in metric2prompt_vals.items(): + data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals) + + return data_src2var2metric2val diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b1b781286639bcafb0b5e71ac1b589640a800da --- /dev/null +++ b/verl/trainer/ppo/ray_trainer.py @@ -0,0 +1,1084 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import json +import os +import uuid +from collections import defaultdict +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass, field +from enum import Enum +from pprint import pprint +from typing import Dict, Optional, Type + +import numpy as np +import ray +import torch +from codetiming import Timer +from omegaconf import OmegaConf, open_dict +from torch.utils.data import Dataset, Sampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl import DataProto +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.base import Worker +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + process_validation_metrics, +) +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path +from verl.utils.metric import ( + reduce_metrics, +) +from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.torch_functional import masked_mean +from verl.utils.tracking import ValidationGenerationsLogger +from verl.workers.rollout.async_server import AsyncLLMServerManager + +WorkerType = Type[Worker] + + +class Role(Enum): + """ + To create more roles dynamically, you can subclass Role and add new members + """ + + Actor = 0 + Rollout = 1 + ActorRollout = 2 + Critic = 3 + RefPolicy = 4 + RewardModel = 5 + ActorRolloutRef = 6 + + +class AdvantageEstimator(str, Enum): + """ + Using an enumeration class to avoid spelling errors in adv_estimator + """ + + GAE = "gae" + GRPO = "grpo" + REINFORCE_PLUS_PLUS = "reinforce_plus_plus" + REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" + REMAX = "remax" + RLOO = "rloo" + + +@dataclass +class ResourcePoolManager: + """ + Define a resource pool specification. Resource pool will be initialized first. + """ + + resource_pool_spec: dict[str, list[int]] + mapping: dict[Role, str] + resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) + + def create_resource_pool(self): + for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): + # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool + # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. + # For Megatron backend, we recommend using max_colocate_count>1 + # that can utilize different WorkerGroup for differnt models + resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name) + self.resource_pool_dict[resource_pool_name] = resource_pool + + self._check_resource_available() + + def get_resource_pool(self, role: Role) -> RayResourcePool: + """Get the resource pool of the worker_cls""" + return self.resource_pool_dict[self.mapping[role]] + + def get_n_gpus(self) -> int: + """Get the number of gpus in this cluster.""" + return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + + def _check_resource_available(self): + """Check if the resource pool can be satisfied in this ray cluster.""" + node_available_resources = ray.state.available_resources_per_node() + node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} + + # check total required gpus can be satisfied + total_available_gpus = sum(node_available_gpus.values()) + total_required_gpus = sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + if total_available_gpus < total_required_gpus: + raise ValueError(f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}") + + # check each resource pool can be satisfied, O(#resource_pools * #nodes) + for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): + num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes) + for node, available_gpus in node_available_gpus.items(): + if available_gpus >= num_gpus: + node_available_gpus[node] -= num_gpus + num_nodes -= 1 + if num_nodes == 0: + break + if num_nodes > 0: + raise ValueError(f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes}" + "cannot be satisfied in this ray cluster") + + +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl", multi_turn=False): + responses = data.batch["responses"] + response_length = responses.size(1) + token_level_scores = data.batch["token_level_scores"] + batch_size = data.batch.batch_size[0] + + if multi_turn: + loss_mask = data.batch["loss_mask"] + response_mask = loss_mask[:, -response_length:] + else: + attention_mask = data.batch["attention_mask"] + response_mask = attention_mask[:, -response_length:] + + # compute kl between ref_policy and current policy + # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. + kld = core_algos.kl_penalty(data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty) # (batch_size, response_length) + kld = kld * response_mask + beta = kl_ctrl.value + + token_level_rewards = token_level_scores - beta * kld + + current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence + current_kl = torch.mean(current_kl, dim=0).item() + + # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 + kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) + data.batch["token_level_rewards"] = token_level_rewards + + metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} + + return data, metrics + + +def compute_response_mask(data: DataProto): + responses = data.batch["responses"] + response_length = responses.size(1) + attention_mask = data.batch["attention_mask"] + return attention_mask[:, -response_length:] + + +def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True): + # Back-compatible with trainers that do not compute response mask in fit + if "response_mask" not in data.batch: + data.batch["response_mask"] = compute_response_mask(data) + # prepare response group + # TODO: add other ways to estimate advantages + if adv_estimator == AdvantageEstimator.GAE: + advantages, returns = core_algos.compute_gae_advantage_return( + token_level_rewards=data.batch["token_level_rewards"], + values=data.batch["values"], + response_mask=data.batch["response_mask"], + gamma=gamma, + lam=lam, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + elif adv_estimator == AdvantageEstimator.GRPO: + # TODO: test on more adv estimator type + grpo_calculation_mask = data.batch["response_mask"] + if multi_turn: + # If multi-turn, replace the mask with the relevant part of loss_mask + response_length = grpo_calculation_mask.size(1) # Get length from the initial response mask + grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:] # This mask is the one intended for GRPO + # Call compute_grpo_outcome_advantage with parameters matching its definition + advantages, returns = core_algos.compute_grpo_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=grpo_calculation_mask, + index=data.non_tensor_batch["uid"], + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE: + advantages, returns = core_algos.compute_reinforce_plus_plus_baseline_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=data.batch["response_mask"], + index=data.non_tensor_batch["uid"], + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: + advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=data.batch["response_mask"], + gamma=gamma, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + elif adv_estimator == AdvantageEstimator.REMAX: + advantages, returns = core_algos.compute_remax_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + reward_baselines=data.batch["reward_baselines"], + response_mask=data.batch["response_mask"], + ) + + data.batch["advantages"] = advantages + data.batch["returns"] = returns + elif adv_estimator == AdvantageEstimator.RLOO: + advantages, returns = core_algos.compute_rloo_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=data.batch["response_mask"], + index=data.non_tensor_batch["uid"], + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + else: + raise NotImplementedError + return data + + +@contextmanager +def _timer(name: str, timing_raw: Dict[str, float]): + with Timer(name=name, logger=None) as timer: + yield + if name not in timing_raw: + timing_raw[name] = 0 + timing_raw[name] += timer.last + + +class RayPPOTrainer: + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + ): + # assert torch.cuda.is_available(), 'cuda must be available on driver' + + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert self.hybrid_engine, "Currently, only support hybrid engine" + + if self.hybrid_engine: + assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = Role.RefPolicy in role_worker_mapping + self.use_rm = Role.RewardModel in role_worker_mapping + self.ray_worker_group_cls = ray_worker_group_cls + self.validation_generations_logger = ValidationGenerationsLogger() + + # define in-reward KL control + # kl loss control currently not suppoorted + if config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: + self.use_critic = True + elif self.config.algorithm.adv_estimator in [ + AdvantageEstimator.GRPO, + AdvantageEstimator.REINFORCE_PLUS_PLUS, + AdvantageEstimator.REMAX, + AdvantageEstimator.RLOO, + AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE, + ]: + self.use_critic = False + else: + raise NotImplementedError + + self._validate_config() + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + def _validate_config(self): + config = self.config + # number of GPUs total + n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes + + # 1. Check total batch size for data correctness + real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n + assert real_train_batch_size % n_gpus == 0, f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." + + # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" + # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". + def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): + settings = { + "actor_rollout_ref.actor": "micro_batch_size", + "critic": "micro_batch_size", + "reward_model": "micro_batch_size", + "actor_rollout_ref.ref": "log_prob_micro_batch_size", + "actor_rollout_ref.rollout": "log_prob_micro_batch_size", + } + + if name in settings: + param = settings[name] + param_per_gpu = f"{param}_per_gpu" + + if mbs is None and mbs_per_gpu is None: + raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.") + + if mbs is not None and mbs_per_gpu is not None: + raise ValueError(f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove '{name}.{param}' because only '*_{param_per_gpu}'" + "is supported (the former is deprecated).") + + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.actor.ppo_micro_batch_size, + config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, + "actor_rollout_ref.actor", + ) + + if self.use_reference_policy: + # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.ref.log_prob_micro_batch_size, + config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.ref", + ) + + # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.rollout.log_prob_micro_batch_size, + config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.rollout", + ) + + if self.use_critic and not config.critic.use_dynamic_bsz: + # Check for critic micro-batch size conflicts + check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic") + + # Check for reward model micro-batch size conflicts + if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: + check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model") + + # Actor + # check if train_batch_size is larger than ppo_mini_batch_size + # if NOT dynamic_bsz, we must ensure: + # ppo_mini_batch_size is divisible by ppo_micro_batch_size + # ppo_micro_batch_size * sequence_parallel_size >= n_gpus + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size + sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) + if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: + assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 + assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus + + assert config.actor_rollout_ref.actor.loss_agg_mode in [ + "token-mean", + "seq-mean-token-sum", + "seq-mean-token-mean", + "seq-mean-token-sum-norm", + ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" + + if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: + print("NOTICE: You have both enabled in-reward kl and kl loss.") + + # critic + if self.use_critic and not config.critic.use_dynamic_bsz: + assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size + sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) + if config.critic.ppo_micro_batch_size is not None: + assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 + assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus + + # Check if use_remove_padding is enabled when using sequence parallelism for fsdp + if config.actor_rollout_ref.actor.strategy == "fsdp" and (config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1): + assert config.actor_rollout_ref.model.use_remove_padding, "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." + + if self.use_critic and config.critic.strategy == "fsdp": + if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: + assert config.critic.model.use_remove_padding, "When using sequence parallelism for critic, you must enable `use_remove_padding`." + + if config.data.get("val_batch_size", None) is not None: + print("WARNING: val_batch_size is deprecated." + " Validation datasets are sent to inference engines as a whole batch," + " which will schedule the memory themselves.") + + # check eval config + if config.actor_rollout_ref.rollout.val_kwargs.do_sample: + assert config.actor_rollout_ref.rollout.temperature > 0, "validation gen temperature should be greater than 0 when enabling do_sample" + + # check multi_turn with tool config + if config.actor_rollout_ref.rollout.multi_turn.enable: + assert config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None, "tool_config_path must be set when enabling multi_turn with tool, due to no role-playing support" + assert config.algorithm.adv_estimator in [AdvantageEstimator.GRPO], "only GRPO is tested for multi-turn with tool" + + print("[validate_config] All configuration checks passed successfully!") + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor) + if val_dataset is None: + val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor) + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=self.config.data.get("dataloader_num_workers", 8), + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=self.config.data.get("dataloader_num_workers", 8), + shuffle=False, + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}") + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path): + """Dump rollout/validation samples as JSONL.""" + os.makedirs(dump_path, exist_ok=True) + filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") + + n = len(inputs) + base_data = { + "input": inputs, + "output": outputs, + "score": scores, + "step": [self.global_steps] * n, + } + + for k, v in reward_extra_infos_dict.items(): + if len(v) == n: + base_data[k] = v + + with open(filename, "w") as f: + for i in range(n): + entry = {k: v[i] for k, v in base_data.items()} + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + print(f"Dumped generations to {filename}") + + def _maybe_log_val_generations(self, inputs, outputs, scores): + """Log a table of validation samples to the configured logger (wandb or swanlab)""" + + generations_to_log = self.config.trainer.log_val_generations + + if generations_to_log == 0: + return + + import numpy as np + + # Create tuples of (input, output, score) and sort by input text + samples = list(zip(inputs, outputs, scores)) + samples.sort(key=lambda x: x[0]) # Sort by input text + + # Use fixed random seed for deterministic shuffling + rng = np.random.RandomState(42) + rng.shuffle(samples) + + # Take first N samples after shuffling + samples = samples[:generations_to_log] + + # Log to each configured logger + self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + + def _validate(self): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_scores = [] + + for test_data in self.val_dataloader: + test_batch = DataProto.from_single_dict(test_data) + + # repeat test batch + test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True) + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + return {} + + # Store original inputs + input_ids = test_batch.batch["input_ids"] + # TODO: Can we keep special tokens except for padding tokens? + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_inputs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"]) + if "raw_prompt" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + test_gen_batch = test_batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + test_gen_batch.meta_info = { + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + "validate": True, + } + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # pad to be divisible by dp_size + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) + if not self.async_rollout_mode: + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + else: + self.async_rollout_manager.wake_up() + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + self.async_rollout_manager.sleep() + + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + print("validation generation end") + + # Store generated outputs + output_ids = test_output_gen_batch.batch["responses"] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch = test_batch.union(test_output_gen_batch) + + # evaluate using reward_function + result = self.val_reward_fn(test_batch, return_dict=True) + reward_tensor = result["reward_tensor"] + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + print("val-scores in batch: ", scores) + + reward_extra_infos_dict["reward"].extend(scores) + if "reward_extra_info" in result: + for key, lst in result["reward_extra_info"].items(): + reward_extra_infos_dict[key].extend(lst) + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + data_sources = np.concatenate(data_source_lst, axis=0) + + data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) and (f"@{n_max}" in metric_name): + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + return metric_dict + + def init_workers(self): + """Init resource pool and worker group""" + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role="actor_rollout", + ) + self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref") + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, **wg_kwargs) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg["critic"] + self.critic_wg.init_model() + + if self.use_reference_policy: + self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = all_wg["rm"] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg["actor_rollout"] + self.actor_rollout_wg.init_model() + + # create async rollout manager and request scheduler + self.async_rollout_mode = False + if self.config.actor_rollout_ref.rollout.mode == "async": + self.async_rollout_mode = True + self.async_rollout_manager = AsyncLLMServerManager( + config=self.config.actor_rollout_ref, + worker_group=self.actor_rollout_wg, + ) + + def _save_checkpoint(self): + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}") + + print(f"local_global_step_folder: {local_global_step_folder}") + + actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + if remove_previous_ckpt_in_save: + print("Warning: remove_previous_ckpt_in_save is deprecated," + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead") + max_actor_ckpt_to_keep = self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + max_critic_ckpt_to_keep = self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + + self.actor_rollout_wg.save_checkpoint(local_global_step_folder, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep) + + if self.use_critic: + critic_local_path = os.path.join(local_global_step_folder, "critic") + critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep) + + # save dataloader + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + + # # latest checkpointed iteration tracker (for atomic usage) + # local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt") + # with open(local_latest_checkpointed_iteration, "w") as f: + # f.write(str(self.global_steps)) + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == "disable": + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + raise NotImplementedError("load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("Training from scratch") + return 0 + else: + if self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" + assert "global_step_" in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps" + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f"Load from checkpoint folder: {global_step_folder}") + # set global step + self.global_steps = int(global_step_folder.split("global_step_")[-1]) + + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") + + actor_path = os.path.join(global_step_folder, "actor") + critic_path = os.path.join(global_step_folder, "critic") + # load actor + self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + # load critic + if self.use_critic: + self.critic_wg.load_checkpoint(critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + + # load dataloader, + # TODO: from remote not implemented yet + dataloader_local_path = os.path.join(global_step_folder, "data.pt") + if os.path.exists(dataloader_local_path): + dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + else: + print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): + """Reorder the data on single controller such that each dp rank gets similar total tokens""" + attention_mask = batch.batch["attention_mask"] + batch_size = attention_mask.shape[0] + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + world_size = self.actor_rollout_wg.world_size + global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst, k_partitions=world_size, equal_size=True) + # reorder based on index. The data will be automatically equally partitioned by dispatch function + global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) + batch.reorder(global_idx) + global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix) + metrics.update(global_balance_stats) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + # load checkpoint before doing anything + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + + if self.config.trainer.get("val_only", False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + best_val_rewards = float("-inf") + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # pop those keys for generation + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_inputs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"]) + if "raw_prompt" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + is_last_step = self.global_steps >= self.total_training_steps + + with _timer("step", timing_raw): + # generate a batch + with _timer("gen", timing_raw): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + else: + self.async_rollout_manager.wake_up() + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + self.async_rollout_manager.sleep() + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with _timer("gen_max", timing_raw): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + + batch = batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + batch.batch["response_mask"] = compute_response_mask(batch) + # balance the number of valid tokens on each dp rank. + # Note that this breaks the order of data inside the batch. + # Please take care when you implement group based adv computation such as GRPO and rloo + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + with _timer("reward", timing_raw): + # compute reward model score + if self.use_rm: + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer) + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + + # recompute old_log_probs + with _timer("old_log_prob", timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with _timer("ref", timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with _timer("values", timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with _timer("adv", timing_raw): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + print(f"{list(reward_extra_infos_dict.keys())=}") + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # compute advantages, executed on the driver process + + norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + multi_turn=self.config.actor_rollout_ref.rollout.multi_turn.enable, + ) + + # update critic + if self.use_critic: + with _timer("update_critic", timing_raw): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with _timer("update_actor", timing_raw): + batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + with _timer("dump_rollout_generations", timing_raw): + print(batch.batch.keys()) + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + self._dump_generations( + inputs=inputs, + outputs=outputs, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=rollout_data_dir, + ) + + # validate + if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ + (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): + with _timer("testing", timing_raw): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + for key in val_metrics: + if "val-aux" in key and "mean@" in key: + if val_metrics[key] > best_val_rewards: + best_val_rewards = metrics[key] + if self.config.trainer.save_freq > 0: + with _timer("save_checkpoint", timing_raw): + self._save_checkpoint() + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + progress_bar.update(1) + self.global_steps += 1 diff --git a/verl/trainer/ppo/reward.py b/verl/trainer/ppo/reward.py new file mode 100644 index 0000000000000000000000000000000000000000..987c539152374f958f882c9731a6f3fe389a691a --- /dev/null +++ b/verl/trainer/ppo/reward.py @@ -0,0 +1,116 @@ +# Copyright 2025 Individual Contributor: Thibaut Barroyer +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import ray + +from verl import DataProto + + +def get_custom_reward_fn(config): + import importlib.util + import sys + + reward_fn_config = config.get("custom_reward_function") or {} + file_path = reward_fn_config.get("path") + if not file_path: + return None + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Reward function file '{file_path}' not found.") + + spec = importlib.util.spec_from_file_location("custom_module", file_path) + module = importlib.util.module_from_spec(spec) + try: + sys.modules["custom_module"] = module + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e + + function_name = reward_fn_config.get("name") + if not hasattr(module, function_name): + raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.") + + print(f"using customized reward function '{function_name}' from '{file_path}'") + raw_fn = getattr(module, function_name) + + reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {})) + + def wrapped_fn(*args, **kwargs): + return raw_fn(*args, **kwargs, **reward_kwargs) + + return wrapped_fn + + +def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): + reward_manager_name = config.reward_model.get("reward_manager", "naive") + if reward_manager_name == "naive": + from verl.workers.reward_manager import NaiveRewardManager + + reward_manager_cls = NaiveRewardManager + elif reward_manager_name == "prime": + from verl.workers.reward_manager import PrimeRewardManager + + reward_manager_cls = PrimeRewardManager + elif reward_manager_name == "batch": + from verl.workers.reward_manager import BatchRewardManager + + reward_manager_cls = BatchRewardManager + elif reward_manager_name == "dapo": + from verl.workers.reward_manager import DAPORewardManager + + reward_manager_cls = DAPORewardManager + else: + raise NotImplementedError + + compute_score = get_custom_reward_fn(config) + return reward_manager_cls( + tokenizer=tokenizer, + num_examine=num_examine, + compute_score=compute_score, + reward_fn_key=config.data.reward_fn_key, + **reward_kwargs, + ) + + +def compute_reward(data: DataProto, reward_fn): + """ + Compute reward for a batch of data. + Args: + data: DataProto object containing the input data. + reward_fn: Reward function to compute the reward. + Returns: + Tuple of reward tensor and extra info dictionary. + """ + try: + reward_result = reward_fn(data, return_dict=True) + reward_tensor = reward_result["reward_tensor"] + reward_extra_infos_dict = reward_result["reward_extra_info"] + except Exception as e: + print(f"Error in reward_fn: {e}") + reward_tensor = reward_fn(data) + reward_extra_infos_dict = {} + + return reward_tensor, reward_extra_infos_dict + + +@ray.remote(num_cpus=1) +def compute_reward_async(data: DataProto, config, tokenizer): + """ + Load the reward manager and compute the reward for a batch of data. + This is meant to be run in a separate Ray worker. + """ + reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})) + return compute_reward(data, reward_fn) diff --git a/verl/trainer/runtime_env.yaml b/verl/trainer/runtime_env.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fc40fe9300dad33610849d44f834080e36a65e0e --- /dev/null +++ b/verl/trainer/runtime_env.yaml @@ -0,0 +1,6 @@ +working_dir: ./ +excludes: ["/.git/"] +env_vars: + TORCH_NCCL_AVOID_RECORD_STREAMS: "1" + # If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: + # VLLM_ATTENTION_BACKEND: "XFORMERS" \ No newline at end of file diff --git a/verl/utils/__init__.py b/verl/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..833f90d394821cb8b00ce9cdb88f9cfbf3cb52f3 --- /dev/null +++ b/verl/utils/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import tokenizer +from .tokenizer import hf_processor, hf_tokenizer + +__all__ = tokenizer.__all__ + ["hf_processor", "hf_tokenizer"] diff --git a/verl/utils/__pycache__/__init__.cpython-311.pyc b/verl/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ce36d4edcd332024c615809f12e08a9931023ea Binary files /dev/null and b/verl/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/distributed.cpython-311.pyc b/verl/utils/__pycache__/distributed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03a5fd37696caded6d8cde4307207151105529dd Binary files /dev/null and b/verl/utils/__pycache__/distributed.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/flops_counter.cpython-311.pyc b/verl/utils/__pycache__/flops_counter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..058e821ac88af8abdb9395a03e3fa51c21870cda Binary files /dev/null and b/verl/utils/__pycache__/flops_counter.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/fs.cpython-311.pyc b/verl/utils/__pycache__/fs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0c60b726cf43d3ec344014b08ed8af23ab8c2c9 Binary files /dev/null and b/verl/utils/__pycache__/fs.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/fsdp_utils.cpython-311.pyc b/verl/utils/__pycache__/fsdp_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0472cf91eadbd1303346e442e25d0f6af504f22c Binary files /dev/null and b/verl/utils/__pycache__/fsdp_utils.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/hdfs_io.cpython-311.pyc b/verl/utils/__pycache__/hdfs_io.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e2029b1847823438c4459c1ae1fa05e9d8fc76d Binary files /dev/null and b/verl/utils/__pycache__/hdfs_io.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/import_utils.cpython-311.pyc b/verl/utils/__pycache__/import_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7efb68d840c47b6f3f825dd0b6a853923f621ba5 Binary files /dev/null and b/verl/utils/__pycache__/import_utils.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/logging_utils.cpython-311.pyc b/verl/utils/__pycache__/logging_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81f25096c71d9a779cdb8ebdb3b3308788c1252d Binary files /dev/null and b/verl/utils/__pycache__/logging_utils.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/model.cpython-311.pyc b/verl/utils/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a2514d79918c048df3229a5a81c3da777a9c4b Binary files /dev/null and b/verl/utils/__pycache__/model.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/py_functional.cpython-311.pyc b/verl/utils/__pycache__/py_functional.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89aab5bfebc1cf02d38a4a64fe14e40756975dbf Binary files /dev/null and b/verl/utils/__pycache__/py_functional.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/seqlen_balancing.cpython-311.pyc b/verl/utils/__pycache__/seqlen_balancing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fd06352e3d9f7d1f15976f42760b92657d33a94 Binary files /dev/null and b/verl/utils/__pycache__/seqlen_balancing.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/tokenizer.cpython-311.pyc b/verl/utils/__pycache__/tokenizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..235813c2b58ebfcbbc74c2a619cbaa80c32805c9 Binary files /dev/null and b/verl/utils/__pycache__/tokenizer.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/torch_dtypes.cpython-311.pyc b/verl/utils/__pycache__/torch_dtypes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..361e0d7465f40c9c6c979f8046ff2db83bfe3ebb Binary files /dev/null and b/verl/utils/__pycache__/torch_dtypes.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/torch_functional.cpython-311.pyc b/verl/utils/__pycache__/torch_functional.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..721e69ebf8b365b8b226c19b12bf012e53b610ed Binary files /dev/null and b/verl/utils/__pycache__/torch_functional.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/tracking.cpython-311.pyc b/verl/utils/__pycache__/tracking.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..def359e1336fb03ef75952b1d7ac0dba5d531fee Binary files /dev/null and b/verl/utils/__pycache__/tracking.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/ulysses.cpython-311.pyc b/verl/utils/__pycache__/ulysses.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0624b3e6d664f2cb5da7c42d17aec17c9e0e46f Binary files /dev/null and b/verl/utils/__pycache__/ulysses.cpython-311.pyc differ diff --git a/verl/utils/__pycache__/vllm_utils.cpython-311.pyc b/verl/utils/__pycache__/vllm_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4869e15a3a78a5c7e934c5ae8a53eecad6b3dce Binary files /dev/null and b/verl/utils/__pycache__/vllm_utils.cpython-311.pyc differ diff --git a/verl/utils/checkpoint/__init__.py b/verl/utils/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6aedde1437eddbe099b730519ef79980e2e3dc4 --- /dev/null +++ b/verl/utils/checkpoint/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/utils/checkpoint/__pycache__/__init__.cpython-311.pyc b/verl/utils/checkpoint/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db44c6446850cdc7e735379bdbda16950070b7b0 Binary files /dev/null and b/verl/utils/checkpoint/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/utils/checkpoint/__pycache__/checkpoint_manager.cpython-311.pyc b/verl/utils/checkpoint/__pycache__/checkpoint_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaad5efc391e6c228724a70745716d7b1c34ce27 Binary files /dev/null and b/verl/utils/checkpoint/__pycache__/checkpoint_manager.cpython-311.pyc differ diff --git a/verl/utils/checkpoint/__pycache__/fsdp_checkpoint_manager.cpython-311.pyc b/verl/utils/checkpoint/__pycache__/fsdp_checkpoint_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..047f7ee1c2d033d6db9e94866efb22e708c7ba7a Binary files /dev/null and b/verl/utils/checkpoint/__pycache__/fsdp_checkpoint_manager.cpython-311.pyc differ diff --git a/verl/utils/checkpoint/checkpoint_manager.py b/verl/utils/checkpoint/checkpoint_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..7177ccc5a46207d2402555eae7d1e041b0e36fd8 --- /dev/null +++ b/verl/utils/checkpoint/checkpoint_manager.py @@ -0,0 +1,148 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +import shutil +import tempfile +from typing import Optional, Union + +import numpy as np +import torch +import torch.distributed +from filelock import FileLock +from transformers import PreTrainedTokenizer, ProcessorMixin + + +class BaseCheckpointManager: + """ + A checkpoint manager that saves and loads + - model + - optimizer + - lr_scheduler + - extra_states + in a SPMD way. + + We save + - sharded model states and optimizer states + - full lr_scheduler states + - huggingface tokenizer and config for ckpt merge + """ + + def __init__( + self, + model, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None, + processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None, + checkpoint_contents: Optional[list] = None, + ): + if checkpoint_contents is None: + checkpoint_contents = ["model", "optimizer", "extra"] + self.previous_global_step = None + self.previous_saved_paths = [] + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.processing_class = processing_class + self.checkpoint_contents = checkpoint_contents + + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False): + raise NotImplementedError + + def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep: int = None): + raise NotImplementedError + + @staticmethod + def checkpath(local_path: str, hdfs_path: str): + assert local_path is not None or hdfs_path is not None, "local_path and hdfs_path cannot be both None" + return local_path is not None, local_path if local_path is not None else hdfs_path + + def remove_previous_save_local_path(self, path): + if isinstance(path, str): + path = [path] + for p in path: + abs_path = os.path.abspath(p) + print(f"Checkpoint manager remove previous save local path: {abs_path}") + if not os.path.exists(abs_path): + continue + shutil.rmtree(abs_path, ignore_errors=True) + + @staticmethod + def local_mkdir(path): + if not os.path.isabs(path): + working_dir = os.getcwd() + path = os.path.join(working_dir, path) + + # Using hash value of path as lock file name to avoid long file name + lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock" + lock_path = os.path.join(tempfile.gettempdir(), lock_filename) + + try: + with FileLock(lock_path, timeout=60): # Add timeout + # make a new dir + os.makedirs(path, exist_ok=True) + except Exception as e: + print(f"Warning: Failed to acquire lock for {path}: {e}") + # Even if the lock is not acquired, try to create the directory + os.makedirs(path, exist_ok=True) + + return path + + @staticmethod + def get_rng_state(): + rng_state = { + "cpu": torch.get_rng_state(), + "cuda": torch.cuda.get_rng_state(), + "numpy": np.random.get_state(), + "random": random.getstate(), + } + return rng_state + + @staticmethod + def load_rng_state(rng_state): + torch.set_rng_state(rng_state["cpu"]) + torch.cuda.set_rng_state(rng_state["cuda"]) + np.random.set_state(rng_state["numpy"]) + random.setstate(rng_state["random"]) + + +def find_latest_ckpt_path(path, directory_format="global_step_{}"): + if path is None: + return None + + tracker_file = get_checkpoint_tracker_filename(path) + if not os.path.exists(tracker_file): + print("Checkpoint tracker file does not exist: %s", tracker_file) + return None + + with open(tracker_file, "rb") as f: + iteration = int(f.read().decode()) + ckpt_path = os.path.join(path, directory_format.format(iteration)) + if not os.path.exists(ckpt_path): + print("Checkpoint does not exist: %s", ckpt_path) + return None + + print("Found checkpoint: %s", ckpt_path) + return ckpt_path + + +def get_checkpoint_tracker_filename(root_path: str): + """ + Tracker file rescords the latest chckpoint during training to restart from. + """ + return os.path.join(root_path, "latest_checkpointed_iteration.txt") diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..36b2c56afe6ab4472ddb678d164d5ff19918b837 --- /dev/null +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -0,0 +1,147 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import warnings +from typing import Optional, Union + +import torch +import torch.distributed +from torch.distributed.fsdp import FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin + +from verl.utils.fs import copy_to_local, is_non_local +from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx + +from .checkpoint_manager import BaseCheckpointManager + + +class FSDPCheckpointManager(BaseCheckpointManager): + """ + A checkpoint manager that saves and loads + - model + - optimizer + - lr_scheduler + - extra_states + in a SPMD way. + + We save + - sharded model states and optimizer states + - full lr_scheduler states + - huggingface tokenizer/processor and config for ckpt merge + """ + + def __init__( + self, + model: FSDP, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None, + checkpoint_contents: Optional[list] = None, + **kwargs, + ): + if checkpoint_contents is None: + checkpoint_contents = ["model", "optimizer", "extra"] + if processing_class is None: + assert "tokenizer" in kwargs, "tokenizer or processor must be provided" + warnings.warn("`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2) + processing_class = kwargs.pop("tokenizer") + assert "model" in checkpoint_contents and "optimizer" in checkpoint_contents and "extra" in checkpoint_contents, f"FSDPCheckpointManager must include ['model', 'optimizer', 'extra'], got {checkpoint_contents}" + + super().__init__( + model, + optimizer, + lr_scheduler=lr_scheduler, + processing_class=processing_class, + checkpoint_contents=checkpoint_contents, + ) + + def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): + if local_path is None: + return + + # every rank download its own checkpoint + remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") + remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") + remote_extra_state_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") + print(f"[rank-{self.rank}]: Loading from {remote_model_path} and {remote_optim_path} and {remote_extra_state_path}") + local_model_path = copy_to_local(remote_model_path) + local_optim_path = copy_to_local(remote_optim_path) + local_extra_state_path = copy_to_local(remote_extra_state_path) + + model_state_dict = torch.load(local_model_path, weights_only=False) + optimizer_state_dict = torch.load(local_optim_path, weights_only=False) + extra_state_dict = torch.load(local_extra_state_path, weights_only=False) + + if del_local_after_load: + try: + os.remove(local_model_path) if is_non_local(local_model_path) else None + os.remove(local_optim_path) if is_non_local(local_optim_path) else None + os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None + except Exception as e: + print(f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored") + + lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] + + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) + with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): + self.model.load_state_dict(model_state_dict) + if self.optimizer is not None: + self.optimizer.load_state_dict(optimizer_state_dict) + # recover random state + if "rng" in extra_state_dict: + # 'rng' may not exist for backward compatibility + self.load_rng_state(extra_state_dict["rng"]) + + if self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) + + def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): + # record the previous global step + self.previous_global_step = global_step + # only support save and load ckpt for actor + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + experiment_dir = os.path.dirname(local_path) + if self.rank == 0: + if os.path.exists(experiment_dir): + subdirs = [name for name in os.listdir(experiment_dir) if os.path.isdir(os.path.join(experiment_dir, name))] + for name in subdirs: + full_path = os.path.join(experiment_dir, name) + shutil.rmtree(full_path) + + os.makedirs(local_path, exist_ok=True) + from torch.distributed.fsdp import FullStateDictConfig, StateDictType + torch.distributed.barrier() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + with FSDP.state_dict_type( + self.model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + ): + state_dict = self.model.state_dict() + model_path = os.path.join(local_path, f'model.pt') + if self.rank == 0: + torch.save(state_dict, model_path) + + print("\n" + "="*60) + print(f"βœ…βœ…βœ… SUCCESS: Model saved βœ…βœ…βœ…") + print("="*60 + "\n") + + torch.distributed.barrier() + diff --git a/verl/utils/checkpoint/megatron_checkpoint_manager.py b/verl/utils/checkpoint/megatron_checkpoint_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..06bcbf4785cee51e1c731752072069a30e746273 --- /dev/null +++ b/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -0,0 +1,313 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +from typing import Optional + +import numpy as np +import torch +import torch.distributed +from megatron.core import mpu, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedObject + +from verl.models.weight_loader_registry import get_weight_saver +from verl.utils.fs import is_non_local +from verl.utils.megatron_utils import ( + get_hf_model_checkpoint_path, + get_model_checkpoint_path, + get_optimizer_checkpoint_path, + get_rng_states_checkpoint_path, +) + +from .checkpoint_manager import BaseCheckpointManager + + +class MegatronCheckpointManager(BaseCheckpointManager): + """ + A checkpoint manager that saves and loads + - model + - optimizer + - lr_scheduler + - extra_states + in a SPMD way. + + We save + - sharded model states and optimizer states + - full lr_scheduler states + - huggingface tokenizer/processor and config for ckpt merge + """ + + def __init__( + self, + config, + model_config, + role, + model: torch.nn.ModuleList, + arch: str, + hf_config, + param_dtype: torch.dtype, + share_embeddings_and_output_weights: bool, + tokenizer, + optimizer, + use_distributed_optimizer: bool, + checkpoint_contents: Optional[list] = None, + **kwargs, + ): + if checkpoint_contents is None: + checkpoint_contents = ["model", "optimizer", "extra"] + super().__init__( + model, + optimizer=optimizer, + lr_scheduler=None, + processing_class=tokenizer, + checkpoint_contents=checkpoint_contents, + ) + self.arch = arch + self.config = config + self.role = role + self.is_value_model = False + if self.role in ["reward", "critic"]: + self.is_value_model = True + self.model_config = model_config + self.hf_config = hf_config + self.param_dtype = param_dtype + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.model_path = self.config.model.path + self.use_distributed_optimizer = use_distributed_optimizer + + self.rank = torch.distributed.get_rank() + + self.weight_saver = get_weight_saver(self.arch) + + def get_rng_state(self, use_dist_ckpt: bool = False, data_parallel_random_init: bool = False): + """collect rng state across data parallel ranks""" + rng_state = { + "random_rng_state": random.getstate(), + "np_rng_state": np.random.get_state(), + "torch_rng_state": torch.get_rng_state(), + "cuda_rng_state": torch.cuda.get_rng_state(), + "rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(), + } + + rng_state_list = None + if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init: + rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())] + torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group()) + else: + rng_state_list = [rng_state] + + if use_dist_ckpt: + pp_rank = mpu.get_pipeline_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + cp_size = mpu.get_context_parallel_world_size() + rng_state_list = ShardedObject( + "rng_state", + rng_state_list, + (pp_size, tp_size, cp_size), + (pp_rank, tp_rank, cp_rank), + replica_id=mpu.get_data_parallel_rank(with_context_parallel=True), + ) + + return rng_state_list + + def get_checkpoint_name( + self, + checkpoints_path, + pipeline_parallel=None, + tensor_rank=None, + pipeline_rank=None, + cp_rank=None, + expert_parallel=None, + expert_rank=None, + return_base_dir=True, + basename="model.pt", + ): + """Determine the directory name for this rank's checkpoint.""" + # Use both the tensor and pipeline MP rank. + if pipeline_parallel is None: + pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1 + if tensor_rank is None: + tensor_rank = mpu.get_tensor_model_parallel_rank() + if pipeline_rank is None: + pipeline_rank = mpu.get_pipeline_model_parallel_rank() + if cp_rank is None: + cp_rank = mpu.get_context_parallel_rank() + if expert_parallel is None: + expert_parallel = mpu.get_expert_model_parallel_world_size() > 1 + if expert_rank is None: + expert_rank = mpu.get_expert_model_parallel_rank() + + # Use both the tensor and pipeline MP rank. If using the distributed + # optimizer, then the optimizer's path must additionally include the + # data parallel rank. + + # due to the fact that models are identical across cp ranks, cp rank is not used in the checkpoint path + if not pipeline_parallel: + common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}") + else: + common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}") + + if expert_parallel: + common_path = common_path + f"_{expert_rank:03d}" + + os.makedirs(common_path, exist_ok=True) + + if return_base_dir: + return common_path + return os.path.join(common_path, basename) + + def load_optimizer(self, ckpt_path): + # TODO: Check Optimizer format and distributed optimizer + optimizer_path = get_optimizer_checkpoint_path(ckpt_path) + print(f"Loading optimizer from {optimizer_path}") + self.optimizer.load_parameter_state(optimizer_path) + + def load_rng_states(self, ckpt_path, data_parallel_random_init=False, use_dist_ckpt=False): + rng_state_path = get_rng_states_checkpoint_path(ckpt_path, only_rank0_save=False) + print(f"Loading rng states from {rng_state_path}") + rng_state = torch.load(rng_state_path, weights_only=False) + # access rng_state for data parallel rank + if not use_dist_ckpt: + rng_state = rng_state[mpu.get_data_parallel_rank()] if data_parallel_random_init else rng_state[0] + random.setstate(rng_state["random_rng_state"]) + np.random.set_state(rng_state["np_rng_state"]) + torch.set_rng_state(rng_state["torch_rng_state"]) + torch.cuda.set_rng_state(rng_state["cuda_rng_state"]) + # Check for empty states array + if not rng_state["rng_tracker_states"]: + raise KeyError + tensor_parallel.get_cuda_rng_tracker().set_states(rng_state["rng_tracker_states"]) + + def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): + if local_path is None: + return + + if "model" in self.checkpoint_contents: + model_path = get_model_checkpoint_path(local_path) + ckpt_name = self.get_checkpoint_name(model_path, return_base_dir=False) + state_dicts = torch.load(os.path.join(ckpt_name), weights_only=False) + assert len(state_dicts) == len(self.model), f"state_dicts length: {len(state_dicts)} mismatch with model length: {len(self.model)}" + for vpp_rank, (state_dict, model) in enumerate(zip(state_dicts, self.model)): + model.load_state_dict(state_dict) + print(f"Loaded sharded model checkpoint from {model_path}") + + if "optimizer" in self.checkpoint_contents: + self.load_optimizer(local_path) + + if "extra" in self.checkpoint_contents: + self.load_rng_states(local_path) + + if del_local_after_load: + try: + os.remove(local_path) if is_non_local(local_path) else None + except Exception as e: + print(f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored") + + def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): + # record the previous global step + self.previous_global_step = global_step + + # remove previous local_path + if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep: + keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 + self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) + self.previous_saved_paths = self.previous_saved_paths[keep_start:] + + local_path = self.local_mkdir(local_path) + + # Save Model + if "model" in self.checkpoint_contents and mpu.get_data_parallel_rank() == 0: + state_dicts = [] + + for vpp_rank, model in enumerate(self.model): + state_dict = model.state_dict() + state_dicts.append(state_dict) + + print(f"Saving sharded model checkpoint to {local_path}") + model_ckpt_path = get_model_checkpoint_path(local_path) + hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) + ckpt_name = self.get_checkpoint_name(model_ckpt_path, return_base_dir=False) + torch.save(state_dicts, os.path.join(ckpt_name)) + self.processing_class.save_pretrained(hf_model_ckpt_path) # tokenizer will be saved to hf_model_ckpt_path + print(f"Saved checkpoint to {model_ckpt_path}") + if hdfs_path is not None: + print(f"Uploading checkpoint to {hdfs_path}") + from verl.utils import hdfs_io + + hdfs_io.makedirs(hdfs_path, exist_ok=True) + hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) + + if "hf_model" in self.checkpoint_contents: + # wait for everyone to dump to local + state_dict = self.weight_saver( + self.model, + self.hf_config, + dtype=self.param_dtype, + is_value_model=self.is_value_model, + tie_word_embeddings=self.share_embeddings_and_output_weights, + ) + + torch.distributed.barrier() + print(f"self.param_dtype: {self.param_dtype}") + for key in state_dict.keys(): + print(f"state_dict[key].dtype: {key} {state_dict[key].dtype}") + torch.distributed.barrier() + if self.rank == 0: + hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) + import warnings + + from accelerate import init_empty_weights + + with init_empty_weights(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + if "mistral7b-rm" in self.config.model.path: + from transformers import MistralForSequenceClassification + + model = MistralForSequenceClassification.from_pretrained(self.config.model.path) # use score head instead of lm_head + state_dict["score.weight"] = state_dict["score.weight"] + else: + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto") + model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) + if hdfs_path is not None: + print(f"Uploading checkpoint to {hdfs_path}") + from verl.utils import hdfs_io + + hdfs_io.makedirs(hdfs_path, exist_ok=True) + hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) + + # Save Optimizer + if "optimizer" in self.checkpoint_contents: + torch.distributed.barrier() + + optimizer_path = get_optimizer_checkpoint_path(local_path) + self.optimizer.save_parameter_state(optimizer_path) + if self.rank == 0: + print(f"saving optimizer state to {optimizer_path}") + + # Save RNG States + if "extra" in self.checkpoint_contents: + torch.distributed.barrier() + + rng_state_path = get_rng_states_checkpoint_path(local_path, only_rank0_save=False) + rng_state = self.get_rng_state() + torch.save(rng_state, rng_state_path) + print(f"Rank {self.rank} saving rng states to {rng_state_path}") + + self.previous_saved_paths.append(local_path) diff --git a/verl/utils/config.py b/verl/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1d27f94b2012db2fc9366e44025b3b547512ef7e --- /dev/null +++ b/verl/utils/config.py @@ -0,0 +1,23 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +from omegaconf import DictConfig + + +def update_dict_with_config(dictionary: Dict, config: DictConfig): + for key in dictionary: + if hasattr(config, key): + dictionary[key] = getattr(config, key) diff --git a/verl/utils/dataset/README.md b/verl/utils/dataset/README.md new file mode 100644 index 0000000000000000000000000000000000000000..989ac37e00b6bfd46364c9d9d329d1d104cb4dd0 --- /dev/null +++ b/verl/utils/dataset/README.md @@ -0,0 +1,16 @@ +# Dataset Format +## RLHF dataset +We combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers. + +Math problems +```json +{ + "data_source": "openai/gsm8k", + "prompt": [{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\""}], + "ability": "math", + "reward_model": { + "style": "rule", + "ground_truth": ["72"] + }, +} +``` diff --git a/verl/utils/dataset/__init__.py b/verl/utils/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fbeb5822defc3417594be9b7102b0d44db2ab71 --- /dev/null +++ b/verl/utils/dataset/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .rl_dataset import RLHFDataset +from .rm_dataset import RMDataset +from .sft_dataset import SFTDataset + +__all__ = ["RLHFDataset", "RMDataset", "SFTDataset"] diff --git a/verl/utils/dataset/__pycache__/__init__.cpython-311.pyc b/verl/utils/dataset/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fd78a1eff1a06c73232953e485d0c2db89d6f59 Binary files /dev/null and b/verl/utils/dataset/__pycache__/__init__.cpython-311.pyc differ diff --git a/verl/utils/dataset/__pycache__/multiturn_sft_dataset.cpython-311.pyc b/verl/utils/dataset/__pycache__/multiturn_sft_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a3e67040ec8679f79eb8c60d0539a112a8a5ca6 Binary files /dev/null and b/verl/utils/dataset/__pycache__/multiturn_sft_dataset.cpython-311.pyc differ diff --git a/verl/utils/dataset/__pycache__/rl_dataset.cpython-311.pyc b/verl/utils/dataset/__pycache__/rl_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb6bfa179440925451d12fdb1295a2c909342eb7 Binary files /dev/null and b/verl/utils/dataset/__pycache__/rl_dataset.cpython-311.pyc differ diff --git a/verl/utils/dataset/__pycache__/rm_dataset.cpython-311.pyc b/verl/utils/dataset/__pycache__/rm_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f8f9efa409d00d012e716f0edb35847d747edc0 Binary files /dev/null and b/verl/utils/dataset/__pycache__/rm_dataset.cpython-311.pyc differ diff --git a/verl/utils/dataset/__pycache__/sft_dataset.cpython-311.pyc b/verl/utils/dataset/__pycache__/sft_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16b3f52a81cf0fc8190d5c4a2cc91c0304d39449 Binary files /dev/null and b/verl/utils/dataset/__pycache__/sft_dataset.cpython-311.pyc differ diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..444dc8b38f942a83589a71d9c893c55493ad1838 --- /dev/null +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -0,0 +1,146 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Multi-turn SFT dataset that supports training on conversation data with multiple turns +""" + +from typing import List, Union + +import pandas as pd +import torch +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer + +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_local_path_from_hdfs + + +class MultiTurnSFTDataset(Dataset): + """ + Dataset for multi-turn conversations where each assistant response should be trained + """ + + def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config=None): + # Set defaults and extract parameters from config if provided + config = config or {} + self.truncation = config.get("truncation", "error") + self.max_length = config.get("max_length", 1024) + # Get messages_key from the new multiturn config structure + multiturn_config = config.get("multiturn", {}) + self.messages_key = multiturn_config.get("messages_key", "messages") + + assert self.truncation in ["error", "left", "right"] + + if not isinstance(parquet_files, List): + parquet_files = [parquet_files] + + self.parquet_files = parquet_files + if isinstance(tokenizer, str): + tokenizer = hf_tokenizer(tokenizer) + self.tokenizer: PreTrainedTokenizer = tokenizer + + self._download() + self._read_files_and_process() + + def _download(self): + for i, parquet_file in enumerate(self.parquet_files): + self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) + + def _read_files_and_process(self): + def series_to_item(ls): + import numpy + import pandas + + while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: + ls = ls[0] + return ls + + dataframes = [] + for parquet_file in self.parquet_files: + dataframe = pd.read_parquet(parquet_file) + dataframes.append(dataframe) + self.dataframe = pd.concat(dataframes) + + # Extract messages list from dataframe + self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist() + + def __len__(self): + return len(self.messages) + + def __getitem__(self, item): + tokenizer = self.tokenizer + messages = self.messages[item] + + # First, get the full conversation tokens + full_tokens = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=False) + input_ids = full_tokens[0] # The output is already a tensor + attention_mask = torch.ones_like(input_ids) + + # Create loss mask by identifying assistant responses + loss_mask = torch.zeros_like(input_ids, dtype=torch.long) + + # Process each message to find assistant responses + for i, msg in enumerate(messages): + # Get tokens for messages up to this point to find the start position + prefix_messages = messages[: i + 1] + prefix_tokens = tokenizer.apply_chat_template(prefix_messages, tokenize=True, return_tensors="pt", add_generation_prompt=False) + + # Get tokens for messages up to previous point + prev_tokens = tokenizer.apply_chat_template(messages[:i], tokenize=True, return_tensors="pt", add_generation_prompt=False) if i > 0 else None + + # Calculate start and end positions + start_pos = prev_tokens[0].shape[0] if prev_tokens is not None else 0 + end_pos = prefix_tokens[0].shape[0] + + # If this is an assistant message, set loss mask + if msg["role"] == "assistant": + loss_mask[start_pos:end_pos] = 1 + + # Handle sequence length + sequence_length = input_ids.shape[0] + if sequence_length < self.max_length: + # Pad sequences + pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) * pad_token_id + padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) + padded_loss_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=loss_mask.dtype) + + input_ids = torch.cat((input_ids, padded_input_ids)) + attention_mask = torch.cat((attention_mask, padded_attention_mask)) + loss_mask = torch.cat((loss_mask, padded_loss_mask)) + elif sequence_length > self.max_length: + if self.truncation == "left": + input_ids = input_ids[-self.max_length :] + attention_mask = attention_mask[-self.max_length :] + loss_mask = loss_mask[-self.max_length :] + elif self.truncation == "right": + input_ids = input_ids[: self.max_length] + attention_mask = attention_mask[: self.max_length] + loss_mask = loss_mask[: self.max_length] + elif self.truncation == "error": + raise ValueError(f"{sequence_length=} is larger than {self.max_length=}") + else: + raise ValueError(f"Unknown truncation method {self.truncation}") + + # Create position IDs + position_ids = torch.arange(len(input_ids), dtype=torch.long) + # Zero out position IDs for padding + position_ids = position_ids * attention_mask + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "loss_mask": loss_mask, + } diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2d416934ee0ee7fe9a6b428390247e1753219e6a --- /dev/null +++ b/verl/utils/dataset/rl_dataset.py @@ -0,0 +1,264 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import os +import re +from collections import defaultdict +from typing import List, Optional, Union + +import datasets +import numpy as np +import torch +from omegaconf import DictConfig, ListConfig +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer, ProcessorMixin + +import verl.utils.torch_functional as verl_F +from verl.utils.model import compute_position_id_with_mask + +logger = logging.getLogger(__name__) + + +def collate_fn(data_list: list[dict]) -> dict: + """Collate a batch of data.""" + tensors = defaultdict(list) + non_tensors = defaultdict(list) + + for data in data_list: + for key, val in data.items(): + if isinstance(val, torch.Tensor): + tensors[key].append(val) + else: + non_tensors[key].append(val) + + for key, val in tensors.items(): + tensors[key] = torch.stack(val, dim=0) + + for key, val in non_tensors.items(): + non_tensors[key] = np.array(val, dtype=object) + + return {**tensors, **non_tensors} + + +class RLHFDataset(Dataset): + """ + We assume the dataset contains a column that contains prompts and other information + """ + + def __init__( + self, + data_files: Union[str, List[str]], + tokenizer: PreTrainedTokenizer, + config: DictConfig, + processor: Optional[ProcessorMixin] = None, + ): + if not isinstance(data_files, (List, ListConfig)): + data_files = [data_files] + + self.data_files = copy.deepcopy(data_files) + self.original_data_files = copy.deepcopy(data_files) # use for resume + self.tokenizer = tokenizer + self.processor = processor + self.config = config + + self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) + self.prompt_key = config.get("prompt_key", "prompt") + self.image_key = config.get("image_key", "images") + self.video_key = config.get("video_key", "videos") + self.max_prompt_length = config.get("max_prompt_length", 1024) + self.return_raw_chat = config.get("return_raw_chat", False) + self.truncation = config.get("truncation", "error") + self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) + + self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) + self.num_workers = min(self.num_workers, os.cpu_count()) + self.chat_template_func = config.get("chat_template_func", None) + self.need_tools_kwargs = config.get("need_tools_kwargs", False) + self.filter_prompts = config.get("filter_prompts", True) + self.serialize_dataset = False + self._download() + self._read_files_and_tokenize() + + def _download(self, use_origin_parquet=False): + from verl.utils.fs import copy_to_local + + data_files = self.data_files if not use_origin_parquet else self.original_data_files + for i, parquet_file in enumerate(data_files): + self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir) + + def _read_files_and_tokenize(self): + dataframes = [] + for parquet_file in self.data_files: + # read parquet files and cache + dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + dataframes.append(dataframe) + self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) + + print(f"dataset len: {len(self.dataframe)}") + + # filter out too long prompts + if self.filter_overlong_prompts: + tokenizer = self.tokenizer + prompt_key = self.prompt_key + self.dataframe = self.dataframe.filter( + lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length, + num_proc=self.num_workers, + desc=f"Filtering prompts longer than {self.max_prompt_length} tokens", + ) + + print(f"filter dataset len: {len(self.dataframe)}") + + def resume_dataset_state(self): + self.serialize_dataset = not hasattr(self, "original_data_files") + # resume dataframe if not it's serialized in data.pt + if not self.serialize_dataset: + self._download(use_origin_parquet=True) # download and resume from original parquet files + self._read_files_and_tokenize() + else: + print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") + + def __len__(self): + return len(self.dataframe) + + def _build_messages(self, example: dict): + messages: list = example.pop(self.prompt_key) + + if self.image_key in example or self.video_key in example: + for message in messages: + content = message["content"] + content_list = [] + for segment in re.split("(|