261 lines
8.1 KiB
Python
261 lines
8.1 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
基于评分数据的LoRA增量训练脚本 - 简化版
|
|
只保留核心的数据训练模型功能
|
|
"""
|
|
|
|
import time
|
|
import json
|
|
import torch
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
TrainingArguments,
|
|
Trainer,
|
|
DataCollatorForLanguageModeling
|
|
)
|
|
from peft import LoraConfig, get_peft_model, TaskType
|
|
import os
|
|
from typing import List, Dict
|
|
import numpy as np
|
|
|
|
from torch.utils.data import Dataset
|
|
import swanlab
|
|
from swanlab.integration.transformers import SwanLabCallback
|
|
|
|
class DialogueDataset(Dataset):
|
|
"""对话数据集类"""
|
|
|
|
def __init__(self, samples: List[Dict], tokenizer, max_length: int = 512):
|
|
self.samples = samples
|
|
self.tokenizer = tokenizer
|
|
self.max_length = max_length
|
|
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
sample = self.samples[idx]
|
|
text = f"角色: {sample['speaker']}\n对话内容: {sample['content']}"
|
|
|
|
encoding = self.tokenizer(
|
|
text,
|
|
truncation=True,
|
|
padding='max_length',
|
|
max_length=self.max_length,
|
|
return_tensors='pt'
|
|
)
|
|
|
|
return {
|
|
'input_ids': encoding['input_ids'].flatten(),
|
|
'attention_mask': encoding['attention_mask'].flatten(),
|
|
'labels': encoding['input_ids'].flatten().clone()
|
|
}
|
|
|
|
class IterativeLoRATrainer:
|
|
def __init__(self, config: Dict):
|
|
self.config = config
|
|
self.base_model_path = config['base_model_path']
|
|
self.training_data_path = config['training_data_path']
|
|
self.output_path = config['output_path']
|
|
|
|
# 训练配置
|
|
self.training_args = TrainingArguments(
|
|
output_dir=self.output_path,
|
|
overwrite_output_dir=True,
|
|
num_train_epochs=config.get('epochs', 3),
|
|
per_device_train_batch_size=config.get('batch_size', 2),
|
|
gradient_accumulation_steps=config.get('gradient_accumulation_steps', 8),
|
|
learning_rate=config.get('learning_rate', 5e-4),
|
|
weight_decay=config.get('weight_decay', 0.01),
|
|
warmup_steps=config.get('warmup_steps', 100),
|
|
logging_steps=config.get('logging_steps', 10),
|
|
save_steps=config.get('save_steps', 500),
|
|
bf16=config.get('bf16', True),
|
|
remove_unused_columns=False,
|
|
report_to="none"
|
|
)
|
|
|
|
# LoRA配置
|
|
self.lora_config = LoraConfig(
|
|
task_type=TaskType.CAUSAL_LM,
|
|
inference_mode=False,
|
|
r=config.get('lora_r', 32),
|
|
lora_alpha=config.get('lora_alpha', 64),
|
|
lora_dropout=config.get('lora_dropout', 0.05),
|
|
target_modules=config.get('target_modules', ["q_proj", "v_proj", "k_proj", "o_proj"]),
|
|
bias="none"
|
|
)
|
|
|
|
torch.manual_seed(config.get('seed', 42))
|
|
np.random.seed(config.get('seed', 42))
|
|
|
|
def load_training_data(self):
|
|
"""加载训练数据"""
|
|
with open(self.training_data_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
return data['data']
|
|
|
|
def prepare_training_samples(self, data):
|
|
"""准备训练样本"""
|
|
samples = []
|
|
quality_threshold = self.config.get('quality_threshold', 7.0)
|
|
|
|
for item in data:
|
|
overall_score = self._calculate_overall_score(item.get('scores', {}))
|
|
|
|
if overall_score >= quality_threshold:
|
|
sample = {
|
|
'speaker': item['speaker'],
|
|
'content': item['content'],
|
|
'quality_score': overall_score
|
|
}
|
|
samples.append(sample)
|
|
|
|
print(f"筛选出高质量样本: {len(samples)}个")
|
|
return samples
|
|
|
|
def _calculate_overall_score(self, scores: Dict) -> float:
|
|
"""计算综合评分"""
|
|
if not scores:
|
|
return 0.0
|
|
|
|
weights = {
|
|
'coherence': 0.3,
|
|
'character_consistency': 0.25,
|
|
'naturalness': 0.25,
|
|
'information_density': 0.1,
|
|
'creativity': 0.1
|
|
}
|
|
|
|
total_score = 0.0
|
|
total_weight = 0.0
|
|
|
|
for key, weight in weights.items():
|
|
if key in scores:
|
|
total_score += scores[key] * weight
|
|
total_weight += weight
|
|
|
|
return total_score / total_weight if total_weight > 0 else 0.0
|
|
|
|
|
|
def train(self):
|
|
"""执行训练"""
|
|
print("开始LoRA增量训练...")
|
|
|
|
# 创建输出目录
|
|
os.makedirs(self.output_path, exist_ok=True)
|
|
|
|
# 加载模型和分词器
|
|
print(f"加载基础模型: {self.base_model_path}")
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
self.base_model_path,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
self.base_model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
trust_remote_code=True
|
|
)
|
|
|
|
# 设置pad_token
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
model.config.pad_token_id = tokenizer.eos_token_id
|
|
|
|
# 应用LoRA
|
|
print("应用LoRA配置...")
|
|
model = get_peft_model(model, self.lora_config)
|
|
model.print_trainable_parameters()
|
|
|
|
# 加载和准备训练数据
|
|
training_data = self.load_training_data()
|
|
training_samples = self.prepare_training_samples(training_data)
|
|
|
|
if len(training_samples) == 0:
|
|
raise ValueError("没有找到符合条件的训练样本")
|
|
|
|
print(f"训练样本数量: {len(training_samples)}")
|
|
|
|
# 创建数据集
|
|
train_dataset = DialogueDataset(
|
|
training_samples,
|
|
tokenizer,
|
|
self.config.get('max_length', 512)
|
|
)
|
|
|
|
# 创建数据整理器
|
|
data_collator = DataCollatorForLanguageModeling(
|
|
tokenizer=tokenizer,
|
|
mlm=False
|
|
)
|
|
|
|
#添加swan监测
|
|
swanlab_callback = SwanLabCallback(
|
|
project = "QwenLora_Learn",
|
|
experiment_name="Qwen3-8B-LoRA-experiment"
|
|
)
|
|
swanlab.login(api_key="pAxFTROvv3aspmEijax46")
|
|
|
|
# 创建训练器
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=self.training_args,
|
|
train_dataset=train_dataset,
|
|
data_collator=data_collator,
|
|
tokenizer=tokenizer,
|
|
callbacks = [swanlab_callback]
|
|
)
|
|
|
|
# 开始训练
|
|
print("开始训练...")
|
|
trainer.train()
|
|
|
|
# 保存模型
|
|
print("保存训练完成的模型...")
|
|
final_output_dir = os.path.join(self.output_path, "final_model")
|
|
trainer.save_model(final_output_dir)
|
|
tokenizer.save_pretrained(final_output_dir)
|
|
|
|
print(f"训练完成,模型保存到: {self.output_path}")
|
|
time.sleep(3)
|
|
def main():
|
|
config = {
|
|
'base_model_path': '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B',
|
|
'training_data_path': './training_data/high_quality_dialogues.json',
|
|
'output_path': './output/iterative_lora_simple',
|
|
|
|
# 训练参数
|
|
'epochs': 3,
|
|
'batch_size': 1,
|
|
'gradient_accumulation_steps': 16,
|
|
'learning_rate': 5e-6,
|
|
'weight_decay': 0.01,
|
|
'warmup_steps': 100,
|
|
'max_length': 1024,
|
|
'bf16': True,
|
|
|
|
# LoRA参数
|
|
'lora_r': 8,
|
|
'lora_alpha': 8,
|
|
'lora_dropout': 0.05,
|
|
'target_modules': ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
|
|
|
|
# 数据参数
|
|
'quality_threshold': 7.0,
|
|
'logging_steps': 5,
|
|
'save_steps': 100,
|
|
'seed': 42
|
|
}
|
|
|
|
trainer = IterativeLoRATrainer(config)
|
|
trainer.train()
|
|
|
|
if __name__ == '__main__':
|
|
main() |