Project02/AITrain/train_npc_dialogue_lora.py

480 lines
18 KiB
Python
Raw Normal View History

2025-08-08 10:17:40 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
角色对话LoRA微调训练脚本
基于test.jsonl数据微调Qwen 8B模型生成游戏NPC对话
'''
import json
import os
import torch
2025-08-11 12:01:40 +08:00
import numpy as np
2025-08-08 10:17:40 +08:00
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
2025-08-11 12:01:40 +08:00
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq, TrainerCallback
2025-08-08 10:17:40 +08:00
from datasets import Dataset
import platform
import swanlab
from swanlab.integration.transformers import SwanLabCallback
2025-08-11 12:01:40 +08:00
import logging
2025-08-08 10:17:40 +08:00
# Windows multiprocessing兼容性修复
if platform.system() == "Windows":
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)
os.environ['VLLM_USE_MODELSCOPE'] = 'True'
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
2025-08-11 12:01:40 +08:00
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('training_debug.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
2025-08-08 10:17:40 +08:00
2025-08-11 12:01:40 +08:00
class GradientMonitorCallback(TrainerCallback):
"""梯度监控回调函数用于检测NaN和梯度爆炸"""
2025-08-08 10:17:40 +08:00
2025-08-11 12:01:40 +08:00
def __init__(self):
self.step_count = 0
def on_step_begin(self, args, state, control, model=None, **kwargs):
"""在每个训练步骤开始前检查参数状态"""
self.step_count += 1
logger.info(f"\n=== Step {self.step_count} Begin ===")
# 检查模型参数中的异常值
for name, param in model.named_parameters():
if param.requires_grad:
if torch.isnan(param.data).any():
logger.error(f"NaN detected in parameter: {name}")
if torch.isinf(param.data).any():
logger.error(f"Inf detected in parameter: {name}")
# 记录参数统计
param_stats = {
'min': param.data.min().item(),
'max': param.data.max().item(),
'mean': param.data.mean().item(),
'std': param.data.std().item()
}
if abs(param_stats['max']) > 1e6 or abs(param_stats['min']) > 1e6:
logger.warning(f"Large parameter values in {name}: {param_stats}")
2025-08-08 10:17:40 +08:00
2025-08-11 12:01:40 +08:00
def on_step_end(self, args, state, control, model=None, **kwargs):
"""在每个训练步骤结束后检查梯度"""
logger.info(f"=== Step {self.step_count} End ===")
total_norm = 0
nan_grads = []
inf_grads = []
large_grads = []
for name, param in model.named_parameters():
if param.requires_grad and param.grad is not None:
# 检查梯度中的NaN和Inf
if torch.isnan(param.grad).any():
nan_grads.append(name)
logger.error(f"NaN gradient detected in: {name}")
if torch.isinf(param.grad).any():
inf_grads.append(name)
logger.error(f"Inf gradient detected in: {name}")
# 计算梯度范数
param_norm = param.grad.data.norm(2)
total_norm += param_norm.item() ** 2
# 检查异常大的梯度
if param_norm > 10.0:
large_grads.append((name, param_norm.item()))
logger.warning(f"Large gradient in {name}: {param_norm.item():.6f}")
total_norm = total_norm ** (1. / 2)
logger.info(f"Total gradient norm: {total_norm:.6f}")
# 汇总报告
if nan_grads:
logger.error(f"Parameters with NaN gradients: {nan_grads}")
if inf_grads:
logger.error(f"Parameters with Inf gradients: {inf_grads}")
if large_grads:
logger.warning(f"Parameters with large gradients: {large_grads}")
# 如果梯度范数过大,记录详细信息
if total_norm > 100.0:
logger.error(f"Gradient explosion detected! Total norm: {total_norm}")
def on_log(self, args, state, control, model=None, **kwargs):
"""记录训练日志"""
if hasattr(state, 'log_history') and state.log_history:
last_log = state.log_history[-1]
if 'train_loss' in last_log:
loss = last_log['train_loss']
logger.info(f"Current loss: {loss:.6f}")
if np.isnan(loss):
logger.error("Loss is NaN!")
elif loss > 1e6:
logger.error(f"Loss explosion detected: {loss}")
def process_func(example, tokenizer):
"""数据预处理函数(增加数值稳定性检查)"""
MAX_LENGTH = 1024
2025-08-08 10:17:40 +08:00
2025-08-11 12:01:40 +08:00
try:
# 构建对话模板 - 专门针对角色对话
system_prompt = f"你是一个游戏中的NPC角色。{example['character']}"
instruction = example['instruction']
user_input = example['input']
# 验证输入数据
if not isinstance(system_prompt, str) or not isinstance(instruction, str):
logger.warning(f"Invalid input data types: system_prompt={type(system_prompt)}, instruction={type(instruction)}")
return None
# 定义输入部分
instruction_tokens = tokenizer(
f"<s><|im_start|>system\n{system_prompt}<|im_end|>\n"
f"<|im_start|>user\n{instruction + user_input}<|im_end|>\n"
f"<|im_start|>assistant\n",
add_special_tokens=False
)
# 定义输出部分
response_tokens = tokenizer(f"{example['output']}", add_special_tokens=False)
# 验证tokenization结果
if not instruction_tokens["input_ids"] or not response_tokens["input_ids"]:
logger.warning("Empty tokenization result")
return None
# 合并输入输出
input_ids = instruction_tokens["input_ids"] + response_tokens["input_ids"] + [tokenizer.pad_token_id]
attention_mask = instruction_tokens["attention_mask"] + response_tokens["attention_mask"] + [1]
# 标签:只对输出部分计算损失
labels = [-100] * len(instruction_tokens["input_ids"]) + response_tokens["input_ids"] + [tokenizer.pad_token_id]
# 截断处理
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
# 最终验证
if len(input_ids) != len(attention_mask) or len(input_ids) != len(labels):
logger.error(f"Length mismatch: input_ids={len(input_ids)}, attention_mask={len(attention_mask)}, labels={len(labels)}")
return None
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
except Exception as e:
logger.error(f"Error in process_func: {e}")
return None
2025-08-08 10:17:40 +08:00
def load_model_and_tokenizer(model_path):
"""加载模型和分词器"""
print(f"Loading model from: {model_path}")
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 加载模型
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
return model, tokenizer
def create_lora_config():
"""创建LoRA配置"""
config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
2025-08-11 10:40:01 +08:00
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
2025-08-08 10:17:40 +08:00
inference_mode=False,
2025-08-11 14:10:38 +08:00
r=8, # rank
lora_alpha=8, # 降低alpha值以增加稳定性
lora_dropout=0.05, # 降低dropout以减少不稳定性
# 移除modules_to_save以避免embed_tokens参数的NaN问题
# modules_to_save=["lm_head", "embed_tokens"]
2025-08-08 10:17:40 +08:00
)
return config
def prepare_dataset(data_path, tokenizer):
2025-08-11 12:01:40 +08:00
"""准备数据集(增强健壮性和数值稳定性检查)"""
2025-08-08 10:17:40 +08:00
print(f"Loading dataset from: {data_path}")
2025-08-11 12:01:40 +08:00
logger.info(f"Loading dataset from: {data_path}")
2025-08-08 10:17:40 +08:00
# 加载JSON数据
with open(data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
2025-08-11 11:21:03 +08:00
print(f"Total samples before filtering: {len(data)}")
2025-08-11 12:01:40 +08:00
logger.info(f"Total samples before filtering: {len(data)}")
2025-08-08 10:17:40 +08:00
# 转换为Dataset格式
dataset = Dataset.from_list(data)
2025-08-11 14:10:38 +08:00
#过滤 None 和空字符串(推荐)
dataset = dataset.filter(
lambda example: example.get("output") not in [None, ""]
)
2025-08-08 10:17:40 +08:00
# 应用预处理函数
tokenized_dataset = dataset.map(
2025-08-11 11:21:03 +08:00
lambda example: process_func(example, tokenizer),
2025-08-08 10:17:40 +08:00
remove_columns=dataset.column_names,
2025-08-11 12:01:40 +08:00
batched=False,
desc="Tokenizing dataset"
2025-08-08 10:17:40 +08:00
)
2025-08-11 12:01:40 +08:00
# 过滤掉预处理后变为空的样本或包含None的样本
2025-08-11 11:21:03 +08:00
original_size = len(tokenized_dataset)
2025-08-11 12:01:40 +08:00
tokenized_dataset = tokenized_dataset.filter(
lambda example: example is not None and
len(example.get("input_ids", [])) > 0 and
len(example.get("labels", [])) > 0
)
2025-08-11 11:21:03 +08:00
filtered_size = len(tokenized_dataset)
print(f"Total samples after filtering: {filtered_size} ({original_size - filtered_size} samples removed)")
2025-08-11 12:01:40 +08:00
logger.info(f"Total samples after filtering: {filtered_size} ({original_size - filtered_size} samples removed)")
# 数据质量检查
if filtered_size == 0:
logger.error("No valid samples remaining after filtering!")
raise ValueError("Dataset is empty after preprocessing")
# 检查几个样本的数据质量
for i in range(min(3, filtered_size)):
sample = tokenized_dataset[i]
logger.info(f"Sample {i} - input_ids length: {len(sample['input_ids'])}, labels length: {len(sample['labels'])}")
# 检查是否包含异常token
if any(token_id < 0 or token_id > tokenizer.vocab_size for token_id in sample['input_ids'] if token_id != -100):
logger.warning(f"Sample {i} contains out-of-vocabulary tokens")
2025-08-11 11:21:03 +08:00
2025-08-08 10:17:40 +08:00
return tokenized_dataset
def train_lora_model(model_path, data_path, output_dir):
"""训练LoRA模型"""
# 1. 加载模型和分词器
model, tokenizer = load_model_and_tokenizer(model_path)
# 2. 创建LoRA配置
lora_config = create_lora_config()
# 3. 应用LoRA
model = get_peft_model(model, lora_config)
2025-08-11 14:10:38 +08:00
# 4. 数值稳定性初始化 - 初始化LoRA权重
for name, param in model.named_parameters():
if param.requires_grad:
if 'lora_A' in name:
# LoRA A矩阵使用正态分布初始化
torch.nn.init.normal_(param, mean=0.0, std=0.01)
elif 'lora_B' in name:
# LoRA B矩阵初始化为0
torch.nn.init.zeros_(param)
# 检查初始化后是否有异常值
if torch.isnan(param).any() or torch.isinf(param).any():
logger.error(f"Abnormal values detected in parameter {name} after initialization")
torch.nn.init.normal_(param, mean=0.0, std=0.001)
# 5. 启用梯度计算
2025-08-08 10:17:40 +08:00
for param in model.parameters():
if param.requires_grad:
param.requires_grad_(True)
model.config.use_cache = False # 关闭缓存以节省显存
2025-08-11 14:10:38 +08:00
# 6. 准备数据集
2025-08-11 10:29:12 +08:00
train_preparedataset = prepare_dataset(data_path, tokenizer)
2025-08-08 10:17:40 +08:00
2025-08-11 14:10:38 +08:00
# 7. 配置训练参数 - 针对3080显卡优化
2025-08-08 10:17:40 +08:00
training_args = TrainingArguments(
output_dir=output_dir,
2025-08-11 10:29:12 +08:00
per_device_train_batch_size=2, # 减小batch size
2025-08-08 10:17:40 +08:00
gradient_accumulation_steps=4, # 增加梯度累积
logging_steps=10,
num_train_epochs=3, # 增加训练轮数以充分学习角色特征
save_steps=50,
2025-08-11 14:10:38 +08:00
learning_rate=5e-6, # 进一步降低学习率
2025-08-08 10:17:40 +08:00
warmup_ratio=0.1,
2025-08-11 14:10:38 +08:00
max_grad_norm=0.5, # 更严格的梯度裁剪
2025-08-08 10:17:40 +08:00
save_on_each_node=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": True},
dataloader_pin_memory=False, # 减少内存使用
remove_unused_columns=False,
report_to="none",
2025-08-11 11:21:03 +08:00
bf16=True, # 显式启用bf16以匹配模型加载类型
#fp16=False, # 确保fp16被禁用
2025-08-08 10:17:40 +08:00
save_total_limit=3, # 只保留最新的3个检查点
2025-08-11 14:10:38 +08:00
adam_epsilon=1e-8, # 增加数值稳定性
weight_decay=0.01, # 添加权重衰减
2025-08-08 10:17:40 +08:00
)
#添加swan监测
swanlab_callback = SwanLabCallback(
project = "QwenLora_Learn",
experiment_name="Qwen3-8B-LoRA-experiment"
)
swanlab.login(api_key="pAxFTROvv3aspmEijax46")
2025-08-11 12:01:40 +08:00
# 创建梯度监控回调
gradient_monitor = GradientMonitorCallback()
2025-08-11 14:10:38 +08:00
# 8. 创建训练器
2025-08-08 10:17:40 +08:00
trainer = Trainer(
model=model,
args=training_args,
2025-08-11 10:29:12 +08:00
train_dataset=train_preparedataset,
2025-08-08 10:17:40 +08:00
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
2025-08-11 12:01:40 +08:00
callbacks=[swanlab_callback, gradient_monitor] # 添加梯度监控回调
2025-08-08 10:17:40 +08:00
)
2025-08-11 14:10:38 +08:00
# 9. 开始训练
2025-08-08 10:17:40 +08:00
print("Starting training...")
2025-08-11 12:01:40 +08:00
logger.info("Starting training...")
try:
trainer.train()
logger.info("Training completed successfully!")
except Exception as e:
logger.error(f"Training failed with error: {e}")
import traceback
logger.error(traceback.format_exc())
raise
2025-08-08 10:17:40 +08:00
2025-08-11 14:10:38 +08:00
# 10. 保存最终模型
2025-08-08 10:17:40 +08:00
final_output_dir = os.path.join(output_dir, "final_model")
trainer.save_model(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
print(f"Training completed! Model saved to: {final_output_dir}")
return final_output_dir
def test_trained_model(model_path, lora_path):
"""测试训练后的模型"""
print("Testing trained model...")
# 加载基础模型
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# 加载LoRA权重
model = PeftModel.from_pretrained(model, lora_path)
# 测试对话
test_cases = [
{
"system": "你是克莱恩,一位神秘学专家和侦探。",
"user": "请告诉我一些关于神秘学的知识。"
},
{
"system": "你是阿兹克,经验丰富的神秘学导师。",
"user": "学生遇到了危险,你会给出什么建议?"
},
{
"system": "你是塔利姆,一个有礼貌的普通人,遇到了困难。",
"user": "你最近怎么样?"
}
]
for i, case in enumerate(test_cases):
messages = [
{"role": "system", "content": case["system"]},
{"role": "user", "content": case["user"]}
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
enable_thinking=False
).to('cuda')
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
# 检查非法值
if torch.isnan(probs).any():
print("概率张量包含NaN")
if torch.isinf(probs).any():
print("概率张量包含Inf")
if (probs < 0).any():
print("概率张量包含负数!")
outputs = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.8,
pad_token_id=tokenizer.eos_token_id
)
response = outputs[0][inputs['input_ids'].shape[1]:]
decoded_response = tokenizer.decode(response, skip_special_tokens=True)
print(f"\n--- 测试用例 {i+1} ---")
print(f"系统提示: {case['system']}")
print(f"用户输入: {case['user']}")
print(f"模型回复: {decoded_response}")
def main():
# 配置路径
2025-08-11 11:21:03 +08:00
model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径
2025-08-08 10:17:40 +08:00
data_path = './npc_dialogue_dataset.json' # 训练数据路径
output_dir = './output/NPC_Dialogue_LoRA' # 输出目录
2025-08-11 10:29:12 +08:00
# #####test
# final_model_path = os.path.join(output_dir, "final_model")
# test_trained_model(model_path, final_model_path)
2025-08-08 10:17:40 +08:00
# 确保数据文件存在
if not os.path.exists(data_path):
print(f"数据文件不存在: {data_path}")
print("请先运行 prepare_dialogue_data.py 生成训练数据")
return
try:
# 训练模型
final_model_path = train_lora_model(model_path, data_path, output_dir)
# 测试模型
test_trained_model(model_path, final_model_path)
except Exception as e:
print(f"训练过程中出现错误: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()