大模型实践--QLoRA代码实践(基于Unsloth)

举报
剑指南天 发表于 2026/05/15 18:46:03 2026/05/15
【摘要】 Qwen/Qwen3-0.6B 全参数微调需要的显卡内存接近30GB,无法在RTX3060显卡上面进行微调。所以本文基于 TRL 中用于监督微调的 SFT Trainer 和 Unsloth(QLoRA)工具,在RTX3060显卡上面实现 Qwen/Qwen3-0.6B 的参数高效微调。

1.概述

TRL(Transformers Reinforcement Learning)是一个全栈库,提供了一整套工具,如监督微调(SFT)、组相对策略优化(GRPO)、直接偏好优化(DPO)、奖励建模等方法训练 Transformer 语言模型。

Qwen/Qwen3-0.6B 全参数微调需要的显卡内存接近30GB,无法在RTX3060显卡上面进行微调。所以本文基于 TRL 中用于监督微调的 SFT Trainer  Unsloth(QLoRA)工具,在RTX3060显卡上面实现 Qwen/Qwen3-0.6B 的参数高效微调。

2. 模型选择

2. 训练数据集

数据集主要是提取段落关键词。

3. 微调方法

QLoRA参数高效微调

4. 分布式训练

否,单卡训练

5. 显卡选择

RTX3060

6. 微调工具

基于 TRL 中用于监督微调的 SFT Trainer  Unsloth 工具。

7. 代码实践

①配置环境变量

import os
import torch
import unsloth
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer

# 配置预训练模型下载地址
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com/'
os.environ['TRITON_CACHE_DIR'] = 'C:\\triton_cache'

加载模型和分词器,将模型进行量化

# Configure model and tokenizer
model_name = "Qwen/Qwen3-0.6B"
max_length = 1024  # Supports automatic RoPE Scaling, so choose any number
model, tokenizer = unsloth.FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=max_length,
    dtype=torch.bfloat16,  # For auto-detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
    load_in_4bit=True,  # Use 4bit quantization to reduce memory usage. Can be False
)

tokenizer = unsloth.chat_templates.get_chat_template(
    tokenizer,
    chat_template="qwen3",  # change this to the right chat_template name
)

# r: rank dimension for LoRA update matrices (smaller = more compression)
rank_dimension = 4
# lora_alpha: scaling factor for LoRA layers (higher = stronger adaptation)
lora_alpha = 8
# lora_dropout: dropout probability for LoRA layers (helps prevent overfitting)
lora_dropout = 0.05

# Do model patching and add fast LoRA weights
model = unsloth.FastLanguageModel.get_peft_model(
    model,
    r=rank_dimension,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,  # Dropout = 0 is currently optimized
    bias="none",  # Bias = "none" is currently optimized
    use_gradient_checkpointing=True,
    random_state=3407,
)

③处理数据集

# 数据集
# Load dataset
dataset_dict = load_dataset("json",
                            data_files={"train": "hf/model/Qwen3-0.6B/peft_sample/data/keywords_data_test.jsonl",
                                        "test": "hf/model/Qwen3-0.6B/peft_sample/data/keywords_data_train.jsonl"})
def map_func(message):
    conversation = message['conversation']
    messages = []
    for item in conversation:
        messages.append({"role": "user", "content": item['human']})
        messages.append({"role": "assistant", "content": item['assistant']})
    return {"messages": messages}
# 将数据转化为对话式的 ShareGPT 格式,并移除不需要的字段
dataset_dict = dataset_dict.map(map_func, batched=False,
                                 remove_columns=['conversation_id', 'category', 'conversation', 'dataset'])

def formatting_prompts_func(messages):
    conversation = messages["messages"]
    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in
             conversation]
    return {"text": texts}
# 应用Chat Template
dataset_dict = dataset_dict.map(formatting_prompts_func, batched=True, remove_columns=['messages'])

④配置SFTTrainer工具

training_args = SFTConfig(
    output_dir="hf/model/Qwen3-0.6B/unsloth_sample/model1",
    max_steps=1000,
    per_device_train_batch_size=4,
    learning_rate=5e-5,
    save_steps=50,
    save_total_limit=2,
    eval_strategy="steps",
    eval_steps=50,
    load_best_model_at_end=True,
    logging_dir="hf/model/Qwen3-0.6B/peft_sample/logs/",
    logging_strategy='steps',
    logging_steps=50,
    bf16=True,
    warmup_steps=50,
    assistant_only_loss=False,
)

# Initialize trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset_dict["train"],
    eval_dataset=dataset_dict["test"].train_test_split(test_size=0.005, shuffle=True)['test'],
    processing_class=tokenizer,
)

⑤启动训练

# Start training
trainer.train()

⑥保存模型

# 保存模型
trainer.save_model("hf/model/Qwen3-0.6B/unsloth_sample/model1/final_model1/")

8. 推理效果对比

①Qwen/Qwen3-0.6B

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "Qwen/Qwen3-0.6B"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype="auto",
    device_map="auto"
)
# prepare the model input
prompt = "识别出文本中的关键词:\n题目:基于产权理论的智慧城市信息共享机制研究\n摘要:[目的/意义]新型智慧城市的建设是实现国家治理体系和治理能力现代化的关键手段,然而目前普遍存在信息孤岛问题严重影响新型智慧城市落地实现,如何合理高效解决信息资源共享问题,已经成为急需解决的问题。[方法/过程]通过分析智慧城市信息共享存在的问题原因以及解决思路,以信息产权及信息确权分析为基础,对智慧城市信息共享的利益冲突与平衡进行分析,最后提出一种基于产权理论的智慧城市信息共享机制框架。[结果/结论]新型智慧城市应充分考虑信息价值及信息产权问题,形成信息共享的正向激励机制,使智慧城市信息资源得以高效利用,实现政府各部门相互协作,适应复杂多变的城市环境需求,形成智慧政府能力。"
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() 

# parsing thinking content
try:
    # rindex finding 151668 (</think>)
    index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
    index = 0

thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

print("thinking content:", thinking_content)
print("content:", content)

②LoRA微调后的模型

import torch
from transformers import AutoModelForCausalLM,AutoTokenizer
import os
from peft import PeftModel

# 配置预训练模型下载地址
os.environ['HF_ENDPOINT']='https://hf-mirror.com/'

# Configure model and tokenizer
model_name = "Qwen/Qwen3-0.6B"

# 1. Load the base model
base_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

# 2. Load the Unsloth model with adapter
peft_model = PeftModel.from_pretrained(
base_model, "hf/model/Qwen3-0.6B/unsloth_sample/model1/final_model1/",device_map="auto")

# 3. Merge adapter weights with base model
merged_model = peft_model.merge_and_unload()

# 4. prepare the model input
prompt = "识别出文本中的关键词:\n题目:基于产权理论的智慧城市信息共享机制研究\n摘要:[目的/意义]新型智慧城市的建设是实现国家治理体系和治理能力现代化的关键手段,然而目前普遍存在信息孤岛问题严重影响新型智慧城市落地实现,如何合理高效解决信息资源共享问题,已经成为急需解决的问题。[方法/过程]通过分析智慧城市信息共享存在的问题原因以及解决思路,以信息产权及信息确权分析为基础,对智慧城市信息共享的利益冲突与平衡进行分析,最后提出一种基于产权理论的智慧城市信息共享机制框架。[结果/结论]新型智慧城市应充分考虑信息价值及信息产权问题,形成信息共享的正向激励机制,使智慧城市信息资源得以高效利用,实现政府各部门相互协作,适应复杂多变的城市环境需求,形成智慧政府能力。"
messages = [
{"role": "user", "content": prompt}
]

# 5. load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(merged_model.device)

# 6. conduct text completion
generated_ids = merged_model.generate(
**model_inputs,
max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()

# 7. parsing thinking content
try:
# rindex finding 151668 (</think>)
index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
index = 0

thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

print("thinking content:", thinking_content)
print("content:", content)

9. 总结:基于unsloth库的QLoRA在显存占用和计算方面对比peft有显著的降低。

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。