#!/usr/bin/env python # -*- coding: utf-8 -*- ''' 双AI角色对话系统主控制程序 完整的工作流程:PDF处理 -> 角色加载 -> RAG对话 -> 历史记录 ''' import os import sys import shutil from typing import List, Dict import json def check_dependencies(): """检查依赖库""" missing_deps = [] try: import PyPDF2 except ImportError: missing_deps.append("PyPDF2") try: import pymupdf print("✓ pymupdf 可用") except ImportError: print("⚠ pymupdf 不可用,将使用 PyPDF2") try: import sentence_transformers import faiss print("✓ 向量化功能可用") except ImportError: print("⚠ 向量化功能不可用,将使用文本匹配") if missing_deps: print(f"✗ 缺少依赖库: {', '.join(missing_deps)}") print("请运行: pip install PyPDF2 sentence-transformers faiss-cpu") return False return True def setup_directories(): """设置项目目录结构""" directories = [ "./knowledge_base", "./characters", "./worldview", "./rag_knowledge", "./conversation_data" ] for dir_path in directories: os.makedirs(dir_path, exist_ok=True) print(f"✓ 目录就绪: {dir_path}") def copy_demo_files(): """复制演示文档到知识库目录""" file_mappings = [ ("./worldview/worldview_template_coc.json", "./knowledge_base/worldview_template_coc.json"), ("./characters/character_template_detective.json", "./knowledge_base/character_template_detective.json"), ("./characters/character_template_professor.json", "./knowledge_base/character_template_professor.json") ] for source, target in file_mappings: if os.path.exists(source): shutil.copy2(source, target) print(f"✓ 复制文档: {os.path.basename(target)}") def process_pdf_workflow(): """PDF处理工作流""" print("\n" + "="*60) print("PDF世界观文档处理") print("="*60) from pdf_to_rag_processor import PDFToRAGProcessor pdf_path = input("请输入PDF文件路径 (例: ./coc.pdf): ").strip() if not os.path.exists(pdf_path): print(f"✗ 文件不存在: {pdf_path}") return False try: processor = PDFToRAGProcessor() result = processor.process_pdf_to_rag(pdf_path, "./rag_knowledge") print(f"\n✓ PDF处理完成!") print(f" - 文档块数: {result['chunks_count']}") print(f" - 概念数: {result['concepts_count']}") print(f" - 向量索引: {'启用' if result['vector_enabled'] else '未启用'}") return True except Exception as e: print(f"✗ PDF处理失败: {e}") return False def show_character_info(): """显示角色信息""" print("\n" + "="*60) print("角色设定信息") print("="*60) knowledge_dir = "./knowledge_base" character_files = [f for f in os.listdir(knowledge_dir) if f.startswith('character') and f.endswith('.json')] for char_file in character_files: try: with open(os.path.join(knowledge_dir, char_file), 'r', encoding='utf-8') as f: char_data = json.load(f) name = char_data.get('character_name', '未知') occupation = char_data.get('basic_info', {}).get('occupation', '未知') traits = char_data.get('personality', {}).get('core_traits', []) print(f"\n角色: {name}") print(f" 职业: {occupation}") print(f" 特点: {', '.join(traits[:3])}") except Exception as e: print(f"✗ 读取角色文件失败: {char_file} - {e}") def run_dialogue_system(): """运行双AI对话系统""" print("\n" + "="*60) print("启动双AI角色对话系统") print("="*60) try: # 直接启动双模型对话 print("\n正在初始化双模型对话系统...") from dual_ai_dialogue_system import RAGKnowledgeBase, ConversationManager, DualAIDialogueEngine from npc_dialogue_generator import DualModelDialogueGenerator # 初始化组件 kb = RAGKnowledgeBase("./knowledge_base") conv_mgr = ConversationManager("./conversation_data/conversations.db") # 检查模型路径 base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' lora_model_path = './output/NPC_Dialogue_LoRA/final_model' if not os.path.exists(base_model_path): print(f"✗ 基础模型路径不存在: {base_model_path}") print("请修改 main_controller.py 中的模型路径") return if not os.path.exists(lora_model_path): lora_model_path = None print("⚠ LoRA模型不存在,使用基础模型") # 检查角色数据 if not hasattr(kb, 'character_data') or len(kb.character_data) < 2: print("✗ 角色数据不足,无法创建双模型对话系统") print("请确保knowledge_base目录中有至少两个角色文件") return # 获取前两个角色 character_names = list(kb.character_data.keys())[:2] char1_name = character_names[0] char2_name = character_names[1] print(f"✓ 使用角色: {char1_name} 和 {char2_name}") # 配置两个角色的模型 character1_config = { "name": char1_name, "lora_path": lora_model_path, "character_data": kb.character_data[char1_name] } character2_config = { "name": char2_name, "lora_path": lora_model_path, "character_data": kb.character_data[char2_name] } # 创建双模型对话生成器 print("正在初始化双模型对话生成器...") dual_generator = DualModelDialogueGenerator( base_model_path, character1_config, character2_config ) # 创建对话引擎(启用评分功能) dialogue_engine = DualAIDialogueEngine( kb, conv_mgr, dual_generator, enable_scoring=True, base_model_path=base_model_path ) # 创建对话会话 characters = [char1_name, char2_name] worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观' session_id = conv_mgr.create_session(characters, worldview) print(f"✓ 创建对话会话: {session_id}") # 交互式对话循环 print(f"\n=== 双AI模型对话系统 ===") print(f"角色: {char1_name} vs {char2_name}") print(f"世界观: {worldview}") print("输入 'quit' 退出对话") print("-" * 50) while True: try: # 获取用户输入 user_input = input("\n请输入对话主题或指令: ").strip() if user_input.lower() == 'quit': print("退出双AI对话系统") break if not user_input: print("请输入有效的对话主题") continue # 询问对话轮数 turns_input = input("请输入对话轮数 (默认4): ").strip() turns = int(turns_input) if turns_input.isdigit() else 4 # 询问历史上下文设置 history_input = input("使用历史对话轮数 (默认2): ").strip() history_count = int(history_input) if history_input.isdigit() else 2 context_input = input("使用上下文信息数量 (默认10): ").strip() context_info_count = int(context_input) if context_input.isdigit() else 10 print(f"\n开始对话 - 主题: {user_input}") print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}") print("-" * 50) # 运行双模型对话 dialogue_engine.run_dual_model_conversation( session_id, user_input, turns, history_count, context_info_count ) print("-" * 50) print("对话完成!") except KeyboardInterrupt: print("\n\n用户中断对话") break except Exception as e: print(f"对话过程中出现错误: {e}") import traceback traceback.print_exc() except Exception as e: print(f"✗ 对话系统启动失败: {e}") import traceback traceback.print_exc() def create_demo_scenario(): """创建演示场景""" print("\n创建演示对话场景...") try: from dual_ai_dialogue_system import RAGKnowledgeBase, ConversationManager, DualAIDialogueEngine from npc_dialogue_generator import NPCDialogueGenerator # 初始化组件 kb = RAGKnowledgeBase("./knowledge_base") conv_mgr = ConversationManager("./conversation_data/demo_conversations.db") # 检查模型路径 base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B' lora_model_path = './output/NPC_Dialogue_LoRA/final_model' if not os.path.exists(base_model_path): print(f"✗ 基础模型路径不存在: {base_model_path}") print("请修改 main_controller.py 中的模型路径") return if not os.path.exists(lora_model_path): lora_model_path = None print("⚠ LoRA模型不存在,使用基础模型") llm_generator = NPCDialogueGenerator(base_model_path, lora_model_path, kb.character_data) dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator) # 创建演示对话 characters = ["维多利亚·布莱克伍德", "阿奇博尔德·韦恩"] worldview = "克苏鲁的呼唤" session_id = conv_mgr.create_session(characters, worldview) print(f"✓ 创建演示会话: {session_id}") # 运行几轮对话 topic = "最近发生的神秘事件" print(f"\n开始演示对话 - 主题: {topic}") print("-" * 40) # 演示不同的历史上下文设置 # print("演示1: 使用默认上下文设置(历史3轮,信息2个)") # dialogue_engine.run_conversation_turn(session_id, characters, 6, topic) session_id = conv_mgr.create_session(characters, worldview) print(f"✓ 创建演示会话: {session_id}") print("\n演示3: 使用最少历史上下文(历史1轮,信息1个)") dialogue_engine.run_conversation_turn(session_id, characters, 6, topic, 1, 10) session_id = conv_mgr.create_session(characters, worldview) print(f"✓ 创建演示会话: {session_id}") print("\n演示2: 使用更多历史上下文(历史10轮,信息10个)") dialogue_engine.run_conversation_turn(session_id, characters, 6, topic, 5, 10) print(f"\n✓ 演示完成!会话ID: {session_id}") print("你可以通过主对话系统继续这个对话") except Exception as e: print(f"✗ 演示场景创建失败: {e}") import traceback traceback.print_exc() def analyze_model_performance(): """分析模型性能""" print("\n" + "="*60) print("模型性能分析") print("="*60) try: from dual_ai_dialogue_system import ConversationManager import sqlite3 import json from datetime import datetime, timedelta conv_mgr = ConversationManager("./conversation_data/conversations.db") with sqlite3.connect(conv_mgr.db_path) as conn: print("\n1. 总体性能趋势分析:") # 按时间段分析性能趋势 cursor = conn.execute(""" SELECT DATE(timestamp) as date, COUNT(*) as dialogue_count, AVG(dialogue_score) as avg_score, AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate FROM dialogue_turns WHERE dialogue_score > 0 AND timestamp >= datetime('now', '-7 days') GROUP BY DATE(timestamp) ORDER BY date DESC """) trend_data = cursor.fetchall() if trend_data: print(f" 最近7天性能趋势:") for date, count, avg_score, hq_rate in trend_data: print(f" {date}: 平均{avg_score:.2f}分 ({count}轮对话, {hq_rate*100:.1f}%高质量)") else: print(" 暂无足够数据进行趋势分析") print("\n2. 维度问题分析:") # 分析各维度的问题 cursor = conn.execute(""" SELECT score_details FROM dialogue_turns WHERE dialogue_score > 0 AND score_details != '{}' ORDER BY timestamp DESC LIMIT 100 """) dimension_scores = { 'coherence': [], 'character_consistency': [], 'naturalness': [], 'information_density': [], 'creativity': [] } for (score_details,) in cursor.fetchall(): try: scores = json.loads(score_details) for dim, score in scores.items(): if dim in dimension_scores: dimension_scores[dim].append(float(score)) except: continue dimension_names = { 'coherence': '连贯性', 'character_consistency': '角色一致性', 'naturalness': '自然度', 'information_density': '信息密度', 'creativity': '创意性' } weak_dimensions = [] for dim, scores in dimension_scores.items(): if scores: avg_score = sum(scores) / len(scores) print(f" {dimension_names[dim]}: 平均{avg_score:.2f}分 ({len(scores)}个样本)") if avg_score < 7.0: weak_dimensions.append(dim) if weak_dimensions: print(f"\n ⚠ 发现薄弱维度: {[dimension_names[d] for d in weak_dimensions]}") print(" 建议进行针对性优化训练") print("\n3. 角色表现分析:") # 分析不同角色的表现 cursor = conn.execute(""" SELECT speaker, COUNT(*) as dialogue_count, AVG(dialogue_score) as avg_score, MIN(dialogue_score) as min_score, MAX(dialogue_score) as max_score, AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate FROM dialogue_turns WHERE dialogue_score > 0 GROUP BY speaker ORDER BY avg_score DESC """) character_performance = cursor.fetchall() if character_performance: print(" 角色表现排名:") for i, (speaker, count, avg, min_s, max_s, hq_rate) in enumerate(character_performance, 1): status = "✓" if avg >= 7.5 else "⚠" if avg >= 6.5 else "✗" print(f" {i}. {speaker} {status}") print(f" 平均{avg:.2f}分 (范围{min_s:.1f}-{max_s:.1f}, {hq_rate*100:.1f}%高质量, {count}轮)") print("\n4. 问题模式识别:") # 识别低分对话的常见问题 cursor = conn.execute(""" SELECT content, dialogue_score, score_feedback FROM dialogue_turns WHERE dialogue_score > 0 AND dialogue_score < 6.0 ORDER BY dialogue_score ASC LIMIT 5 """) low_score_examples = cursor.fetchall() if low_score_examples: print(" 低分对话示例:") for i, (content, score, feedback) in enumerate(low_score_examples, 1): print(f" {i}. 分数{score:.1f}: {content[:50]}...") if feedback: print(f" 问题: {feedback[:80]}...") else: print(" 暂无低分对话样本") print("\n5. 优化建议:") # 生成优化建议 suggestions = [] if weak_dimensions: if 'character_consistency' in weak_dimensions: suggestions.append("• 加强角色设定训练,增加角色特征描述的权重") if 'creativity' in weak_dimensions: suggestions.append("• 增加创意性训练数据,提高对话的趣味性") if 'coherence' in weak_dimensions: suggestions.append("• 优化上下文理解,加强对话逻辑连贯性") if 'naturalness' in weak_dimensions: suggestions.append("• 增加自然语言训练,改善表达流畅度") if 'information_density' in weak_dimensions: suggestions.append("• 优化信息组织,避免冗余表达") # 检查是否需要数据收集 cursor = conn.execute("SELECT COUNT(*) FROM dialogue_turns WHERE dialogue_score > 0") total_scored = cursor.fetchone()[0] if total_scored < 50: suggestions.append("• 需要收集更多评分数据以进行准确分析") if total_scored >= 100: suggestions.append("• 数据量充足,建议开始模型迭代优化") if suggestions: for suggestion in suggestions: print(f" {suggestion}") else: print(" 当前性能表现良好,继续保持!") except Exception as e: print(f"✗ 性能分析失败: {e}") import traceback traceback.print_exc() def generate_training_dataset(): """生成训练数据集""" print("\n" + "="*60) print("生成训练数据集") print("="*60) try: from dual_ai_dialogue_system import ConversationManager import sqlite3 import json import os from datetime import datetime conv_mgr = ConversationManager("./conversation_data/conversations.db") # 创建输出目录 output_dir = "./training_data" os.makedirs(output_dir, exist_ok=True) print("请选择训练数据生成类型:") print("1. 高质量对话数据集 (分数≥8.0)") print("2. 问题对话改进数据集 (分数<6.0)") print("3. 角色一致性训练集") print("4. 创意性增强训练集") print("5. 完整对话质量数据集") choice = input("请输入选择 (1-5): ").strip() with sqlite3.connect(conv_mgr.db_path) as conn: training_data = [] if choice == '1': # 高质量对话数据集 print("\n生成高质量对话数据集...") cursor = conn.execute(""" SELECT speaker, content, score_details, score_feedback FROM dialogue_turns WHERE dialogue_score >= 8.0 ORDER BY dialogue_score DESC LIMIT 200 """) for speaker, content, score_details, feedback in cursor.fetchall(): training_data.append({ "type": "high_quality", "speaker": speaker, "content": content, "scores": json.loads(score_details) if score_details else {}, "feedback": feedback, "label": "positive" }) output_file = f"{output_dir}/high_quality_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json" elif choice == '2': # 问题对话改进数据集 print("\n生成问题对话改进数据集...") cursor = conn.execute(""" SELECT speaker, content, score_details, score_feedback FROM dialogue_turns WHERE dialogue_score < 6.0 AND dialogue_score > 0 ORDER BY dialogue_score ASC LIMIT 100 """) for speaker, content, score_details, feedback in cursor.fetchall(): # 为每个低分对话生成改进指导 improvement_prompt = generate_improvement_prompt(content, feedback) training_data.append({ "type": "improvement", "speaker": speaker, "original_content": content, "scores": json.loads(score_details) if score_details else {}, "problems": feedback, "improvement_prompt": improvement_prompt, "label": "needs_improvement" }) output_file = f"{output_dir}/improvement_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json" elif choice == '3': # 角色一致性训练集 print("\n生成角色一致性训练集...") cursor = conn.execute(""" SELECT speaker, content, score_details FROM dialogue_turns WHERE dialogue_score > 0 ORDER BY json_extract(score_details, '$.character_consistency') DESC LIMIT 150 """) for speaker, content, score_details in cursor.fetchall(): scores = json.loads(score_details) if score_details else {} char_consistency = scores.get('character_consistency', 0) training_data.append({ "type": "character_consistency", "speaker": speaker, "content": content, "character_consistency_score": char_consistency, "label": "high_consistency" if char_consistency >= 8.0 else "medium_consistency" }) output_file = f"{output_dir}/character_consistency_{datetime.now().strftime('%Y%m%d_%H%M')}.json" elif choice == '4': # 创意性增强训练集 print("\n生成创意性增强训练集...") cursor = conn.execute(""" SELECT speaker, content, score_details FROM dialogue_turns WHERE dialogue_score > 0 ORDER BY json_extract(score_details, '$.creativity') DESC LIMIT 150 """) for speaker, content, score_details in cursor.fetchall(): scores = json.loads(score_details) if score_details else {} creativity = scores.get('creativity', 0) training_data.append({ "type": "creativity", "speaker": speaker, "content": content, "creativity_score": creativity, "label": "high_creativity" if creativity >= 8.0 else "medium_creativity" }) output_file = f"{output_dir}/creativity_enhancement_{datetime.now().strftime('%Y%m%d_%H%M')}.json" elif choice == '5': # 完整对话质量数据集 print("\n生成完整对话质量数据集...") cursor = conn.execute(""" SELECT speaker, content, dialogue_score, score_details, score_feedback FROM dialogue_turns WHERE dialogue_score > 0 ORDER BY timestamp DESC LIMIT 300 """) for speaker, content, score, score_details, feedback in cursor.fetchall(): training_data.append({ "type": "complete_dataset", "speaker": speaker, "content": content, "overall_score": score, "dimension_scores": json.loads(score_details) if score_details else {}, "feedback": feedback, "quality_label": get_quality_label(score) }) output_file = f"{output_dir}/complete_quality_dataset_{datetime.now().strftime('%Y%m%d_%H%M')}.json" else: print("❌ 无效选择") return if training_data: # 保存训练数据 with open(output_file, 'w', encoding='utf-8') as f: json.dump({ "metadata": { "created_at": datetime.now().isoformat(), "total_samples": len(training_data), "data_type": choice, "source": "dialogue_scoring_system" }, "data": training_data }, f, ensure_ascii=False, indent=2) print(f"\n✓ 训练数据集生成成功!") print(f" - 文件路径: {output_file}") print(f" - 数据条数: {len(training_data)}") print(f" - 数据类型: {get_dataset_description(choice)}") # 生成数据集统计信息 generate_dataset_statistics(training_data, choice) else: print("❌ 未找到符合条件的数据") except Exception as e: print(f"✗ 训练数据集生成失败: {e}") import traceback traceback.print_exc() def generate_improvement_prompt(content, feedback): """生成改进提示""" return f"""原对话: {content} 问题分析: {feedback} 改进要求: 1. 保持角色特征和设定 2. 增强对话的逻辑性和连贯性 3. 提升表达的自然度 4. 增加信息密度,避免冗余 5. 适当增加创意元素 请生成一个改进版本的对话。""" def get_quality_label(score): """根据分数获取质量标签""" if score >= 9.0: return "excellent" elif score >= 8.0: return "good" elif score >= 7.0: return "average" elif score >= 6.0: return "below_average" else: return "poor" def get_dataset_description(choice): """获取数据集描述""" descriptions = { '1': "高质量对话数据集", '2': "问题对话改进数据集", '3': "角色一致性训练集", '4': "创意性增强训练集", '5': "完整对话质量数据集" } return descriptions.get(choice, "未知类型") def generate_dataset_statistics(training_data, data_type): """生成数据集统计信息""" print(f"\n数据集统计信息:") if data_type == '1': # 高质量数据集 speakers = {} for item in training_data: speaker = item['speaker'] speakers[speaker] = speakers.get(speaker, 0) + 1 print(f" - 角色分布:") for speaker, count in speakers.items(): print(f" • {speaker}: {count}条对话") elif data_type == '5': # 完整数据集 quality_dist = {} for item in training_data: label = item['quality_label'] quality_dist[label] = quality_dist.get(label, 0) + 1 print(f" - 质量分布:") for label, count in quality_dist.items(): print(f" • {label}: {count}条对话") def run_model_optimization(): """运行模型迭代优化""" print("\n" + "="*60) print("模型迭代优化") print("="*60) try: from dual_ai_dialogue_system import ConversationManager import sqlite3 import json import os from datetime import datetime conv_mgr = ConversationManager("./conversation_data/conversations.db") print("模型优化选项:") print("1. 分析优化需求") print("2. 生成LoRA训练脚本") print("3. 创建提示优化配置") print("4. 执行增量训练") print("5. 性能对比验证") choice = input("请输入选择 (1-5): ").strip() if choice == '1': # 分析优化需求 print("\n=== 优化需求分析 ===") with sqlite3.connect(conv_mgr.db_path) as conn: # 获取性能数据 cursor = conn.execute(""" SELECT COUNT(*) as total, AVG(dialogue_score) as avg_score, AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate FROM dialogue_turns WHERE dialogue_score > 0 """) total, avg_score, hq_rate = cursor.fetchone() print(f"当前性能指标:") print(f" - 总评分样本: {total}") print(f" - 平均分数: {avg_score:.2f}") print(f" - 高质量率: {hq_rate*100:.1f}%") # 分析优化潜力 optimization_needs = [] if avg_score < 7.0: optimization_needs.append("整体质量需要提升") if hq_rate < 0.6: optimization_needs.append("高质量对话比例偏低") # 分析各维度表现 cursor = conn.execute(""" SELECT score_details FROM dialogue_turns WHERE dialogue_score > 0 AND score_details != '{}' ORDER BY timestamp DESC LIMIT 100 """) dim_scores = {'coherence': [], 'character_consistency': [], 'naturalness': [], 'information_density': [], 'creativity': []} for (score_details,) in cursor.fetchall(): try: scores = json.loads(score_details) for dim, score in scores.items(): if dim in dim_scores: dim_scores[dim].append(float(score)) except: continue weak_dimensions = [] print(f"\n维度分析:") for dim, scores in dim_scores.items(): if scores: avg = sum(scores) / len(scores) print(f" - {dim}: {avg:.2f}分") if avg < 7.0: weak_dimensions.append(dim) if weak_dimensions: optimization_needs.append(f"薄弱维度: {weak_dimensions}") print(f"\n优化建议:") if optimization_needs: for i, need in enumerate(optimization_needs, 1): print(f" {i}. {need}") else: print(" 当前模型表现良好,可考虑细微调优") elif choice == '2': # 生成LoRA训练脚本 print("\n=== 生成LoRA训练脚本 ===") script_content = generate_lora_training_script() script_path = "./scripts/iterative_lora_training.py" os.makedirs("./scripts", exist_ok=True) with open(script_path, 'w', encoding='utf-8') as f: f.write(script_content) print(f"✓ LoRA训练脚本已生成: {script_path}") print("使用方法:") print(" 1. 先运行训练数据生成 (选项8)") print(" 2. 修改脚本中的路径配置") print(f" 3. 运行: python {script_path}") elif choice == '3': # 创建提示优化配置 print("\n=== 创建提示优化配置 ===") config = generate_prompt_optimization_config() config_path = "./config/prompt_optimization.json" os.makedirs("./config", exist_ok=True) with open(config_path, 'w', encoding='utf-8') as f: json.dump(config, f, ensure_ascii=False, indent=2) print(f"✓ 提示优化配置已生成: {config_path}") print("配置包含:") print(" - 动态提示调整规则") print(" - 质量阈值设置") print(" - 生成参数优化") elif choice == '4': # 执行增量训练 print("\n=== 执行增量训练 ===") # 检查训练数据 training_dir = "./training_data" if not os.path.exists(training_dir): print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)") return training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')] if not training_files: print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)") return print(f"找到训练数据文件:") for i, file in enumerate(training_files, 1): print(f" {i}. {file}") file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip() try: selected_file = training_files[int(file_idx) - 1] training_file_path = os.path.join(training_dir, selected_file) print(f"将使用训练文件: {selected_file}") print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源") # 生成训练命令 training_command = generate_training_command(training_file_path) print(f"建议训练命令:") print(f" {training_command}") # 可选:执行训练(需要用户确认) confirm = input("是否现在执行训练?(y/N): ").strip().lower() if confirm == 'y': print("开始增量训练...") # 这里可以添加实际的训练执行逻辑 print("⚠ 训练功能需要根据实际环境配置") except (ValueError, IndexError): print("❌ 无效的文件选择") elif choice == '5': # 性能对比验证 print("\n=== 性能对比验证 ===") print("验证选项:") print("1. 历史性能趋势对比") print("2. A/B测试配置生成") print("3. 模型版本性能对比") verify_choice = input("请输入选择 (1-3): ").strip() if verify_choice == '1': # 历史性能趋势 with sqlite3.connect(conv_mgr.db_path) as conn: cursor = conn.execute(""" SELECT DATE(timestamp) as date, AVG(dialogue_score) as avg_score, COUNT(*) as count FROM dialogue_turns WHERE dialogue_score > 0 AND timestamp >= datetime('now', '-30 days') GROUP BY DATE(timestamp) ORDER BY date ASC """) trend_data = cursor.fetchall() if trend_data: print(f"30天性能趋势:") for date, avg_score, count in trend_data: trend = "📈" if avg_score > 7.5 else "📉" if avg_score < 6.5 else "📊" print(f" {date}: {avg_score:.2f}分 {trend} ({count}条对话)") else: print("暂无足够的历史数据") elif verify_choice == '2': # A/B测试配置 ab_config = generate_ab_test_config() ab_config_path = "./config/ab_test_config.json" os.makedirs("./config", exist_ok=True) with open(ab_config_path, 'w', encoding='utf-8') as f: json.dump(ab_config, f, ensure_ascii=False, indent=2) print(f"✓ A/B测试配置已生成: {ab_config_path}") print("配置包含:") print(" - 对照组和实验组设置") print(" - 评估指标定义") print(" - 测试持续时间配置") elif verify_choice == '3': print("模型版本对比功能开发中...") print("建议手动记录不同版本的性能指标进行对比") else: print("❌ 无效选择") except Exception as e: print(f"✗ 模型优化失败: {e}") import traceback traceback.print_exc() def generate_lora_training_script(): """生成LoRA训练脚本""" return '''#!/usr/bin/env python # -*- coding: utf-8 -*- """ 基于评分数据的LoRA增量训练脚本 自动生成 - 请根据实际环境调整配置 """ import json import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import LoraConfig, get_peft_model, TaskType import os class IterativeLoRATrainer: def __init__(self, base_model_path, training_data_path, output_path): self.base_model_path = base_model_path self.training_data_path = training_data_path self.output_path = output_path # LoRA配置 self.lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=16, # LoRA rank lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj", "k_proj", "o_proj"] ) def load_training_data(self): """加载训练数据""" with open(self.training_data_path, 'r', encoding='utf-8') as f: data = json.load(f) return data['data'] def prepare_training_samples(self, data): """准备训练样本""" samples = [] for item in data: if item.get('label') == 'positive' or item.get('overall_score', 0) >= 8.0: # 高质量样本 sample = { 'input': f"角色: {item['speaker']}\\n请生成高质量对话:", 'output': item['content'], 'quality_score': item.get('overall_score', 8.0) } samples.append(sample) return samples def train(self): """执行训练""" print("开始LoRA增量训练...") # 加载模型和分词器 tokenizer = AutoTokenizer.from_pretrained(self.base_model_path) model = AutoModelForCausalLM.from_pretrained( self.base_model_path, torch_dtype=torch.bfloat16, device_map="auto" ) # 应用LoRA model = get_peft_model(model, self.lora_config) # 加载训练数据 training_data = self.load_training_data() training_samples = self.prepare_training_samples(training_data) print(f"训练样本数量: {len(training_samples)}") # 这里添加实际的训练循环 # 建议使用transformers的Trainer或自定义训练循环 # 保存模型 model.save_pretrained(self.output_path) tokenizer.save_pretrained(self.output_path) print(f"训练完成,模型保存到: {self.output_path}") if __name__ == '__main__': # 配置参数 - 请根据实际情况修改 BASE_MODEL_PATH = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B' TRAINING_DATA_PATH = './training_data/high_quality_dialogues_latest.json' OUTPUT_PATH = './output/iterative_lora_v2' trainer = IterativeLoRATrainer(BASE_MODEL_PATH, TRAINING_DATA_PATH, OUTPUT_PATH) trainer.train() ''' def generate_prompt_optimization_config(): """生成提示优化配置""" return { "optimization_rules": { "quality_thresholds": { "excellent": 9.0, "good": 8.0, "acceptable": 7.0, "needs_improvement": 6.0 }, "adaptive_adjustments": { "low_coherence": { "add_context_emphasis": True, "increase_history_weight": 1.2, "add_logical_constraints": True }, "low_character_consistency": { "enhance_character_description": True, "add_personality_reminders": True, "increase_character_weight": 1.5 }, "low_creativity": { "add_creativity_prompts": True, "increase_temperature": 0.1, "diversify_examples": True } } }, "dynamic_prompts": { "quality_boost_templates": [ "请生成一个富有创意且符合角色特征的高质量对话", "确保对话逻辑连贯,信息丰富,表达自然", "体现角色的独特性格和说话风格" ], "problem_specific_guidance": { "repetitive": "避免重复之前的内容,提供新的信息和观点", "inconsistent": "严格遵循角色设定,保持人格一致性", "dull": "增加对话的趣味性和深度,使用生动的表达" } }, "generation_parameters": { "adaptive_temperature": { "high_creativity_needed": 0.9, "normal": 0.8, "high_consistency_needed": 0.7 }, "adaptive_top_p": { "creative_mode": 0.9, "balanced_mode": 0.8, "conservative_mode": 0.7 } } } def generate_ab_test_config(): """生成A/B测试配置""" return { "test_name": "model_optimization_validation", "created_at": datetime.now().isoformat(), "groups": { "control": { "name": "原始模型", "description": "未优化的基础模型", "model_path": "/path/to/base/model", "sample_ratio": 0.5 }, "experimental": { "name": "优化模型", "description": "经过迭代优化的模型", "model_path": "/path/to/optimized/model", "sample_ratio": 0.5 } }, "evaluation_metrics": { "primary": [ "overall_dialogue_score", "character_consistency_score", "creativity_score" ], "secondary": [ "user_satisfaction", "response_time", "coherence_score" ] }, "test_duration": { "target_samples": 200, "max_duration_days": 7, "min_samples_per_group": 50 }, "statistical_settings": { "confidence_level": 0.95, "minimum_effect_size": 0.3, "power": 0.8 } } def generate_training_command(training_file_path): """生成训练命令""" return f"python ./scripts/iterative_lora_training.py --data {training_file_path} --output ./output/optimized_model_v{datetime.now().strftime('%Y%m%d')}" def show_scoring_statistics(): """显示对话评分统计""" print("\n" + "="*60) print("对话评分统计") print("="*60) try: from dual_ai_dialogue_system import ConversationManager import json conv_mgr = ConversationManager("./conversation_data/conversations.db") # 查询评分数据 import sqlite3 with sqlite3.connect(conv_mgr.db_path) as conn: # 总体统计 cursor = conn.execute(""" SELECT COUNT(*) as total_turns, AVG(dialogue_score) as avg_score, MAX(dialogue_score) as max_score, MIN(dialogue_score) as min_score FROM dialogue_turns WHERE dialogue_score > 0 """) stats = cursor.fetchone() if stats and stats[0] > 0: total_turns, avg_score, max_score, min_score = stats print(f"总体统计:") print(f" - 已评分对话轮数: {total_turns}") print(f" - 平均分数: {avg_score:.2f}") print(f" - 最高分数: {max_score:.2f}") print(f" - 最低分数: {min_score:.2f}") else: print("暂无评分数据") return # 按角色统计 print(f"\n按角色统计:") cursor = conn.execute(""" SELECT speaker, COUNT(*) as turns, AVG(dialogue_score) as avg_score, MAX(dialogue_score) as max_score FROM dialogue_turns WHERE dialogue_score > 0 GROUP BY speaker ORDER BY avg_score DESC """) for row in cursor.fetchall(): speaker, turns, avg_score, max_score = row print(f" - {speaker}: 平均{avg_score:.2f}分 (最高{max_score:.2f}分, {turns}轮对话)") # 最近高分对话 print(f"\n最近高分对话 (分数≥8.0):") cursor = conn.execute(""" SELECT speaker, content, dialogue_score, score_feedback, timestamp FROM dialogue_turns WHERE dialogue_score >= 8.0 ORDER BY timestamp DESC LIMIT 5 """) high_score_turns = cursor.fetchall() if high_score_turns: for speaker, content, score, feedback, timestamp in high_score_turns: print(f" [{timestamp[:16]}] {speaker} ({score:.2f}分)") print(f" 内容: {content[:80]}...") if feedback: print(f" 评价: {feedback[:60]}...") print() else: print(" 暂无高分对话") # 分数分布统计 print(f"\n分数分布:") cursor = conn.execute(""" SELECT CASE WHEN dialogue_score >= 9.0 THEN '优秀 (9-10分)' WHEN dialogue_score >= 8.0 THEN '良好 (8-9分)' WHEN dialogue_score >= 7.0 THEN '中等 (7-8分)' WHEN dialogue_score >= 6.0 THEN '及格 (6-7分)' ELSE '待改进 (<6分)' END as score_range, COUNT(*) as count FROM dialogue_turns WHERE dialogue_score > 0 GROUP BY score_range ORDER BY MIN(dialogue_score) DESC """) for score_range, count in cursor.fetchall(): percentage = (count / total_turns) * 100 print(f" - {score_range}: {count}轮 ({percentage:.1f}%)") except Exception as e: print(f"✗ 评分统计查询失败: {e}") def show_system_status(): """显示系统状态""" print("\n" + "="*60) print("系统状态检查") print("="*60) # 检查文件 files_to_check = [ ("./knowledge_base/worldview_template_coc.json", "世界观模板"), ("./knowledge_base/character_template_detective.json", "侦探角色"), ("./knowledge_base/character_template_professor.json", "教授角色"), ("./pdf_to_rag_processor.py", "PDF处理器"), ("./dual_ai_dialogue_system.py", "对话系统"), ("./npc_dialogue_generator.py", "NPC生成器") ] print("\n文件检查:") for file_path, description in files_to_check: if os.path.exists(file_path): print(f"✓ {description}: {file_path}") else: print(f"✗ {description}: {file_path} (不存在)") # 检查目录 print("\n目录检查:") directories = ["./knowledge_base", "./rag_knowledge", "./conversation_data"] for dir_path in directories: if os.path.exists(dir_path): file_count = len([f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]) print(f"✓ {dir_path}: {file_count} 个文件") else: print(f"✗ {dir_path}: 不存在") # 检查对话会话 try: from dual_ai_dialogue_system import ConversationManager conv_mgr = ConversationManager("./conversation_data/conversations.db") sessions = conv_mgr.list_sessions() print(f"\n✓ 对话会话: {len(sessions)} 个") except Exception as e: print(f"\n✗ 对话会话检查失败: {e}") def main(): """主控制程序""" print("="*70) print(" 双AI角色对话系统 - 主控制程序") print(" 基于RAG的世界观增强对话引擎") print("="*70) # 检查依赖 if not check_dependencies(): return # 设置目录 # setup_directories() # copy_demo_files() while True: print("\n" + "="*50) print("主菜单 - 请选择操作:") print("1. 处理PDF世界观文档 (转换为RAG格式)") print("2. 查看角色设定信息") print("3. 启动双AI对话系统 (支持双模型对话)") print("4. 创建演示对话场景") print("5. 系统状态检查") print("6. 查看对话评分统计") print("7. 模型性能分析与优化") print("8. 生成训练数据集") print("9. 模型迭代优化") print("10. 查看使用说明") print("0. 退出") print("="*50) choice = input("请输入选择 (0-10): ").strip() if choice == '0': print("\n感谢使用双AI角色对话系统!") break elif choice == '1': process_pdf_workflow() elif choice == '2': show_character_info() elif choice == '3': run_dialogue_system() elif choice == '4': create_demo_scenario() elif choice == '5': show_system_status() elif choice == '6': show_scoring_statistics() elif choice == '7': analyze_model_performance() elif choice == '8': generate_training_dataset() elif choice == '9': run_model_optimization() elif choice == '10': show_usage_guide() else: print("❌ 无效选择,请重新输入") def show_usage_guide(): """显示使用说明""" print("\n" + "="*60) print("系统使用说明") print("="*60) guide = """ 🚀 快速开始: 1. 首次使用建议先运行"创建演示对话场景" 2. 如有PDF世界观文档,选择"处理PDF世界观文档" 3. 通过"启动双AI对话系统"开始角色对话 📁 文档格式说明: - 世界观文档: worldview_template_coc.json (参考COC设定) - 角色设定: character_template_*.json (包含详细人设) 🔧 系统功能: - PDF自动转换为RAG知识库 - 基于向量相似度的上下文检索 - 持久化对话历史存储 - 角色设定一致性保持 📝 自定义角色: 1. 参考 character_template_*.json 格式 2. 保存到 knowledge_base/ 目录 3. 重启对话系统加载新角色 💾 对话数据: - 历史对话保存在 conversation_data/ 目录 - 支持会话恢复和历史查看 - 自动记录使用的上下文信息 ⚠️ 注意事项: - 确保模型路径正确设置 - 首次运行需要下载向量化模型 - PDF处理需要足够内存 """ print(guide) if __name__ == '__main__': main()