verl quick start 实战
【摘要】 环境信息3090 24G软件版本python3.12.3pytorch2.7.1cuda12.8transformers4.54.1flash-attn2.8.2vllm0.10.0verl0.5.0操作步骤git clone https://github.com/volcengine/verl.git数据集下载cd verlpython3 examples/data_preprocess/...
环境信息
3090 24G
软件 | 版本 |
python | 3.12.3 |
pytorch | 2.7.1 |
cuda | 12.8 |
transformers | 4.54.1 |
flash-attn | 2.8.2 |
vllm | 0.10.0 |
verl | 0.5.0 |
操作步骤
git clone https://github.com/volcengine/verl.git
数据集下载
cd verl
python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k
模型下载
python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-0.5B-Instruct')"
训练启动命令
如下配置在3090上单卡大概需要跑约7H
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
data.train_files=~/data/gsm8k/train.parquet \
data.val_files=~/data/gsm8k/test.parquet \
data.train_batch_size=256 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
critic.optim.lr=1e-5 \
critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
critic.ppo_micro_batch_size_per_gpu=4 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=console \
trainer.val_before_train=False \
trainer.n_gpus_per_node=1 \
trainer.nnodes=1 \
trainer.save_freq=500 \
trainer.test_freq=500 \
trainer.max_actor_ckpt_to_keep=2 \
trainer.max_critic_ckpt_to_keep=2 \
trainer.total_epochs=15 2>&1 | tee verl_demo.log
训练日志大致如下:
输出目录结构大致如下:
checkpoint合并
python3 -m verl.model_merger merge \
--backend fsdp \
--local_dir checkpoints/verl_examples/gsm8k/global_step_1/actor \
--target_dir checkpoints/verl_examples/gsm8k/global_step_1/actor/huggingface
测试验证
import re
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import os
import torch
from torch.utils.data import DataLoader
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
# ========= 1. 路径 =========
# Accuracy = 0.362 (477/1319)
CKPT_DIR = "~/checkpoints/verl_examples/gsm8k/global_step_435/actor/huggingface"
# Accuracy = 0.061 (81/1319)
#CKPT_DIR = "Qwen/Qwen2.5-0.5B-Instruct"
DATA_PATH = "/root/autodl-tmp/data/gsm8k/test.parquet"
DEVICE = "cuda"
BATCH = 256
# ---------- 2. 环境加速 ----------
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# ========= 2. 加载 tokenizer & 模型 =========
tokenizer = AutoTokenizer.from_pretrained(CKPT_DIR)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(CKPT_DIR,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2")
model = torch.compile(model, mode="reduce-overhead")
# ========= 3. 读取预处理后 parquet =========
class DS:
def __init__(self, df): self.df = df
def __len__(self): return len(self.df)
def __getitem__(self, idx):
r = self.df.iloc[idx]
return r["prompt"][0]["content"], str(r["reward_model"]["ground_truth"])
df = pd.read_parquet(DATA_PATH)
loader = DataLoader(DS(df), batch_size=BATCH, shuffle=False)
# ========= 4. 打分函数 =========
def extract_answer(text: str) -> str:
m = re.search(r"####\s*([\d\.,\-]+)", text)
return m.group(1).replace(",", "") if m else None
def judge(pred: str, gold: str) -> bool:
m = extract_answer(pred)
return m == gold
# ========= 5. 推理 & 评估 =========
correct = total = 0
for prompts, golds in tqdm(loader, desc="Evaluating"):
# 构造对话格式 prompt
inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
).to(model.device)
outs = model.generate(
**inputs,
max_new_tokens=256,
pad_token_id=tokenizer.eos_token_id,
)
preds = tokenizer.batch_decode(outs, skip_special_tokens=True)
for pred, gold in zip(preds, golds):
total += 1
if extract_answer(pred) == gold:
correct += 1
print(f"Accuracy = {correct / total:.3f} ({correct}/{total})")
训练效果:
训练前:
训练后:
【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)