text2text / _infer.py
braindeck
Initial commit
bcdf9fa
#!/usr/bin/env python
# Copyright 2024 Bytedance
# Apache-2.0
#
# VERL + vLLM inference with runtime LoRA (no merge).
# - Wraps a LoRA .pt into a PEFT adapter and attaches via rollout.lora_modules
# - Mixed precision defaults for H100: dtype=bf16, kv_cache_dtype=fp8_e5m2
# - Pins max_model_len, max_num_batched_tokens, sets swap_space
# - Uses OmegaConf.open_dict to add keys safely (no "not in struct" errors)
# - Prevents FSDP from trying to load LoRA .pt as a full model
import os
import ast
import json
import hydra
import numpy as np
import ray
import torch
from pathlib import Path
from pprint import pprint
# Quiet logs
os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "WARN")
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get("TOKENIZERS_PARALLELISM", "true")
# vLLM CuMem allocator is incompatible with expandable_segments
_bad = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
if "expandable_segments:True" in _bad:
print(f"[fix] Removing incompatible PYTORCH_CUDA_ALLOC_CONF={_bad}")
os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None)
import pandas as pd
from omegaconf import OmegaConf, open_dict
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.hdfs_io import makedirs
from verl.utils.model import compute_position_id_with_mask
from verl.workers.fsdp_workers import ActorRolloutRefWorker
# ---------------- LoRA helpers ----------------
DEFAULT_TARGET_MODULES = [
"q_proj","k_proj","v_proj","o_proj",
"up_proj","gate_proj","down_proj",
]
def _infer_lengths_and_defaults(config):
"""Ensure rollout/data keys exist and set reasonable H100 defaults."""
# Ensure nested structs exist
with open_dict(config):
if "rollout" not in config:
config["rollout"] = OmegaConf.create()
if "data" not in config:
config["data"] = OmegaConf.create()
if "trainer" not in config:
config["trainer"] = OmegaConf.create()
if "ray_init" not in config:
config["ray_init"] = OmegaConf.create()
# Defaults that work on a single H100
with open_dict(config.rollout):
# If user didn't set these, choose H100-friendly defaults
config.rollout.setdefault("dtype", "bfloat16") # weights/activations
config.rollout.setdefault("kv_cache_dtype", "fp8_e5m2") # KV cache precision
config.rollout.setdefault("tensor_model_parallel_size", 1)
config.rollout.setdefault("enable_chunked_prefill", True)
config.rollout.setdefault("swap_space", 8) # GB of host swap for KV
config.rollout.setdefault("gpu_memory_utilization", 0.62) # adjust 0.60~0.75 if needed
# Pin lengths to avoid vLLM over-reserving KV cache
pl = int(config.rollout.get("prompt_length", 1024))
rl = int(config.rollout.get("response_length", 128))
need = int(pl + rl)
config.rollout.setdefault("max_model_len", need)
config.rollout.setdefault("max_num_batched_tokens", need)
# Users may pass +rollout.quantization={fp8|awq|gptq} to shrink weights further
# We don't force it here.
with open_dict(config.data):
config.data.setdefault("batch_size", 1)
config.data.setdefault("n_samples", 1)
config.data.setdefault("prompt_key", "prompt")
with open_dict(config.trainer):
config.trainer.setdefault("n_gpus_per_node", 1)
config.trainer.setdefault("nnodes", 1)
with open_dict(config.ray_init):
config.ray_init.setdefault("num_cpus", 4)
def _infer_lora_rank_from_state(sd):
for k, v in sd.items():
if k.endswith("lora_A.weight") and hasattr(v, "dim") and v.dim() == 2:
return int(v.shape[0])
return None
def _list_target_modules_from_state(sd):
found = set()
for k in sd.keys():
if "lora_A.weight" in k or "lora_B.weight" in k:
if ".q_proj." in k: found.add("q_proj")
if ".k_proj." in k: found.add("k_proj")
if ".v_proj." in k: found.add("v_proj")
if ".o_proj." in k: found.add("o_proj")
if ".up_proj." in k: found.add("up_proj")
if ".gate_proj." in k: found.add("gate_proj")
if ".down_proj." in k: found.add("down_proj")
return sorted(found)
def _write_adapter_config(adapter_dir, r, alpha, target_modules, dropout=0.0):
cfg = {
"peft_type": "LORA",
"auto_mapping": None,
"base_model_name_or_path": "",
"bias": "none",
"inference_mode": True,
"lora_alpha": int(alpha),
"lora_dropout": float(dropout),
"r": int(r),
"target_modules": target_modules,
"task_type": "CAUSAL_LM",
}
with open(os.path.join(adapter_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(cfg, f, ensure_ascii=False, indent=2)
def _wrap_lora_pt_as_peft(adapter_pt_path: str, out_dir: str,
fallback_rank=32, fallback_alpha=16):
os.makedirs(out_dir, exist_ok=True)
print(f"[lora] Loading LoRA state from: {adapter_pt_path}")
sd = torch.load(adapter_pt_path, map_location="cpu")
if isinstance(sd, dict) and "state_dict" in sd:
sd = sd["state_dict"]
r = _infer_lora_rank_from_state(sd) or int(fallback_rank)
tmods = _list_target_modules_from_state(sd) or DEFAULT_TARGET_MODULES
print(f"[lora] inferred rank={r}, target_modules={tmods}")
_write_adapter_config(out_dir, r=r, alpha=fallback_alpha, target_modules=tmods)
torch.save(sd, os.path.join(out_dir, "adapter_model.bin"))
return r, tmods
def _maybe_attach_lora_adapter(config):
"""Attach LoRA adapter directory to vLLM rollout (runtime LoRA)."""
# Accept either +lora.pt_path or model.load_param_path as a hint
lora_pt = None
if "lora" in config and getattr(config.lora, "pt_path", ""):
lora_pt = config.lora.pt_path
elif getattr(config.model, "load_param_path", ""):
lora_pt = config.model.load_param_path
if not lora_pt or not Path(lora_pt).is_file():
print("[lora] No LoRA .pt provided; running base model only.")
return
adapter_dir = os.path.join("/tmp", "lora_adapter_vllm")
r, _ = _wrap_lora_pt_as_peft(lora_pt, adapter_dir, fallback_rank=32, fallback_alpha=16)
# Ensure rollout keys exist and add LoRA knobs required by vLLM
with open_dict(config):
if "rollout" not in config:
config["rollout"] = OmegaConf.create()
with open_dict(config.rollout):
config.rollout.setdefault("max_loras", 1)
config.rollout.setdefault("max_lora_rank", int(r))
config.rollout["lora_modules"] = [{"path": adapter_dir, "scale": 1.0}]
print(f"[lora] Attached PEFT adapter: {adapter_dir} (rank={r})")
# CRITICAL: don't let FSDP try to load the LoRA .pt as a full state dict
with open_dict(config.model):
if getattr(config.model, "load_param", False):
print("[lora] Disabling model.load_param to avoid FSDP load_state_dict mismatch.")
config.model["load_param"] = False
# ---------------- Hydra entry ----------------
@hydra.main(config_path="config", config_name="infer", version_base=None)
def main(config):
_infer_lengths_and_defaults(config)
# Ray env for workers
if not ray.is_initialized():
ray.init(
runtime_env={"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"PYTORCH_CUDA_ALLOC_CONF": "", # keep allocator happy for vLLM
}},
num_cpus=config.ray_init.num_cpus,
)
ray.get(main_task.remote(config))
@ray.remote(num_cpus=1)
def main_task(config):
print("[worker] PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))
pprint(OmegaConf.to_container(config, resolve=True))
OmegaConf.resolve(config)
# Build LoRA adapter if provided
_maybe_attach_lora_adapter(config)
# Optionally pre-gen dataset schema if your repo provides it
try:
from prompts.infer_prompt import infer_dataset
infer_dataset(
model_name=config.model.path,
data_path=os.path.dirname(os.path.dirname(config.data.path)),
)
except Exception as e:
print(f"[info] infer_dataset() skipped: {e}")
# ---- Tokenizer from base model
local_path = copy_to_local(config.model.path)
trust_remote_code = getattr(config.model, "trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ---- Sampling checks
if float(config.rollout.temperature) == 0.0:
assert int(config.data.n_samples) == 1, "When temperature=0, n_samples must be 1."
assert int(config.data.n_samples) >= 1, "n_samples should always >= 1"
# ---- Load dataset
dataset = pd.read_parquet(config.data.path)
prompt_key = getattr(config.data, "prompt_key", "prompt")
if prompt_key not in dataset.columns:
raise KeyError(f"Dataset missing column '{prompt_key}'")
chat_lst = dataset[prompt_key].tolist()
chat_lst = [chat.tolist() if hasattr(chat, "tolist") else chat for chat in chat_lst]
# ---- Worker group (vLLM inside Rollout)
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
print("[debug] rollout.lora_modules =", config.rollout.get("lora_modules", None))
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
wg.init_model() # vLLM spins up; adapter used if set in rollout.lora_modules
total = len(dataset)
bs = int(config.data.batch_size)
num_batch = -(-total // bs)
slots = [[] for _ in range(int(config.data.n_samples))]
for b in range(num_batch):
print(f"[{b+1}/{num_batch}] Start to process.")
batch_chat = chat_lst[b * bs : (b + 1) * bs]
inputs = tokenizer.apply_chat_template(
batch_chat,
add_generation_prompt=True,
padding=True,
truncation=True,
max_length=int(config.rollout.prompt_length),
return_tensors="pt",
return_dict=True,
tokenize=True,
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
position_ids = compute_position_id_with_mask(attention_mask)
batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}
data = DataProto.from_dict(batch_dict)
data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)
print(f"[{b+1}/{num_batch}] Start to generate.")
for n in range(int(config.data.n_samples)):
output_padded = wg.generate_sequences(data_padded)
output = unpad_dataproto(output_padded, pad_size=pad_size)
texts = []
for i in range(len(output)):
item = output[i]
pl = item.batch["prompts"].shape[-1]
valid_len = item.batch["attention_mask"][pl:].sum()
resp_ids = item.batch["responses"][:valid_len]
s = tokenizer.decode(resp_ids, skip_special_tokens=True)
print(f"[raw] Response {i}: {s!r}")
ix = s.find("</think>")
if ix != -1:
s = s[ix + len("</think>") :].lstrip()
print(f"Response {i}: {s!r}")
try:
texts.append(ast.literal_eval(s))
except Exception:
texts.append(s)
slots[n].extend(texts)
outputs = np.array(slots, dtype=object)
outputs = np.transpose(outputs, (1, 0)).tolist()
dataset["response"] = outputs
keep = ["file_id", "vt", "gt", "response"]
cols = [c for c in keep if c in dataset.columns]
if cols:
dataset = dataset[cols]
out_path = config.data.output_path
makedirs(os.path.dirname(out_path), exist_ok=True)
dataset.to_json(out_path, orient="records", lines=True, force_ascii=False)
print(f"[done] Wrote: {out_path}")
if __name__ == "__main__":
main()