Project02/AITrain/train_npc_dialogue_lora.py
2025-08-11 14:10:38 +08:00

480 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
角色对话LoRA微调训练脚本
基于test.jsonl数据微调Qwen 8B模型生成游戏NPC对话
'''
import json
import os
import torch
import numpy as np
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq, TrainerCallback
from datasets import Dataset
import platform
import swanlab
from swanlab.integration.transformers import SwanLabCallback
import logging
# Windows multiprocessing兼容性修复
if platform.system() == "Windows":
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)
os.environ['VLLM_USE_MODELSCOPE'] = 'True'
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('training_debug.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class GradientMonitorCallback(TrainerCallback):
"""梯度监控回调函数用于检测NaN和梯度爆炸"""
def __init__(self):
self.step_count = 0
def on_step_begin(self, args, state, control, model=None, **kwargs):
"""在每个训练步骤开始前检查参数状态"""
self.step_count += 1
logger.info(f"\n=== Step {self.step_count} Begin ===")
# 检查模型参数中的异常值
for name, param in model.named_parameters():
if param.requires_grad:
if torch.isnan(param.data).any():
logger.error(f"NaN detected in parameter: {name}")
if torch.isinf(param.data).any():
logger.error(f"Inf detected in parameter: {name}")
# 记录参数统计
param_stats = {
'min': param.data.min().item(),
'max': param.data.max().item(),
'mean': param.data.mean().item(),
'std': param.data.std().item()
}
if abs(param_stats['max']) > 1e6 or abs(param_stats['min']) > 1e6:
logger.warning(f"Large parameter values in {name}: {param_stats}")
def on_step_end(self, args, state, control, model=None, **kwargs):
"""在每个训练步骤结束后检查梯度"""
logger.info(f"=== Step {self.step_count} End ===")
total_norm = 0
nan_grads = []
inf_grads = []
large_grads = []
for name, param in model.named_parameters():
if param.requires_grad and param.grad is not None:
# 检查梯度中的NaN和Inf
if torch.isnan(param.grad).any():
nan_grads.append(name)
logger.error(f"NaN gradient detected in: {name}")
if torch.isinf(param.grad).any():
inf_grads.append(name)
logger.error(f"Inf gradient detected in: {name}")
# 计算梯度范数
param_norm = param.grad.data.norm(2)
total_norm += param_norm.item() ** 2
# 检查异常大的梯度
if param_norm > 10.0:
large_grads.append((name, param_norm.item()))
logger.warning(f"Large gradient in {name}: {param_norm.item():.6f}")
total_norm = total_norm ** (1. / 2)
logger.info(f"Total gradient norm: {total_norm:.6f}")
# 汇总报告
if nan_grads:
logger.error(f"Parameters with NaN gradients: {nan_grads}")
if inf_grads:
logger.error(f"Parameters with Inf gradients: {inf_grads}")
if large_grads:
logger.warning(f"Parameters with large gradients: {large_grads}")
# 如果梯度范数过大,记录详细信息
if total_norm > 100.0:
logger.error(f"Gradient explosion detected! Total norm: {total_norm}")
def on_log(self, args, state, control, model=None, **kwargs):
"""记录训练日志"""
if hasattr(state, 'log_history') and state.log_history:
last_log = state.log_history[-1]
if 'train_loss' in last_log:
loss = last_log['train_loss']
logger.info(f"Current loss: {loss:.6f}")
if np.isnan(loss):
logger.error("Loss is NaN!")
elif loss > 1e6:
logger.error(f"Loss explosion detected: {loss}")
def process_func(example, tokenizer):
"""数据预处理函数(增加数值稳定性检查)"""
MAX_LENGTH = 1024
try:
# 构建对话模板 - 专门针对角色对话
system_prompt = f"你是一个游戏中的NPC角色。{example['character']}"
instruction = example['instruction']
user_input = example['input']
# 验证输入数据
if not isinstance(system_prompt, str) or not isinstance(instruction, str):
logger.warning(f"Invalid input data types: system_prompt={type(system_prompt)}, instruction={type(instruction)}")
return None
# 定义输入部分
instruction_tokens = tokenizer(
f"<s><|im_start|>system\n{system_prompt}<|im_end|>\n"
f"<|im_start|>user\n{instruction + user_input}<|im_end|>\n"
f"<|im_start|>assistant\n",
add_special_tokens=False
)
# 定义输出部分
response_tokens = tokenizer(f"{example['output']}", add_special_tokens=False)
# 验证tokenization结果
if not instruction_tokens["input_ids"] or not response_tokens["input_ids"]:
logger.warning("Empty tokenization result")
return None
# 合并输入输出
input_ids = instruction_tokens["input_ids"] + response_tokens["input_ids"] + [tokenizer.pad_token_id]
attention_mask = instruction_tokens["attention_mask"] + response_tokens["attention_mask"] + [1]
# 标签:只对输出部分计算损失
labels = [-100] * len(instruction_tokens["input_ids"]) + response_tokens["input_ids"] + [tokenizer.pad_token_id]
# 截断处理
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
# 最终验证
if len(input_ids) != len(attention_mask) or len(input_ids) != len(labels):
logger.error(f"Length mismatch: input_ids={len(input_ids)}, attention_mask={len(attention_mask)}, labels={len(labels)}")
return None
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
except Exception as e:
logger.error(f"Error in process_func: {e}")
return None
def load_model_and_tokenizer(model_path):
"""加载模型和分词器"""
print(f"Loading model from: {model_path}")
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
return model, tokenizer
def create_lora_config():
"""创建LoRA配置"""
config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
inference_mode=False,
r=8, # rank
lora_alpha=8, # 降低alpha值以增加稳定性
lora_dropout=0.05, # 降低dropout以减少不稳定性
# 移除modules_to_save以避免embed_tokens参数的NaN问题
# modules_to_save=["lm_head", "embed_tokens"]
)
return config
def prepare_dataset(data_path, tokenizer):
"""准备数据集(增强健壮性和数值稳定性检查)"""
print(f"Loading dataset from: {data_path}")
logger.info(f"Loading dataset from: {data_path}")
# 加载JSON数据
with open(data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"Total samples before filtering: {len(data)}")
logger.info(f"Total samples before filtering: {len(data)}")
# 转换为Dataset格式
dataset = Dataset.from_list(data)
#过滤 None 和空字符串(推荐)
dataset = dataset.filter(
lambda example: example.get("output") not in [None, ""]
)
# 应用预处理函数
tokenized_dataset = dataset.map(
lambda example: process_func(example, tokenizer),
remove_columns=dataset.column_names,
batched=False,
desc="Tokenizing dataset"
)
# 过滤掉预处理后变为空的样本或包含None的样本
original_size = len(tokenized_dataset)
tokenized_dataset = tokenized_dataset.filter(
lambda example: example is not None and
len(example.get("input_ids", [])) > 0 and
len(example.get("labels", [])) > 0
)
filtered_size = len(tokenized_dataset)
print(f"Total samples after filtering: {filtered_size} ({original_size - filtered_size} samples removed)")
logger.info(f"Total samples after filtering: {filtered_size} ({original_size - filtered_size} samples removed)")
# 数据质量检查
if filtered_size == 0:
logger.error("No valid samples remaining after filtering!")
raise ValueError("Dataset is empty after preprocessing")
# 检查几个样本的数据质量
for i in range(min(3, filtered_size)):
sample = tokenized_dataset[i]
logger.info(f"Sample {i} - input_ids length: {len(sample['input_ids'])}, labels length: {len(sample['labels'])}")
# 检查是否包含异常token
if any(token_id < 0 or token_id > tokenizer.vocab_size for token_id in sample['input_ids'] if token_id != -100):
logger.warning(f"Sample {i} contains out-of-vocabulary tokens")
return tokenized_dataset
def train_lora_model(model_path, data_path, output_dir):
"""训练LoRA模型"""
# 1. 加载模型和分词器
model, tokenizer = load_model_and_tokenizer(model_path)
# 2. 创建LoRA配置
lora_config = create_lora_config()
# 3. 应用LoRA
model = get_peft_model(model, lora_config)
# 4. 数值稳定性初始化 - 初始化LoRA权重
for name, param in model.named_parameters():
if param.requires_grad:
if 'lora_A' in name:
# LoRA A矩阵使用正态分布初始化
torch.nn.init.normal_(param, mean=0.0, std=0.01)
elif 'lora_B' in name:
# LoRA B矩阵初始化为0
torch.nn.init.zeros_(param)
# 检查初始化后是否有异常值
if torch.isnan(param).any() or torch.isinf(param).any():
logger.error(f"Abnormal values detected in parameter {name} after initialization")
torch.nn.init.normal_(param, mean=0.0, std=0.001)
# 5. 启用梯度计算
for param in model.parameters():
if param.requires_grad:
param.requires_grad_(True)
model.config.use_cache = False # 关闭缓存以节省显存
# 6. 准备数据集
train_preparedataset = prepare_dataset(data_path, tokenizer)
# 7. 配置训练参数 - 针对3080显卡优化
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=2, # 减小batch size
gradient_accumulation_steps=4, # 增加梯度累积
logging_steps=10,
num_train_epochs=3, # 增加训练轮数以充分学习角色特征
save_steps=50,
learning_rate=5e-6, # 进一步降低学习率
warmup_ratio=0.1,
max_grad_norm=0.5, # 更严格的梯度裁剪
save_on_each_node=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": True},
dataloader_pin_memory=False, # 减少内存使用
remove_unused_columns=False,
report_to="none",
bf16=True, # 显式启用bf16以匹配模型加载类型
#fp16=False, # 确保fp16被禁用
save_total_limit=3, # 只保留最新的3个检查点
adam_epsilon=1e-8, # 增加数值稳定性
weight_decay=0.01, # 添加权重衰减
)
#添加swan监测
swanlab_callback = SwanLabCallback(
project = "QwenLora_Learn",
experiment_name="Qwen3-8B-LoRA-experiment"
)
swanlab.login(api_key="pAxFTROvv3aspmEijax46")
# 创建梯度监控回调
gradient_monitor = GradientMonitorCallback()
# 8. 创建训练器
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_preparedataset,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
callbacks=[swanlab_callback, gradient_monitor] # 添加梯度监控回调
)
# 9. 开始训练
print("Starting training...")
logger.info("Starting training...")
try:
trainer.train()
logger.info("Training completed successfully!")
except Exception as e:
logger.error(f"Training failed with error: {e}")
import traceback
logger.error(traceback.format_exc())
raise
# 10. 保存最终模型
final_output_dir = os.path.join(output_dir, "final_model")
trainer.save_model(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
print(f"Training completed! Model saved to: {final_output_dir}")
return final_output_dir
def test_trained_model(model_path, lora_path):
"""测试训练后的模型"""
print("Testing trained model...")
# 加载基础模型
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# 加载LoRA权重
model = PeftModel.from_pretrained(model, lora_path)
# 测试对话
test_cases = [
{
"system": "你是克莱恩,一位神秘学专家和侦探。",
"user": "请告诉我一些关于神秘学的知识。"
},
{
"system": "你是阿兹克,经验丰富的神秘学导师。",
"user": "学生遇到了危险,你会给出什么建议?"
},
{
"system": "你是塔利姆,一个有礼貌的普通人,遇到了困难。",
"user": "你最近怎么样?"
}
]
for i, case in enumerate(test_cases):
messages = [
{"role": "system", "content": case["system"]},
{"role": "user", "content": case["user"]}
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
enable_thinking=False
).to('cuda')
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
# 检查非法值
if torch.isnan(probs).any():
print("概率张量包含NaN")
if torch.isinf(probs).any():
print("概率张量包含Inf")
if (probs < 0).any():
print("概率张量包含负数!")
outputs = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.8,
pad_token_id=tokenizer.eos_token_id
)
response = outputs[0][inputs['input_ids'].shape[1]:]
decoded_response = tokenizer.decode(response, skip_special_tokens=True)
print(f"\n--- 测试用例 {i+1} ---")
print(f"系统提示: {case['system']}")
print(f"用户输入: {case['user']}")
print(f"模型回复: {decoded_response}")
def main():
# 配置路径
model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径
data_path = './npc_dialogue_dataset.json' # 训练数据路径
output_dir = './output/NPC_Dialogue_LoRA' # 输出目录
# #####test
# final_model_path = os.path.join(output_dir, "final_model")
# test_trained_model(model_path, final_model_path)
# 确保数据文件存在
if not os.path.exists(data_path):
print(f"数据文件不存在: {data_path}")
print("请先运行 prepare_dialogue_data.py 生成训练数据")
return
try:
# 训练模型
final_model_path = train_lora_model(model_path, data_path, output_dir)
# 测试模型
test_trained_model(model_path, final_model_path)
except Exception as e:
print(f"训练过程中出现错误: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()