Project02/AITrain/iterative_lora_training.py

260 lines
8.1 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
基于评分数据的LoRA增量训练脚本 - 简化版
只保留核心的数据训练模型功能
"""
import json
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, TaskType
import os
from typing import List, Dict
import numpy as np
from torch.utils.data import Dataset
import swanlab
from swanlab.integration.transformers import SwanLabCallback
class DialogueDataset(Dataset):
"""对话数据集类"""
def __init__(self, samples: List[Dict], tokenizer, max_length: int = 512):
self.samples = samples
self.tokenizer = tokenizer
self.max_length = max_length
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
text = f"角色: {sample['speaker']}\n对话内容: {sample['content']}"
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': encoding['input_ids'].flatten().clone()
}
class IterativeLoRATrainer:
def __init__(self, config: Dict):
self.config = config
self.base_model_path = config['base_model_path']
self.training_data_path = config['training_data_path']
self.output_path = config['output_path']
# 训练配置
self.training_args = TrainingArguments(
output_dir=self.output_path,
overwrite_output_dir=True,
num_train_epochs=config.get('epochs', 3),
per_device_train_batch_size=config.get('batch_size', 2),
gradient_accumulation_steps=config.get('gradient_accumulation_steps', 8),
learning_rate=config.get('learning_rate', 5e-4),
weight_decay=config.get('weight_decay', 0.01),
warmup_steps=config.get('warmup_steps', 100),
logging_steps=config.get('logging_steps', 10),
save_steps=config.get('save_steps', 500),
bf16=config.get('bf16', True),
remove_unused_columns=False,
report_to="none"
)
# LoRA配置
self.lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=config.get('lora_r', 32),
lora_alpha=config.get('lora_alpha', 64),
lora_dropout=config.get('lora_dropout', 0.05),
target_modules=config.get('target_modules', ["q_proj", "v_proj", "k_proj", "o_proj"]),
bias="none"
)
torch.manual_seed(config.get('seed', 42))
np.random.seed(config.get('seed', 42))
def load_training_data(self):
"""加载训练数据"""
with open(self.training_data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data['data']
def prepare_training_samples(self, data):
"""准备训练样本"""
samples = []
quality_threshold = self.config.get('quality_threshold', 7.0)
for item in data:
overall_score = self._calculate_overall_score(item.get('scores', {}))
if overall_score >= quality_threshold:
sample = {
'speaker': item['speaker'],
'content': item['content'],
'quality_score': overall_score
}
samples.append(sample)
print(f"筛选出高质量样本: {len(samples)}")
return samples
def _calculate_overall_score(self, scores: Dict) -> float:
"""计算综合评分"""
if not scores:
return 0.0
weights = {
'coherence': 0.3,
'character_consistency': 0.25,
'naturalness': 0.25,
'information_density': 0.1,
'creativity': 0.1
}
total_score = 0.0
total_weight = 0.0
for key, weight in weights.items():
if key in scores:
total_score += scores[key] * weight
total_weight += weight
return total_score / total_weight if total_weight > 0 else 0.0
def train(self):
"""执行训练"""
print("开始LoRA增量训练...")
# 创建输出目录
os.makedirs(self.output_path, exist_ok=True)
# 加载模型和分词器
print(f"加载基础模型: {self.base_model_path}")
tokenizer = AutoTokenizer.from_pretrained(
self.base_model_path,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
self.base_model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
# 设置pad_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
# 应用LoRA
print("应用LoRA配置...")
model = get_peft_model(model, self.lora_config)
model.print_trainable_parameters()
# 加载和准备训练数据
training_data = self.load_training_data()
training_samples = self.prepare_training_samples(training_data)
if len(training_samples) == 0:
raise ValueError("没有找到符合条件的训练样本")
print(f"训练样本数量: {len(training_samples)}")
# 创建数据集
train_dataset = DialogueDataset(
training_samples,
tokenizer,
self.config.get('max_length', 512)
)
# 创建数据整理器
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
#添加swan监测
swanlab_callback = SwanLabCallback(
project = "QwenLora_Learn",
experiment_name="Qwen3-8B-LoRA-experiment"
)
swanlab.login(api_key="pAxFTROvv3aspmEijax46")
# 创建训练器
trainer = Trainer(
model=model,
args=self.training_args,
train_dataset=train_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
callbacks = [swanlab_callback]
)
# 开始训练
print("开始训练...")
trainer.train()
# 保存模型
print("保存训练完成的模型...")
final_output_dir = os.path.join(self.output_path, "final_model")
trainer.save_model(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
print(f"训练完成,模型保存到: {self.output_path}")
def main():
config = {
'base_model_path': '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B',
'training_data_path': './training_data/high_quality_dialogues_20250823_1819.json',
'output_path': './output/iterative_lora_simple',
# 训练参数
'epochs': 3,
'batch_size': 1,
'gradient_accumulation_steps': 16,
'learning_rate': 5e-6,
'weight_decay': 0.01,
'warmup_steps': 100,
'max_length': 1024,
'bf16': True,
# LoRA参数
'lora_r': 8,
'lora_alpha': 8,
'lora_dropout': 0.05,
'target_modules': ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
# 数据参数
'quality_threshold': 7.0,
'logging_steps': 5,
'save_steps': 100,
'seed': 42
}
trainer = IterativeLoRATrainer(config)
trainer.train()
if __name__ == '__main__':
main()