大模型实践--全参数微调代码实践

举报
剑指南天 发表于 2026/05/14 17:25:00 2026/05/14
【摘要】 本文基于 TRL 中用于监督微调的 SFT Trainer 工具,实现 Qwen/Qwen3-0.6B 的全参微调。

1.概述

TRL(Transformers Reinforcement Learning)是一个全栈库,提供了一整套工具,如监督微调(SFT)、组相对策略优化(GRPO)、直接偏好优化(DPO)、奖励建模等方法训练 Transformer 语言模型。

本文基于 TRL 中用于监督微调的 SFT Trainer 工具,实现 Qwen/Qwen3-0.6B 的全参微调。

2. 模型选择


2. 训练数据集

数据集主要是提取段落关键词。

3. 微调方法

全参数微调

4. 分布式训练

否,单卡训练

5. 显卡选择

6. 微调工具

基于 TRL 中用于监督微调的 SFT Trainer 工具

7. 代码实践

①配置环境变量

②加载模型和分词器

from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
import torch
from transformers import AutoModelForCausalLM,AutoTokenizer

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Configure model and tokenizer
model_name = "Qwen/Qwen3-0.6B"
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)

③处理数据集

# 数据集
# Load dataset
dataset_dict = load_dataset("json",data_files={"train":"/root/autodl-tmp/hf/data/keywords_data_train.jsonl",
                                           "test":"/root/autodl-tmp/hf/data/keywords_data_test.jsonl"})
def map_func(message):
    conversation = message['conversation']
    messages = []
    for item in conversation:
        messages.append({"role": "user", "content": item['human']})
        messages.append({"role": "assistant", "content":item['assistant']})
    return {"messages":messages}
dataset_dict = dataset_dict.map(map_func,batched=False,remove_columns=['conversation_id','category','conversation','dataset'])
print(dataset_dict['train'][0])

④配置SFTTrainer工具

# Configure trainer
training_args = SFTConfig(
    output_dir="/root/autodl-tmp/hf/model/Qwen3-0.6B/sft_sample/model1/",
    max_steps=1000,
    per_device_train_batch_size=10,
    learning_rate=5e-5,
    save_steps=50,
    save_total_limit=2,
    eval_strategy="steps",
    eval_steps=50,
    load_best_model_at_end=True,
    logging_dir="/root/tf-logs/",
    logging_strategy='steps',
    logging_steps=50,
    bf16=True,
    warmup_steps=50,
    assistant_only_loss=True,
)
# Initialize trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset_dict["train"],
    eval_dataset=dataset_dict["test"],
    processing_class=tokenizer,
)
train_dataloader = trainer.get_train_dataloader()
# print(next(iter(train_dataloader)))
print(tokenizer.decode(next(iter(train_dataloader))['input_ids']))

⑤启动训练

# Start training
trainer.train()

⑥保存模型

trainer.save_model("/root/autodl-tmp/hf/model/Qwen3-0.6B/sft_sample/model1/final_model1/")

8. 推理效果对比

①Qwen/Qwen3-0.6B

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "Qwen/Qwen3-0.6B"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype="auto",
    device_map="auto"
)
# prepare the model input
prompt = "识别出文本中的关键词:\n题目:基于产权理论的智慧城市信息共享机制研究\n摘要:[目的/意义]新型智慧城市的建设是实现国家治理体系和治理能力现代化的关键手段,然而目前普遍存在信息孤岛问题严重影响新型智慧城市落地实现,如何合理高效解决信息资源共享问题,已经成为急需解决的问题。[方法/过程]通过分析智慧城市信息共享存在的问题原因以及解决思路,以信息产权及信息确权分析为基础,对智慧城市信息共享的利益冲突与平衡进行分析,最后提出一种基于产权理论的智慧城市信息共享机制框架。[结果/结论]新型智慧城市应充分考虑信息价值及信息产权问题,形成信息共享的正向激励机制,使智慧城市信息资源得以高效利用,实现政府各部门相互协作,适应复杂多变的城市环境需求,形成智慧政府能力。"
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() 

# parsing thinking content
try:
    # rindex finding 151668 (</think>)
    index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
    index = 0

thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

print("thinking content:", thinking_content)
print("content:", content)

②微调后的模型

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "/root/autodl-tmp/hf/model/Qwen3-0.6B/sft_sample/model1/final_model1/"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype="auto",
    device_map="auto"
)
# prepare the model input
prompt = "识别出文本中的关键词:\n题目:基于产权理论的智慧城市信息共享机制研究\n摘要:[目的/意义]新型智慧城市的建设是实现国家治理体系和治理能力现代化的关键手段,然而目前普遍存在信息孤岛问题严重影响新型智慧城市落地实现,如何合理高效解决信息资源共享问题,已经成为急需解决的问题。[方法/过程]通过分析智慧城市信息共享存在的问题原因以及解决思路,以信息产权及信息确权分析为基础,对智慧城市信息共享的利益冲突与平衡进行分析,最后提出一种基于产权理论的智慧城市信息共享机制框架。[结果/结论]新型智慧城市应充分考虑信息价值及信息产权问题,形成信息共享的正向激励机制,使智慧城市信息资源得以高效利用,实现政府各部门相互协作,适应复杂多变的城市环境需求,形成智慧政府能力。"
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()

# parsing thinking content
try:
    # rindex finding 151668 (</think>)
    index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
    index = 0

thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

print("thinking content:", thinking_content)
print("content:", content)

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。