#!/usr/bin/env python # Copyright 2024 Bytedance # Apache-2.0 # # VERL + vLLM inference with runtime LoRA (no merge). # - Wraps a LoRA .pt into a PEFT adapter and attaches via rollout.lora_modules # - Mixed precision defaults for H100: dtype=bf16, kv_cache_dtype=fp8_e5m2 # - Pins max_model_len, max_num_batched_tokens, sets swap_space # - Uses OmegaConf.open_dict to add keys safely (no "not in struct" errors) # - Prevents FSDP from trying to load LoRA .pt as a full model import os import ast import json import hydra import numpy as np import ray import torch from pathlib import Path from pprint import pprint # Quiet logs os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "WARN") os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get("TOKENIZERS_PARALLELISM", "true") # vLLM CuMem allocator is incompatible with expandable_segments _bad = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") if "expandable_segments:True" in _bad: print(f"[fix] Removing incompatible PYTORCH_CUDA_ALLOC_CONF={_bad}") os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None) import pandas as pd from omegaconf import OmegaConf, open_dict from verl import DataProto from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.utils import hf_tokenizer from verl.utils.fs import copy_to_local from verl.utils.hdfs_io import makedirs from verl.utils.model import compute_position_id_with_mask from verl.workers.fsdp_workers import ActorRolloutRefWorker # ---------------- LoRA helpers ---------------- DEFAULT_TARGET_MODULES = [ "q_proj","k_proj","v_proj","o_proj", "up_proj","gate_proj","down_proj", ] def _infer_lengths_and_defaults(config): """Ensure rollout/data keys exist and set reasonable H100 defaults.""" # Ensure nested structs exist with open_dict(config): if "rollout" not in config: config["rollout"] = OmegaConf.create() if "data" not in config: config["data"] = OmegaConf.create() if "trainer" not in config: config["trainer"] = OmegaConf.create() if "ray_init" not in config: config["ray_init"] = OmegaConf.create() # Defaults that work on a single H100 with open_dict(config.rollout): # If user didn't set these, choose H100-friendly defaults config.rollout.setdefault("dtype", "bfloat16") # weights/activations config.rollout.setdefault("kv_cache_dtype", "fp8_e5m2") # KV cache precision config.rollout.setdefault("tensor_model_parallel_size", 1) config.rollout.setdefault("enable_chunked_prefill", True) config.rollout.setdefault("swap_space", 8) # GB of host swap for KV config.rollout.setdefault("gpu_memory_utilization", 0.62) # adjust 0.60~0.75 if needed # Pin lengths to avoid vLLM over-reserving KV cache pl = int(config.rollout.get("prompt_length", 1024)) rl = int(config.rollout.get("response_length", 128)) need = int(pl + rl) config.rollout.setdefault("max_model_len", need) config.rollout.setdefault("max_num_batched_tokens", need) # Users may pass +rollout.quantization={fp8|awq|gptq} to shrink weights further # We don't force it here. with open_dict(config.data): config.data.setdefault("batch_size", 1) config.data.setdefault("n_samples", 1) config.data.setdefault("prompt_key", "prompt") with open_dict(config.trainer): config.trainer.setdefault("n_gpus_per_node", 1) config.trainer.setdefault("nnodes", 1) with open_dict(config.ray_init): config.ray_init.setdefault("num_cpus", 4) def _infer_lora_rank_from_state(sd): for k, v in sd.items(): if k.endswith("lora_A.weight") and hasattr(v, "dim") and v.dim() == 2: return int(v.shape[0]) return None def _list_target_modules_from_state(sd): found = set() for k in sd.keys(): if "lora_A.weight" in k or "lora_B.weight" in k: if ".q_proj." in k: found.add("q_proj") if ".k_proj." in k: found.add("k_proj") if ".v_proj." in k: found.add("v_proj") if ".o_proj." in k: found.add("o_proj") if ".up_proj." in k: found.add("up_proj") if ".gate_proj." in k: found.add("gate_proj") if ".down_proj." in k: found.add("down_proj") return sorted(found) def _write_adapter_config(adapter_dir, r, alpha, target_modules, dropout=0.0): cfg = { "peft_type": "LORA", "auto_mapping": None, "base_model_name_or_path": "", "bias": "none", "inference_mode": True, "lora_alpha": int(alpha), "lora_dropout": float(dropout), "r": int(r), "target_modules": target_modules, "task_type": "CAUSAL_LM", } with open(os.path.join(adapter_dir, "adapter_config.json"), "w", encoding="utf-8") as f: json.dump(cfg, f, ensure_ascii=False, indent=2) def _wrap_lora_pt_as_peft(adapter_pt_path: str, out_dir: str, fallback_rank=32, fallback_alpha=16): os.makedirs(out_dir, exist_ok=True) print(f"[lora] Loading LoRA state from: {adapter_pt_path}") sd = torch.load(adapter_pt_path, map_location="cpu") if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"] r = _infer_lora_rank_from_state(sd) or int(fallback_rank) tmods = _list_target_modules_from_state(sd) or DEFAULT_TARGET_MODULES print(f"[lora] inferred rank={r}, target_modules={tmods}") _write_adapter_config(out_dir, r=r, alpha=fallback_alpha, target_modules=tmods) torch.save(sd, os.path.join(out_dir, "adapter_model.bin")) return r, tmods def _maybe_attach_lora_adapter(config): """Attach LoRA adapter directory to vLLM rollout (runtime LoRA).""" # Accept either +lora.pt_path or model.load_param_path as a hint lora_pt = None if "lora" in config and getattr(config.lora, "pt_path", ""): lora_pt = config.lora.pt_path elif getattr(config.model, "load_param_path", ""): lora_pt = config.model.load_param_path if not lora_pt or not Path(lora_pt).is_file(): print("[lora] No LoRA .pt provided; running base model only.") return adapter_dir = os.path.join("/tmp", "lora_adapter_vllm") r, _ = _wrap_lora_pt_as_peft(lora_pt, adapter_dir, fallback_rank=32, fallback_alpha=16) # Ensure rollout keys exist and add LoRA knobs required by vLLM with open_dict(config): if "rollout" not in config: config["rollout"] = OmegaConf.create() with open_dict(config.rollout): config.rollout.setdefault("max_loras", 1) config.rollout.setdefault("max_lora_rank", int(r)) config.rollout["lora_modules"] = [{"path": adapter_dir, "scale": 1.0}] print(f"[lora] Attached PEFT adapter: {adapter_dir} (rank={r})") # CRITICAL: don't let FSDP try to load the LoRA .pt as a full state dict with open_dict(config.model): if getattr(config.model, "load_param", False): print("[lora] Disabling model.load_param to avoid FSDP load_state_dict mismatch.") config.model["load_param"] = False # ---------------- Hydra entry ---------------- @hydra.main(config_path="config", config_name="infer", version_base=None) def main(config): _infer_lengths_and_defaults(config) # Ray env for workers if not ray.is_initialized(): ray.init( runtime_env={"env_vars": { "TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "PYTORCH_CUDA_ALLOC_CONF": "", # keep allocator happy for vLLM }}, num_cpus=config.ray_init.num_cpus, ) ray.get(main_task.remote(config)) @ray.remote(num_cpus=1) def main_task(config): print("[worker] PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF")) pprint(OmegaConf.to_container(config, resolve=True)) OmegaConf.resolve(config) # Build LoRA adapter if provided _maybe_attach_lora_adapter(config) # Optionally pre-gen dataset schema if your repo provides it try: from prompts.infer_prompt import infer_dataset infer_dataset( model_name=config.model.path, data_path=os.path.dirname(os.path.dirname(config.data.path)), ) except Exception as e: print(f"[info] infer_dataset() skipped: {e}") # ---- Tokenizer from base model local_path = copy_to_local(config.model.path) trust_remote_code = getattr(config.model, "trust_remote_code", False) tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) tokenizer.padding_side = "left" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # ---- Sampling checks if float(config.rollout.temperature) == 0.0: assert int(config.data.n_samples) == 1, "When temperature=0, n_samples must be 1." assert int(config.data.n_samples) >= 1, "n_samples should always >= 1" # ---- Load dataset dataset = pd.read_parquet(config.data.path) prompt_key = getattr(config.data, "prompt_key", "prompt") if prompt_key not in dataset.columns: raise KeyError(f"Dataset missing column '{prompt_key}'") chat_lst = dataset[prompt_key].tolist() chat_lst = [chat.tolist() if hasattr(chat, "tolist") else chat for chat in chat_lst] # ---- Worker group (vLLM inside Rollout) ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) print("[debug] rollout.lora_modules =", config.rollout.get("lora_modules", None)) wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) wg.init_model() # vLLM spins up; adapter used if set in rollout.lora_modules total = len(dataset) bs = int(config.data.batch_size) num_batch = -(-total // bs) slots = [[] for _ in range(int(config.data.n_samples))] for b in range(num_batch): print(f"[{b+1}/{num_batch}] Start to process.") batch_chat = chat_lst[b * bs : (b + 1) * bs] inputs = tokenizer.apply_chat_template( batch_chat, add_generation_prompt=True, padding=True, truncation=True, max_length=int(config.rollout.prompt_length), return_tensors="pt", return_dict=True, tokenize=True, ) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] position_ids = compute_position_id_with_mask(attention_mask) batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} data = DataProto.from_dict(batch_dict) data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size) print(f"[{b+1}/{num_batch}] Start to generate.") for n in range(int(config.data.n_samples)): output_padded = wg.generate_sequences(data_padded) output = unpad_dataproto(output_padded, pad_size=pad_size) texts = [] for i in range(len(output)): item = output[i] pl = item.batch["prompts"].shape[-1] valid_len = item.batch["attention_mask"][pl:].sum() resp_ids = item.batch["responses"][:valid_len] s = tokenizer.decode(resp_ids, skip_special_tokens=True) print(f"[raw] Response {i}: {s!r}") ix = s.find("") if ix != -1: s = s[ix + len("") :].lstrip() print(f"Response {i}: {s!r}") try: texts.append(ast.literal_eval(s)) except Exception: texts.append(s) slots[n].extend(texts) outputs = np.array(slots, dtype=object) outputs = np.transpose(outputs, (1, 0)).tolist() dataset["response"] = outputs keep = ["file_id", "vt", "gt", "response"] cols = [c for c in keep if c in dataset.columns] if cols: dataset = dataset[cols] out_path = config.data.output_path makedirs(os.path.dirname(out_path), exist_ok=True) dataset.to_json(out_path, orient="records", lines=True, force_ascii=False) print(f"[done] Wrote: {out_path}") if __name__ == "__main__": main()