diff --git a/AITrain/iterative_lora_training.py b/AITrain/iterative_lora_training.py new file mode 100644 index 0000000..e12277d --- /dev/null +++ b/AITrain/iterative_lora_training.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +基于评分数据的LoRA增量训练脚本 - 简化版 +只保留核心的数据训练模型功能 +""" + +import json +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + TrainingArguments, + Trainer, + DataCollatorForLanguageModeling +) +from peft import LoraConfig, get_peft_model, TaskType +import os +from typing import List, Dict +import numpy as np + +from torch.utils.data import Dataset +import swanlab +from swanlab.integration.transformers import SwanLabCallback + +class DialogueDataset(Dataset): + """对话数据集类""" + + def __init__(self, samples: List[Dict], tokenizer, max_length: int = 512): + self.samples = samples + self.tokenizer = tokenizer + self.max_length = max_length + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + sample = self.samples[idx] + text = f"角色: {sample['speaker']}\n对话内容: {sample['content']}" + + encoding = self.tokenizer( + text, + truncation=True, + padding='max_length', + max_length=self.max_length, + return_tensors='pt' + ) + + return { + 'input_ids': encoding['input_ids'].flatten(), + 'attention_mask': encoding['attention_mask'].flatten(), + 'labels': encoding['input_ids'].flatten().clone() + } + +class IterativeLoRATrainer: + def __init__(self, config: Dict): + self.config = config + self.base_model_path = config['base_model_path'] + self.training_data_path = config['training_data_path'] + self.output_path = config['output_path'] + + # 训练配置 + self.training_args = TrainingArguments( + output_dir=self.output_path, + overwrite_output_dir=True, + num_train_epochs=config.get('epochs', 3), + per_device_train_batch_size=config.get('batch_size', 2), + gradient_accumulation_steps=config.get('gradient_accumulation_steps', 8), + learning_rate=config.get('learning_rate', 5e-4), + weight_decay=config.get('weight_decay', 0.01), + warmup_steps=config.get('warmup_steps', 100), + logging_steps=config.get('logging_steps', 10), + save_steps=config.get('save_steps', 500), + bf16=config.get('bf16', True), + remove_unused_columns=False, + report_to="none" + ) + + # LoRA配置 + self.lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=config.get('lora_r', 32), + lora_alpha=config.get('lora_alpha', 64), + lora_dropout=config.get('lora_dropout', 0.05), + target_modules=config.get('target_modules', ["q_proj", "v_proj", "k_proj", "o_proj"]), + bias="none" + ) + + torch.manual_seed(config.get('seed', 42)) + np.random.seed(config.get('seed', 42)) + + def load_training_data(self): + """加载训练数据""" + with open(self.training_data_path, 'r', encoding='utf-8') as f: + data = json.load(f) + return data['data'] + + def prepare_training_samples(self, data): + """准备训练样本""" + samples = [] + quality_threshold = self.config.get('quality_threshold', 7.0) + + for item in data: + overall_score = self._calculate_overall_score(item.get('scores', {})) + + if overall_score >= quality_threshold: + sample = { + 'speaker': item['speaker'], + 'content': item['content'], + 'quality_score': overall_score + } + samples.append(sample) + + print(f"筛选出高质量样本: {len(samples)}个") + return samples + + def _calculate_overall_score(self, scores: Dict) -> float: + """计算综合评分""" + if not scores: + return 0.0 + + weights = { + 'coherence': 0.3, + 'character_consistency': 0.25, + 'naturalness': 0.25, + 'information_density': 0.1, + 'creativity': 0.1 + } + + total_score = 0.0 + total_weight = 0.0 + + for key, weight in weights.items(): + if key in scores: + total_score += scores[key] * weight + total_weight += weight + + return total_score / total_weight if total_weight > 0 else 0.0 + + + def train(self): + """执行训练""" + print("开始LoRA增量训练...") + + # 创建输出目录 + os.makedirs(self.output_path, exist_ok=True) + + # 加载模型和分词器 + print(f"加载基础模型: {self.base_model_path}") + tokenizer = AutoTokenizer.from_pretrained( + self.base_model_path, + trust_remote_code=True + ) + + model = AutoModelForCausalLM.from_pretrained( + self.base_model_path, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True + ) + + # 设置pad_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + model.config.pad_token_id = tokenizer.eos_token_id + + # 应用LoRA + print("应用LoRA配置...") + model = get_peft_model(model, self.lora_config) + model.print_trainable_parameters() + + # 加载和准备训练数据 + training_data = self.load_training_data() + training_samples = self.prepare_training_samples(training_data) + + if len(training_samples) == 0: + raise ValueError("没有找到符合条件的训练样本") + + print(f"训练样本数量: {len(training_samples)}") + + # 创建数据集 + train_dataset = DialogueDataset( + training_samples, + tokenizer, + self.config.get('max_length', 512) + ) + + # 创建数据整理器 + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False + ) + + #添加swan监测 + swanlab_callback = SwanLabCallback( + project = "QwenLora_Learn", + experiment_name="Qwen3-8B-LoRA-experiment" + ) + swanlab.login(api_key="pAxFTROvv3aspmEijax46") + + # 创建训练器 + trainer = Trainer( + model=model, + args=self.training_args, + train_dataset=train_dataset, + data_collator=data_collator, + tokenizer=tokenizer, + callbacks = [swanlab_callback] + ) + + # 开始训练 + print("开始训练...") + trainer.train() + + # 保存模型 + print("保存训练完成的模型...") + final_output_dir = os.path.join(self.output_path, "final_model") + trainer.save_model(final_output_dir) + tokenizer.save_pretrained(final_output_dir) + + print(f"训练完成,模型保存到: {self.output_path}") + +def main(): + config = { + 'base_model_path': '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B', + 'training_data_path': './training_data/high_quality_dialogues_20250823_1819.json', + 'output_path': './output/iterative_lora_simple', + + # 训练参数 + 'epochs': 3, + 'batch_size': 1, + 'gradient_accumulation_steps': 16, + 'learning_rate': 5e-6, + 'weight_decay': 0.01, + 'warmup_steps': 100, + 'max_length': 1024, + 'bf16': True, + + # LoRA参数 + 'lora_r': 8, + 'lora_alpha': 8, + 'lora_dropout': 0.05, + 'target_modules': ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], + + # 数据参数 + 'quality_threshold': 7.0, + 'logging_steps': 5, + 'save_steps': 100, + 'seed': 42 + } + + trainer = IterativeLoRATrainer(config) + trainer.train() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/AITrain/main_controller.py b/AITrain/main_controller.py index 605387a..f02a6d0 100644 --- a/AITrain/main_controller.py +++ b/AITrain/main_controller.py @@ -141,7 +141,7 @@ def run_dialogue_system(enableScore: bool, useManualScoring: bool = False): # 检查模型路径 base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' - lora_model_path = './output/NPC_Dialogue_LoRA/final_model' + lora_model_path = './output/iterative_lora_simple/final_model' if not os.path.exists(base_model_path): print(f"✗ 基础模型路径不存在: {base_model_path}") @@ -529,175 +529,17 @@ def run_model_optimization(): conv_mgr = ConversationManager("./conversation_data/conversations.db") - # print("模型优化选项:") - - # print("1. 生成LoRA训练脚本") - # print("2. 执行增量训练") - - - # choice = input("请输入选择 (1-5): ").strip() - - # if choice == '1': # 生成LoRA训练脚本 - print("\n=== 生成LoRA训练脚本 ===") - - script_content = generate_lora_training_script() - script_path = "./scripts/iterative_lora_training.py" - - os.makedirs("./scripts", exist_ok=True) - with open(script_path, 'w', encoding='utf-8') as f: - f.write(script_content) - - print(f"✓ LoRA训练脚本已生成: {script_path}") - print("使用方法:") - print(" 1. 先运行训练数据生成 (选项8)") - print(" 2. 修改脚本中的路径配置") - print(f" 3. 运行: python {script_path}") - - # elif choice == '2': - # # 执行增量训练 - # print("\n=== 执行增量训练 ===") - - # # 检查训练数据 - # training_dir = "./training_data" - # if not os.path.exists(training_dir): - # print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)") - # return - - # training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')] - # if not training_files: - # print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)") - # return - - # print(f"找到训练数据文件:") - # for i, file in enumerate(training_files, 1): - # print(f" {i}. {file}") - - # file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip() - # try: - # selected_file = training_files[int(file_idx) - 1] - # training_file_path = os.path.join(training_dir, selected_file) - - # print(f"将使用训练文件: {selected_file}") - # print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源") - - # # 生成训练命令 - # training_command = generate_training_command(training_file_path) - # print(f"建议训练命令:") - # print(f" {training_command}") - - # # 可选:执行训练(需要用户确认) - # confirm = input("是否现在执行训练?(y/N): ").strip().lower() - # if confirm == 'y': - # print("开始增量训练...") - # # 这里可以添加实际的训练执行逻辑 - # print("⚠ 训练功能需要根据实际环境配置") - - # except (ValueError, IndexError): - # print("❌ 无效的文件选择") - - # else: - # print("❌ 无效选择") + print("\n=== 使用LoRA训练脚本 ===") + import iterative_lora_training + iterative_lora_training.main() + except Exception as e: print(f"✗ 模型优化失败: {e}") import traceback traceback.print_exc() -def generate_lora_training_script(): - """生成LoRA训练脚本""" - return '''#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -基于评分数据的LoRA增量训练脚本 -自动生成 - 请根据实际环境调整配置 -""" - -import json -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -from peft import LoraConfig, get_peft_model, TaskType -import os - -class IterativeLoRATrainer: - def __init__(self, base_model_path, training_data_path, output_path): - self.base_model_path = base_model_path - self.training_data_path = training_data_path - self.output_path = output_path - - # LoRA配置 - self.lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=16, # LoRA rank - lora_alpha=32, - lora_dropout=0.1, - target_modules=["q_proj", "v_proj", "k_proj", "o_proj"] - ) - - def load_training_data(self): - """加载训练数据""" - with open(self.training_data_path, 'r', encoding='utf-8') as f: - data = json.load(f) - return data['data'] - - def prepare_training_samples(self, data): - """准备训练样本""" - samples = [] - - for item in data: - if item.get('label') == 'positive' or item.get('overall_score', 0) >= 8.0: - # 高质量样本 - sample = { - 'input': f"角色: {item['speaker']}\\n请生成高质量对话:", - 'output': item['content'], - 'quality_score': item.get('overall_score', 8.0) - } - samples.append(sample) - - return samples - - def train(self): - """执行训练""" - print("开始LoRA增量训练...") - - # 加载模型和分词器 - tokenizer = AutoTokenizer.from_pretrained(self.base_model_path) - model = AutoModelForCausalLM.from_pretrained( - self.base_model_path, - torch_dtype=torch.bfloat16, - device_map="auto" - ) - - # 应用LoRA - model = get_peft_model(model, self.lora_config) - - # 加载训练数据 - training_data = self.load_training_data() - training_samples = self.prepare_training_samples(training_data) - - print(f"训练样本数量: {len(training_samples)}") - - # 这里添加实际的训练循环 - # 建议使用transformers的Trainer或自定义训练循环 - - # 保存模型 - model.save_pretrained(self.output_path) - tokenizer.save_pretrained(self.output_path) - - print(f"训练完成,模型保存到: {self.output_path}") - -if __name__ == '__main__': - # 配置参数 - 请根据实际情况修改 - BASE_MODEL_PATH = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B' - TRAINING_DATA_PATH = './training_data/high_quality_dialogues_latest.json' - OUTPUT_PATH = './output/iterative_lora_v2' - - trainer = IterativeLoRATrainer(BASE_MODEL_PATH, TRAINING_DATA_PATH, OUTPUT_PATH) - trainer.train() -''' - - def generate_training_command(training_file_path): @@ -878,7 +720,7 @@ def main(): print("0. 退出") print("="*50) - choice = input("请输入选择 (0-11): ").strip() + choice = input("请输入选择 (0-9): ").strip() if choice == '0': print("\n感谢使用双AI角色对话系统!") diff --git a/AITrain/npc_dialogue_generator.py b/AITrain/npc_dialogue_generator.py index 1289128..925c4a4 100644 --- a/AITrain/npc_dialogue_generator.py +++ b/AITrain/npc_dialogue_generator.py @@ -133,9 +133,9 @@ class NPCDialogueGenerator: ) # 如果有LoRA模型,则加载 - # if self.lora_model_path: - # print(f"Loading LoRA weights from: {self.lora_model_path}") - # self.model = PeftModel.from_pretrained(self.model, self.lora_model_path) + if self.lora_model_path: + print(f"Loading LoRA weights from: {self.lora_model_path}") + self.model = PeftModel.from_pretrained(self.model, self.lora_model_path) def generate_character_dialogue( self,