|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import ast
|
|
|
import json
|
|
|
import hydra
|
|
|
import numpy as np
|
|
|
import ray
|
|
|
import torch
|
|
|
from pathlib import Path
|
|
|
from pprint import pprint
|
|
|
|
|
|
|
|
|
os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "WARN")
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get("TOKENIZERS_PARALLELISM", "true")
|
|
|
|
|
|
|
|
|
_bad = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
|
|
|
if "expandable_segments:True" in _bad:
|
|
|
print(f"[fix] Removing incompatible PYTORCH_CUDA_ALLOC_CONF={_bad}")
|
|
|
os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None)
|
|
|
|
|
|
import pandas as pd
|
|
|
from omegaconf import OmegaConf, open_dict
|
|
|
|
|
|
from verl import DataProto
|
|
|
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
|
|
|
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
|
|
from verl.utils import hf_tokenizer
|
|
|
from verl.utils.fs import copy_to_local
|
|
|
from verl.utils.hdfs_io import makedirs
|
|
|
from verl.utils.model import compute_position_id_with_mask
|
|
|
from verl.workers.fsdp_workers import ActorRolloutRefWorker
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_TARGET_MODULES = [
|
|
|
"q_proj","k_proj","v_proj","o_proj",
|
|
|
"up_proj","gate_proj","down_proj",
|
|
|
]
|
|
|
|
|
|
def _infer_lengths_and_defaults(config):
|
|
|
"""Ensure rollout/data keys exist and set reasonable H100 defaults."""
|
|
|
|
|
|
with open_dict(config):
|
|
|
if "rollout" not in config:
|
|
|
config["rollout"] = OmegaConf.create()
|
|
|
if "data" not in config:
|
|
|
config["data"] = OmegaConf.create()
|
|
|
if "trainer" not in config:
|
|
|
config["trainer"] = OmegaConf.create()
|
|
|
if "ray_init" not in config:
|
|
|
config["ray_init"] = OmegaConf.create()
|
|
|
|
|
|
|
|
|
with open_dict(config.rollout):
|
|
|
|
|
|
config.rollout.setdefault("dtype", "bfloat16")
|
|
|
config.rollout.setdefault("kv_cache_dtype", "fp8_e5m2")
|
|
|
config.rollout.setdefault("tensor_model_parallel_size", 1)
|
|
|
config.rollout.setdefault("enable_chunked_prefill", True)
|
|
|
config.rollout.setdefault("swap_space", 8)
|
|
|
config.rollout.setdefault("gpu_memory_utilization", 0.62)
|
|
|
|
|
|
|
|
|
pl = int(config.rollout.get("prompt_length", 1024))
|
|
|
rl = int(config.rollout.get("response_length", 128))
|
|
|
need = int(pl + rl)
|
|
|
config.rollout.setdefault("max_model_len", need)
|
|
|
config.rollout.setdefault("max_num_batched_tokens", need)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open_dict(config.data):
|
|
|
config.data.setdefault("batch_size", 1)
|
|
|
config.data.setdefault("n_samples", 1)
|
|
|
config.data.setdefault("prompt_key", "prompt")
|
|
|
|
|
|
with open_dict(config.trainer):
|
|
|
config.trainer.setdefault("n_gpus_per_node", 1)
|
|
|
config.trainer.setdefault("nnodes", 1)
|
|
|
|
|
|
with open_dict(config.ray_init):
|
|
|
config.ray_init.setdefault("num_cpus", 4)
|
|
|
|
|
|
def _infer_lora_rank_from_state(sd):
|
|
|
for k, v in sd.items():
|
|
|
if k.endswith("lora_A.weight") and hasattr(v, "dim") and v.dim() == 2:
|
|
|
return int(v.shape[0])
|
|
|
return None
|
|
|
|
|
|
def _list_target_modules_from_state(sd):
|
|
|
found = set()
|
|
|
for k in sd.keys():
|
|
|
if "lora_A.weight" in k or "lora_B.weight" in k:
|
|
|
if ".q_proj." in k: found.add("q_proj")
|
|
|
if ".k_proj." in k: found.add("k_proj")
|
|
|
if ".v_proj." in k: found.add("v_proj")
|
|
|
if ".o_proj." in k: found.add("o_proj")
|
|
|
if ".up_proj." in k: found.add("up_proj")
|
|
|
if ".gate_proj." in k: found.add("gate_proj")
|
|
|
if ".down_proj." in k: found.add("down_proj")
|
|
|
return sorted(found)
|
|
|
|
|
|
def _write_adapter_config(adapter_dir, r, alpha, target_modules, dropout=0.0):
|
|
|
cfg = {
|
|
|
"peft_type": "LORA",
|
|
|
"auto_mapping": None,
|
|
|
"base_model_name_or_path": "",
|
|
|
"bias": "none",
|
|
|
"inference_mode": True,
|
|
|
"lora_alpha": int(alpha),
|
|
|
"lora_dropout": float(dropout),
|
|
|
"r": int(r),
|
|
|
"target_modules": target_modules,
|
|
|
"task_type": "CAUSAL_LM",
|
|
|
}
|
|
|
with open(os.path.join(adapter_dir, "adapter_config.json"), "w", encoding="utf-8") as f:
|
|
|
json.dump(cfg, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def _wrap_lora_pt_as_peft(adapter_pt_path: str, out_dir: str,
|
|
|
fallback_rank=32, fallback_alpha=16):
|
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
print(f"[lora] Loading LoRA state from: {adapter_pt_path}")
|
|
|
sd = torch.load(adapter_pt_path, map_location="cpu")
|
|
|
if isinstance(sd, dict) and "state_dict" in sd:
|
|
|
sd = sd["state_dict"]
|
|
|
|
|
|
r = _infer_lora_rank_from_state(sd) or int(fallback_rank)
|
|
|
tmods = _list_target_modules_from_state(sd) or DEFAULT_TARGET_MODULES
|
|
|
print(f"[lora] inferred rank={r}, target_modules={tmods}")
|
|
|
|
|
|
_write_adapter_config(out_dir, r=r, alpha=fallback_alpha, target_modules=tmods)
|
|
|
torch.save(sd, os.path.join(out_dir, "adapter_model.bin"))
|
|
|
return r, tmods
|
|
|
|
|
|
def _maybe_attach_lora_adapter(config):
|
|
|
"""Attach LoRA adapter directory to vLLM rollout (runtime LoRA)."""
|
|
|
|
|
|
lora_pt = None
|
|
|
if "lora" in config and getattr(config.lora, "pt_path", ""):
|
|
|
lora_pt = config.lora.pt_path
|
|
|
elif getattr(config.model, "load_param_path", ""):
|
|
|
lora_pt = config.model.load_param_path
|
|
|
|
|
|
if not lora_pt or not Path(lora_pt).is_file():
|
|
|
print("[lora] No LoRA .pt provided; running base model only.")
|
|
|
return
|
|
|
|
|
|
adapter_dir = os.path.join("/tmp", "lora_adapter_vllm")
|
|
|
r, _ = _wrap_lora_pt_as_peft(lora_pt, adapter_dir, fallback_rank=32, fallback_alpha=16)
|
|
|
|
|
|
|
|
|
with open_dict(config):
|
|
|
if "rollout" not in config:
|
|
|
config["rollout"] = OmegaConf.create()
|
|
|
with open_dict(config.rollout):
|
|
|
config.rollout.setdefault("max_loras", 1)
|
|
|
config.rollout.setdefault("max_lora_rank", int(r))
|
|
|
config.rollout["lora_modules"] = [{"path": adapter_dir, "scale": 1.0}]
|
|
|
print(f"[lora] Attached PEFT adapter: {adapter_dir} (rank={r})")
|
|
|
|
|
|
|
|
|
with open_dict(config.model):
|
|
|
if getattr(config.model, "load_param", False):
|
|
|
print("[lora] Disabling model.load_param to avoid FSDP load_state_dict mismatch.")
|
|
|
config.model["load_param"] = False
|
|
|
|
|
|
|
|
|
|
|
|
@hydra.main(config_path="config", config_name="infer", version_base=None)
|
|
|
def main(config):
|
|
|
_infer_lengths_and_defaults(config)
|
|
|
|
|
|
|
|
|
if not ray.is_initialized():
|
|
|
ray.init(
|
|
|
runtime_env={"env_vars": {
|
|
|
"TOKENIZERS_PARALLELISM": "true",
|
|
|
"NCCL_DEBUG": "WARN",
|
|
|
"PYTORCH_CUDA_ALLOC_CONF": "",
|
|
|
}},
|
|
|
num_cpus=config.ray_init.num_cpus,
|
|
|
)
|
|
|
|
|
|
ray.get(main_task.remote(config))
|
|
|
|
|
|
@ray.remote(num_cpus=1)
|
|
|
def main_task(config):
|
|
|
print("[worker] PYTORCH_CUDA_ALLOC_CONF =", os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))
|
|
|
pprint(OmegaConf.to_container(config, resolve=True))
|
|
|
OmegaConf.resolve(config)
|
|
|
|
|
|
|
|
|
_maybe_attach_lora_adapter(config)
|
|
|
|
|
|
|
|
|
try:
|
|
|
from prompts.infer_prompt import infer_dataset
|
|
|
infer_dataset(
|
|
|
model_name=config.model.path,
|
|
|
data_path=os.path.dirname(os.path.dirname(config.data.path)),
|
|
|
)
|
|
|
except Exception as e:
|
|
|
print(f"[info] infer_dataset() skipped: {e}")
|
|
|
|
|
|
|
|
|
local_path = copy_to_local(config.model.path)
|
|
|
trust_remote_code = getattr(config.model, "trust_remote_code", False)
|
|
|
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
|
|
tokenizer.padding_side = "left"
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
if float(config.rollout.temperature) == 0.0:
|
|
|
assert int(config.data.n_samples) == 1, "When temperature=0, n_samples must be 1."
|
|
|
assert int(config.data.n_samples) >= 1, "n_samples should always >= 1"
|
|
|
|
|
|
|
|
|
dataset = pd.read_parquet(config.data.path)
|
|
|
prompt_key = getattr(config.data, "prompt_key", "prompt")
|
|
|
if prompt_key not in dataset.columns:
|
|
|
raise KeyError(f"Dataset missing column '{prompt_key}'")
|
|
|
chat_lst = dataset[prompt_key].tolist()
|
|
|
chat_lst = [chat.tolist() if hasattr(chat, "tolist") else chat for chat in chat_lst]
|
|
|
|
|
|
|
|
|
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
|
|
|
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
|
|
|
print("[debug] rollout.lora_modules =", config.rollout.get("lora_modules", None))
|
|
|
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
|
|
|
wg.init_model()
|
|
|
|
|
|
total = len(dataset)
|
|
|
bs = int(config.data.batch_size)
|
|
|
num_batch = -(-total // bs)
|
|
|
slots = [[] for _ in range(int(config.data.n_samples))]
|
|
|
|
|
|
for b in range(num_batch):
|
|
|
print(f"[{b+1}/{num_batch}] Start to process.")
|
|
|
batch_chat = chat_lst[b * bs : (b + 1) * bs]
|
|
|
|
|
|
inputs = tokenizer.apply_chat_template(
|
|
|
batch_chat,
|
|
|
add_generation_prompt=True,
|
|
|
padding=True,
|
|
|
truncation=True,
|
|
|
max_length=int(config.rollout.prompt_length),
|
|
|
return_tensors="pt",
|
|
|
return_dict=True,
|
|
|
tokenize=True,
|
|
|
)
|
|
|
input_ids = inputs["input_ids"]
|
|
|
attention_mask = inputs["attention_mask"]
|
|
|
position_ids = compute_position_id_with_mask(attention_mask)
|
|
|
batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}
|
|
|
|
|
|
data = DataProto.from_dict(batch_dict)
|
|
|
data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)
|
|
|
|
|
|
print(f"[{b+1}/{num_batch}] Start to generate.")
|
|
|
for n in range(int(config.data.n_samples)):
|
|
|
output_padded = wg.generate_sequences(data_padded)
|
|
|
output = unpad_dataproto(output_padded, pad_size=pad_size)
|
|
|
texts = []
|
|
|
for i in range(len(output)):
|
|
|
item = output[i]
|
|
|
pl = item.batch["prompts"].shape[-1]
|
|
|
valid_len = item.batch["attention_mask"][pl:].sum()
|
|
|
resp_ids = item.batch["responses"][:valid_len]
|
|
|
s = tokenizer.decode(resp_ids, skip_special_tokens=True)
|
|
|
print(f"[raw] Response {i}: {s!r}")
|
|
|
ix = s.find("</think>")
|
|
|
if ix != -1:
|
|
|
s = s[ix + len("</think>") :].lstrip()
|
|
|
print(f"Response {i}: {s!r}")
|
|
|
try:
|
|
|
texts.append(ast.literal_eval(s))
|
|
|
except Exception:
|
|
|
texts.append(s)
|
|
|
slots[n].extend(texts)
|
|
|
|
|
|
outputs = np.array(slots, dtype=object)
|
|
|
outputs = np.transpose(outputs, (1, 0)).tolist()
|
|
|
dataset["response"] = outputs
|
|
|
|
|
|
keep = ["file_id", "vt", "gt", "response"]
|
|
|
cols = [c for c in keep if c in dataset.columns]
|
|
|
if cols:
|
|
|
dataset = dataset[cols]
|
|
|
|
|
|
out_path = config.data.output_path
|
|
|
makedirs(os.path.dirname(out_path), exist_ok=True)
|
|
|
dataset.to_json(out_path, orient="records", lines=True, force_ascii=False)
|
|
|
print(f"[done] Wrote: {out_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|