diff --git a/AITrain/conversation_data/conversations.db b/AITrain/conversation_data/conversations.db index 17da33b..5cd3317 100644 Binary files a/AITrain/conversation_data/conversations.db and b/AITrain/conversation_data/conversations.db differ diff --git a/AITrain/conversation_data/demo_conversations.db b/AITrain/conversation_data/demo_conversations.db deleted file mode 100644 index b64a0fd..0000000 Binary files a/AITrain/conversation_data/demo_conversations.db and /dev/null differ diff --git a/AITrain/dual_ai_dialogue_system.py b/AITrain/dual_ai_dialogue_system.py index ab4de91..c6dbf2c 100644 --- a/AITrain/dual_ai_dialogue_system.py +++ b/AITrain/dual_ai_dialogue_system.py @@ -286,6 +286,9 @@ class ConversationManager: timestamp TEXT, context_used TEXT, relevance_score REAL, + dialogue_score REAL DEFAULT 0.0, + score_details TEXT, + score_feedback TEXT, FOREIGN KEY (session_id) REFERENCES conversations (session_id) ) ''') @@ -305,7 +308,9 @@ class ConversationManager: print(f"✓ 创建对话会话: {session_id}") return session_id - def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None, relevance_score: float = 0.0): + def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None, + relevance_score: float = 0.0, dialogue_score: float = 0.0, + score_details: str = None, score_feedback: str = None): """添加对话轮次""" if context_used is None: context_used = [] @@ -318,10 +323,11 @@ class ConversationManager: # 插入对话轮次 conn.execute( """INSERT INTO dialogue_turns - (session_id, turn_number, speaker, content, timestamp, context_used, relevance_score) - VALUES (?, ?, ?, ?, ?, ?, ?)""", + (session_id, turn_number, speaker, content, timestamp, context_used, relevance_score, + dialogue_score, score_details, score_feedback) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", (session_id, turn_number, speaker, content, datetime.now().isoformat(), - json.dumps(context_used), relevance_score) + json.dumps(context_used), relevance_score, dialogue_score, score_details, score_feedback) ) # 更新会话最后更新时间 @@ -335,7 +341,7 @@ class ConversationManager: """获取对话历史""" with sqlite3.connect(self.db_path) as conn: cursor = conn.execute( - """SELECT speaker, content, timestamp, context_used, relevance_score + """SELECT speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback FROM dialogue_turns WHERE session_id = ? ORDER BY turn_number DESC LIMIT ?""", @@ -344,14 +350,20 @@ class ConversationManager: turns = [] for row in cursor.fetchall(): - speaker, content, timestamp, context_used, relevance_score = row - turns.append(DialogueTurn( + speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback = row + turn = DialogueTurn( speaker=speaker, content=content, timestamp=timestamp, context_used=json.loads(context_used or "[]"), relevance_score=relevance_score - )) + ) + # 添加评分信息到turn对象 + if dialogue_score: + turn.dialogue_score = dialogue_score + if score_feedback: + turn.score_feedback = score_feedback + turns.append(turn) return list(reversed(turns)) # 按时间正序返回 @@ -378,10 +390,70 @@ class ConversationManager: class DualAIDialogueEngine: """双AI对话引擎""" - def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator): + def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator, + enable_scoring: bool = True, base_model_path: str = None): self.kb = knowledge_base self.conv_mgr = conversation_manager self.llm_generator = llm_generator + self.enable_scoring = enable_scoring + self.scorer = None + + # 初始化评分器 + if enable_scoring and base_model_path: + try: + from dialogue_scorer import DialogueAIScorer + print("正在初始化对话评分系统...") + self.scorer = DialogueAIScorer( + base_model_path=base_model_path, + tokenizer=getattr(llm_generator, 'tokenizer', None), + model=getattr(llm_generator, 'model', None) + ) + print("✓ 对话评分系统初始化成功") + except Exception as e: + print(f"⚠ 对话评分系统初始化失败: {e}") + self.enable_scoring = False + + def score_dialogue_turn(self, dialogue_content: str, speaker: str, dialogue_history: List[DialogueTurn]) -> Tuple[float, str, str]: + """对单条对话进行评分 + + Args: + dialogue_content: 对话内容 + speaker: 说话者 + dialogue_history: 对话历史 + + Returns: + tuple: (总分, 详细分数JSON, 反馈意见) + """ + if not self.enable_scoring or not self.scorer: + return 0.0, "{}", "评分系统未启用" + + try: + # 获取角色数据 + character_data = self.kb.character_data.get(speaker, {}) + + # 转换对话历史格式 + history_for_scoring = [] + for turn in dialogue_history[-5:]: # 最近5轮对话 + history_for_scoring.append({ + 'speaker': turn.speaker, + 'content': turn.content + }) + + # 进行AI评分 + score_result = self.scorer.score_dialogue( + dialogue_content=dialogue_content, + speaker=speaker, + character_data=character_data, + dialogue_history=history_for_scoring, + context_info=[] + ) + + # 返回评分结果 + return score_result.overall_score, json.dumps(score_result.scores), score_result.feedback + + except Exception as e: + print(f"⚠ 对话评分失败: {e}") + return 0.0, "{}", f"评分失败: {str(e)}" def generate_character_prompt(self, character_name: str, context_info: List[Dict], dialogue_history: List[DialogueTurn], history_context_count: int = 3, context_info_count: int = 2) -> str: @@ -494,9 +566,17 @@ class DualAIDialogueEngine: context_used = [f"{info['section']}.{info['subsection']}" for info in context_info[:context_info_count]] avg_relevance = sum(info['relevance_score'] for info in context_info[:context_info_count]) / len(context_info[:context_info_count]) if context_info else 0.0 - # 保存对话轮次 + # 对对话进行评分 + if self.enable_scoring: + dialogue_score, score_details, score_feedback = self.score_dialogue_turn(response, current_speaker, dialogue_history) + print(f" [评分: {dialogue_score:.2f}] {score_feedback}") + else: + dialogue_score, score_details, score_feedback = 0.0, "{}", "" + + # 保存对话轮次(包含评分信息) self.conv_mgr.add_dialogue_turn( - session_id, current_speaker, response, context_used, avg_relevance + session_id, current_speaker, response, context_used, avg_relevance, + dialogue_score, score_details, score_feedback ) return response, context_used @@ -605,14 +685,29 @@ class DualAIDialogueEngine: max_new_tokens=150 ) - # 保存对话到数据库 + # 保存对话到数据库并进行评分 for result in conversation_results: + # 获取当前对话历史进行评分 + current_dialogue_history = self.conv_mgr.get_conversation_history(session_id) + + # 对对话进行评分 + if self.enable_scoring: + dialogue_score, score_details, score_feedback = self.score_dialogue_turn( + result['dialogue'], result['speaker'], current_dialogue_history + ) + print(f" [评分: {dialogue_score:.2f}] {score_feedback[:100]}...") + else: + dialogue_score, score_details, score_feedback = 0.0, "{}", "" + self.conv_mgr.add_dialogue_turn( session_id, result['speaker'], result['dialogue'], [result.get('context_used', '')], - 0.8 # 默认相关性分数 + 0.8, # 默认相关性分数 + dialogue_score, + score_details, + score_feedback ) diff --git a/AITrain/main_controller.py b/AITrain/main_controller.py index e6c0f51..904febd 100644 --- a/AITrain/main_controller.py +++ b/AITrain/main_controller.py @@ -140,7 +140,7 @@ def run_dialogue_system(): conv_mgr = ConversationManager("./conversation_data/conversations.db") # 检查模型路径 - base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B' + base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' lora_model_path = './output/NPC_Dialogue_LoRA/final_model' if not os.path.exists(base_model_path): @@ -186,8 +186,14 @@ def run_dialogue_system(): character2_config ) - # 创建对话引擎 - dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, dual_generator) + # 创建对话引擎(启用评分功能) + dialogue_engine = DualAIDialogueEngine( + kb, + conv_mgr, + dual_generator, + enable_scoring=True, + base_model_path=base_model_path + ) # 创建对话会话 characters = [char1_name, char2_name] @@ -317,6 +323,954 @@ def create_demo_scenario(): import traceback traceback.print_exc() +def analyze_model_performance(): + """分析模型性能""" + print("\n" + "="*60) + print("模型性能分析") + print("="*60) + + try: + from dual_ai_dialogue_system import ConversationManager + import sqlite3 + import json + from datetime import datetime, timedelta + + conv_mgr = ConversationManager("./conversation_data/conversations.db") + + with sqlite3.connect(conv_mgr.db_path) as conn: + print("\n1. 总体性能趋势分析:") + + # 按时间段分析性能趋势 + cursor = conn.execute(""" + SELECT + DATE(timestamp) as date, + COUNT(*) as dialogue_count, + AVG(dialogue_score) as avg_score, + AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate + FROM dialogue_turns + WHERE dialogue_score > 0 + AND timestamp >= datetime('now', '-7 days') + GROUP BY DATE(timestamp) + ORDER BY date DESC + """) + + trend_data = cursor.fetchall() + if trend_data: + print(f" 最近7天性能趋势:") + for date, count, avg_score, hq_rate in trend_data: + print(f" {date}: 平均{avg_score:.2f}分 ({count}轮对话, {hq_rate*100:.1f}%高质量)") + else: + print(" 暂无足够数据进行趋势分析") + + print("\n2. 维度问题分析:") + + # 分析各维度的问题 + cursor = conn.execute(""" + SELECT score_details + FROM dialogue_turns + WHERE dialogue_score > 0 AND score_details != '{}' + ORDER BY timestamp DESC + LIMIT 100 + """) + + dimension_scores = { + 'coherence': [], + 'character_consistency': [], + 'naturalness': [], + 'information_density': [], + 'creativity': [] + } + + for (score_details,) in cursor.fetchall(): + try: + scores = json.loads(score_details) + for dim, score in scores.items(): + if dim in dimension_scores: + dimension_scores[dim].append(float(score)) + except: + continue + + dimension_names = { + 'coherence': '连贯性', + 'character_consistency': '角色一致性', + 'naturalness': '自然度', + 'information_density': '信息密度', + 'creativity': '创意性' + } + + weak_dimensions = [] + for dim, scores in dimension_scores.items(): + if scores: + avg_score = sum(scores) / len(scores) + print(f" {dimension_names[dim]}: 平均{avg_score:.2f}分 ({len(scores)}个样本)") + if avg_score < 7.0: + weak_dimensions.append(dim) + + if weak_dimensions: + print(f"\n ⚠ 发现薄弱维度: {[dimension_names[d] for d in weak_dimensions]}") + print(" 建议进行针对性优化训练") + + print("\n3. 角色表现分析:") + + # 分析不同角色的表现 + cursor = conn.execute(""" + SELECT + speaker, + COUNT(*) as dialogue_count, + AVG(dialogue_score) as avg_score, + MIN(dialogue_score) as min_score, + MAX(dialogue_score) as max_score, + AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate + FROM dialogue_turns + WHERE dialogue_score > 0 + GROUP BY speaker + ORDER BY avg_score DESC + """) + + character_performance = cursor.fetchall() + if character_performance: + print(" 角色表现排名:") + for i, (speaker, count, avg, min_s, max_s, hq_rate) in enumerate(character_performance, 1): + status = "✓" if avg >= 7.5 else "⚠" if avg >= 6.5 else "✗" + print(f" {i}. {speaker} {status}") + print(f" 平均{avg:.2f}分 (范围{min_s:.1f}-{max_s:.1f}, {hq_rate*100:.1f}%高质量, {count}轮)") + + print("\n4. 问题模式识别:") + + # 识别低分对话的常见问题 + cursor = conn.execute(""" + SELECT content, dialogue_score, score_feedback + FROM dialogue_turns + WHERE dialogue_score > 0 AND dialogue_score < 6.0 + ORDER BY dialogue_score ASC + LIMIT 5 + """) + + low_score_examples = cursor.fetchall() + if low_score_examples: + print(" 低分对话示例:") + for i, (content, score, feedback) in enumerate(low_score_examples, 1): + print(f" {i}. 分数{score:.1f}: {content[:50]}...") + if feedback: + print(f" 问题: {feedback[:80]}...") + else: + print(" 暂无低分对话样本") + + print("\n5. 优化建议:") + + # 生成优化建议 + suggestions = [] + + if weak_dimensions: + if 'character_consistency' in weak_dimensions: + suggestions.append("• 加强角色设定训练,增加角色特征描述的权重") + if 'creativity' in weak_dimensions: + suggestions.append("• 增加创意性训练数据,提高对话的趣味性") + if 'coherence' in weak_dimensions: + suggestions.append("• 优化上下文理解,加强对话逻辑连贯性") + if 'naturalness' in weak_dimensions: + suggestions.append("• 增加自然语言训练,改善表达流畅度") + if 'information_density' in weak_dimensions: + suggestions.append("• 优化信息组织,避免冗余表达") + + # 检查是否需要数据收集 + cursor = conn.execute("SELECT COUNT(*) FROM dialogue_turns WHERE dialogue_score > 0") + total_scored = cursor.fetchone()[0] + + if total_scored < 50: + suggestions.append("• 需要收集更多评分数据以进行准确分析") + + if total_scored >= 100: + suggestions.append("• 数据量充足,建议开始模型迭代优化") + + if suggestions: + for suggestion in suggestions: + print(f" {suggestion}") + else: + print(" 当前性能表现良好,继续保持!") + + except Exception as e: + print(f"✗ 性能分析失败: {e}") + import traceback + traceback.print_exc() + +def generate_training_dataset(): + """生成训练数据集""" + print("\n" + "="*60) + print("生成训练数据集") + print("="*60) + + try: + from dual_ai_dialogue_system import ConversationManager + import sqlite3 + import json + import os + from datetime import datetime + + conv_mgr = ConversationManager("./conversation_data/conversations.db") + + # 创建输出目录 + output_dir = "./training_data" + os.makedirs(output_dir, exist_ok=True) + + print("请选择训练数据生成类型:") + print("1. 高质量对话数据集 (分数≥8.0)") + print("2. 问题对话改进数据集 (分数<6.0)") + print("3. 角色一致性训练集") + print("4. 创意性增强训练集") + print("5. 完整对话质量数据集") + + choice = input("请输入选择 (1-5): ").strip() + + with sqlite3.connect(conv_mgr.db_path) as conn: + training_data = [] + + if choice == '1': + # 高质量对话数据集 + print("\n生成高质量对话数据集...") + cursor = conn.execute(""" + SELECT speaker, content, score_details, score_feedback + FROM dialogue_turns + WHERE dialogue_score >= 8.0 + ORDER BY dialogue_score DESC + LIMIT 200 + """) + + for speaker, content, score_details, feedback in cursor.fetchall(): + training_data.append({ + "type": "high_quality", + "speaker": speaker, + "content": content, + "scores": json.loads(score_details) if score_details else {}, + "feedback": feedback, + "label": "positive" + }) + + output_file = f"{output_dir}/high_quality_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json" + + elif choice == '2': + # 问题对话改进数据集 + print("\n生成问题对话改进数据集...") + cursor = conn.execute(""" + SELECT speaker, content, score_details, score_feedback + FROM dialogue_turns + WHERE dialogue_score < 6.0 AND dialogue_score > 0 + ORDER BY dialogue_score ASC + LIMIT 100 + """) + + for speaker, content, score_details, feedback in cursor.fetchall(): + # 为每个低分对话生成改进指导 + improvement_prompt = generate_improvement_prompt(content, feedback) + + training_data.append({ + "type": "improvement", + "speaker": speaker, + "original_content": content, + "scores": json.loads(score_details) if score_details else {}, + "problems": feedback, + "improvement_prompt": improvement_prompt, + "label": "needs_improvement" + }) + + output_file = f"{output_dir}/improvement_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json" + + elif choice == '3': + # 角色一致性训练集 + print("\n生成角色一致性训练集...") + cursor = conn.execute(""" + SELECT speaker, content, score_details + FROM dialogue_turns + WHERE dialogue_score > 0 + ORDER BY json_extract(score_details, '$.character_consistency') DESC + LIMIT 150 + """) + + for speaker, content, score_details in cursor.fetchall(): + scores = json.loads(score_details) if score_details else {} + char_consistency = scores.get('character_consistency', 0) + + training_data.append({ + "type": "character_consistency", + "speaker": speaker, + "content": content, + "character_consistency_score": char_consistency, + "label": "high_consistency" if char_consistency >= 8.0 else "medium_consistency" + }) + + output_file = f"{output_dir}/character_consistency_{datetime.now().strftime('%Y%m%d_%H%M')}.json" + + elif choice == '4': + # 创意性增强训练集 + print("\n生成创意性增强训练集...") + cursor = conn.execute(""" + SELECT speaker, content, score_details + FROM dialogue_turns + WHERE dialogue_score > 0 + ORDER BY json_extract(score_details, '$.creativity') DESC + LIMIT 150 + """) + + for speaker, content, score_details in cursor.fetchall(): + scores = json.loads(score_details) if score_details else {} + creativity = scores.get('creativity', 0) + + training_data.append({ + "type": "creativity", + "speaker": speaker, + "content": content, + "creativity_score": creativity, + "label": "high_creativity" if creativity >= 8.0 else "medium_creativity" + }) + + output_file = f"{output_dir}/creativity_enhancement_{datetime.now().strftime('%Y%m%d_%H%M')}.json" + + elif choice == '5': + # 完整对话质量数据集 + print("\n生成完整对话质量数据集...") + cursor = conn.execute(""" + SELECT speaker, content, dialogue_score, score_details, score_feedback + FROM dialogue_turns + WHERE dialogue_score > 0 + ORDER BY timestamp DESC + LIMIT 300 + """) + + for speaker, content, score, score_details, feedback in cursor.fetchall(): + training_data.append({ + "type": "complete_dataset", + "speaker": speaker, + "content": content, + "overall_score": score, + "dimension_scores": json.loads(score_details) if score_details else {}, + "feedback": feedback, + "quality_label": get_quality_label(score) + }) + + output_file = f"{output_dir}/complete_quality_dataset_{datetime.now().strftime('%Y%m%d_%H%M')}.json" + + else: + print("❌ 无效选择") + return + + if training_data: + # 保存训练数据 + with open(output_file, 'w', encoding='utf-8') as f: + json.dump({ + "metadata": { + "created_at": datetime.now().isoformat(), + "total_samples": len(training_data), + "data_type": choice, + "source": "dialogue_scoring_system" + }, + "data": training_data + }, f, ensure_ascii=False, indent=2) + + print(f"\n✓ 训练数据集生成成功!") + print(f" - 文件路径: {output_file}") + print(f" - 数据条数: {len(training_data)}") + print(f" - 数据类型: {get_dataset_description(choice)}") + + # 生成数据集统计信息 + generate_dataset_statistics(training_data, choice) + + else: + print("❌ 未找到符合条件的数据") + + except Exception as e: + print(f"✗ 训练数据集生成失败: {e}") + import traceback + traceback.print_exc() + +def generate_improvement_prompt(content, feedback): + """生成改进提示""" + return f"""原对话: {content} + +问题分析: {feedback} + +改进要求: +1. 保持角色特征和设定 +2. 增强对话的逻辑性和连贯性 +3. 提升表达的自然度 +4. 增加信息密度,避免冗余 +5. 适当增加创意元素 + +请生成一个改进版本的对话。""" + +def get_quality_label(score): + """根据分数获取质量标签""" + if score >= 9.0: + return "excellent" + elif score >= 8.0: + return "good" + elif score >= 7.0: + return "average" + elif score >= 6.0: + return "below_average" + else: + return "poor" + +def get_dataset_description(choice): + """获取数据集描述""" + descriptions = { + '1': "高质量对话数据集", + '2': "问题对话改进数据集", + '3': "角色一致性训练集", + '4': "创意性增强训练集", + '5': "完整对话质量数据集" + } + return descriptions.get(choice, "未知类型") + +def generate_dataset_statistics(training_data, data_type): + """生成数据集统计信息""" + print(f"\n数据集统计信息:") + + if data_type == '1': # 高质量数据集 + speakers = {} + for item in training_data: + speaker = item['speaker'] + speakers[speaker] = speakers.get(speaker, 0) + 1 + + print(f" - 角色分布:") + for speaker, count in speakers.items(): + print(f" • {speaker}: {count}条对话") + + elif data_type == '5': # 完整数据集 + quality_dist = {} + for item in training_data: + label = item['quality_label'] + quality_dist[label] = quality_dist.get(label, 0) + 1 + + print(f" - 质量分布:") + for label, count in quality_dist.items(): + print(f" • {label}: {count}条对话") + +def run_model_optimization(): + """运行模型迭代优化""" + print("\n" + "="*60) + print("模型迭代优化") + print("="*60) + + try: + from dual_ai_dialogue_system import ConversationManager + import sqlite3 + import json + import os + from datetime import datetime + + conv_mgr = ConversationManager("./conversation_data/conversations.db") + + print("模型优化选项:") + print("1. 分析优化需求") + print("2. 生成LoRA训练脚本") + print("3. 创建提示优化配置") + print("4. 执行增量训练") + print("5. 性能对比验证") + + choice = input("请输入选择 (1-5): ").strip() + + if choice == '1': + # 分析优化需求 + print("\n=== 优化需求分析 ===") + + with sqlite3.connect(conv_mgr.db_path) as conn: + # 获取性能数据 + cursor = conn.execute(""" + SELECT + COUNT(*) as total, + AVG(dialogue_score) as avg_score, + AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate + FROM dialogue_turns WHERE dialogue_score > 0 + """) + + total, avg_score, hq_rate = cursor.fetchone() + + print(f"当前性能指标:") + print(f" - 总评分样本: {total}") + print(f" - 平均分数: {avg_score:.2f}") + print(f" - 高质量率: {hq_rate*100:.1f}%") + + # 分析优化潜力 + optimization_needs = [] + + if avg_score < 7.0: + optimization_needs.append("整体质量需要提升") + + if hq_rate < 0.6: + optimization_needs.append("高质量对话比例偏低") + + # 分析各维度表现 + cursor = conn.execute(""" + SELECT score_details FROM dialogue_turns + WHERE dialogue_score > 0 AND score_details != '{}' + ORDER BY timestamp DESC LIMIT 100 + """) + + dim_scores = {'coherence': [], 'character_consistency': [], + 'naturalness': [], 'information_density': [], 'creativity': []} + + for (score_details,) in cursor.fetchall(): + try: + scores = json.loads(score_details) + for dim, score in scores.items(): + if dim in dim_scores: + dim_scores[dim].append(float(score)) + except: + continue + + weak_dimensions = [] + print(f"\n维度分析:") + for dim, scores in dim_scores.items(): + if scores: + avg = sum(scores) / len(scores) + print(f" - {dim}: {avg:.2f}分") + if avg < 7.0: + weak_dimensions.append(dim) + + if weak_dimensions: + optimization_needs.append(f"薄弱维度: {weak_dimensions}") + + print(f"\n优化建议:") + if optimization_needs: + for i, need in enumerate(optimization_needs, 1): + print(f" {i}. {need}") + else: + print(" 当前模型表现良好,可考虑细微调优") + + elif choice == '2': + # 生成LoRA训练脚本 + print("\n=== 生成LoRA训练脚本 ===") + + script_content = generate_lora_training_script() + script_path = "./scripts/iterative_lora_training.py" + + os.makedirs("./scripts", exist_ok=True) + with open(script_path, 'w', encoding='utf-8') as f: + f.write(script_content) + + print(f"✓ LoRA训练脚本已生成: {script_path}") + print("使用方法:") + print(" 1. 先运行训练数据生成 (选项8)") + print(" 2. 修改脚本中的路径配置") + print(f" 3. 运行: python {script_path}") + + elif choice == '3': + # 创建提示优化配置 + print("\n=== 创建提示优化配置 ===") + + config = generate_prompt_optimization_config() + config_path = "./config/prompt_optimization.json" + + os.makedirs("./config", exist_ok=True) + with open(config_path, 'w', encoding='utf-8') as f: + json.dump(config, f, ensure_ascii=False, indent=2) + + print(f"✓ 提示优化配置已生成: {config_path}") + print("配置包含:") + print(" - 动态提示调整规则") + print(" - 质量阈值设置") + print(" - 生成参数优化") + + elif choice == '4': + # 执行增量训练 + print("\n=== 执行增量训练 ===") + + # 检查训练数据 + training_dir = "./training_data" + if not os.path.exists(training_dir): + print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)") + return + + training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')] + if not training_files: + print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)") + return + + print(f"找到训练数据文件:") + for i, file in enumerate(training_files, 1): + print(f" {i}. {file}") + + file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip() + try: + selected_file = training_files[int(file_idx) - 1] + training_file_path = os.path.join(training_dir, selected_file) + + print(f"将使用训练文件: {selected_file}") + print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源") + + # 生成训练命令 + training_command = generate_training_command(training_file_path) + print(f"建议训练命令:") + print(f" {training_command}") + + # 可选:执行训练(需要用户确认) + confirm = input("是否现在执行训练?(y/N): ").strip().lower() + if confirm == 'y': + print("开始增量训练...") + # 这里可以添加实际的训练执行逻辑 + print("⚠ 训练功能需要根据实际环境配置") + + except (ValueError, IndexError): + print("❌ 无效的文件选择") + + elif choice == '5': + # 性能对比验证 + print("\n=== 性能对比验证 ===") + + print("验证选项:") + print("1. 历史性能趋势对比") + print("2. A/B测试配置生成") + print("3. 模型版本性能对比") + + verify_choice = input("请输入选择 (1-3): ").strip() + + if verify_choice == '1': + # 历史性能趋势 + with sqlite3.connect(conv_mgr.db_path) as conn: + cursor = conn.execute(""" + SELECT + DATE(timestamp) as date, + AVG(dialogue_score) as avg_score, + COUNT(*) as count + FROM dialogue_turns + WHERE dialogue_score > 0 + AND timestamp >= datetime('now', '-30 days') + GROUP BY DATE(timestamp) + ORDER BY date ASC + """) + + trend_data = cursor.fetchall() + if trend_data: + print(f"30天性能趋势:") + for date, avg_score, count in trend_data: + trend = "📈" if avg_score > 7.5 else "📉" if avg_score < 6.5 else "📊" + print(f" {date}: {avg_score:.2f}分 {trend} ({count}条对话)") + else: + print("暂无足够的历史数据") + + elif verify_choice == '2': + # A/B测试配置 + ab_config = generate_ab_test_config() + ab_config_path = "./config/ab_test_config.json" + + os.makedirs("./config", exist_ok=True) + with open(ab_config_path, 'w', encoding='utf-8') as f: + json.dump(ab_config, f, ensure_ascii=False, indent=2) + + print(f"✓ A/B测试配置已生成: {ab_config_path}") + print("配置包含:") + print(" - 对照组和实验组设置") + print(" - 评估指标定义") + print(" - 测试持续时间配置") + + elif verify_choice == '3': + print("模型版本对比功能开发中...") + print("建议手动记录不同版本的性能指标进行对比") + + else: + print("❌ 无效选择") + + except Exception as e: + print(f"✗ 模型优化失败: {e}") + import traceback + traceback.print_exc() + +def generate_lora_training_script(): + """生成LoRA训练脚本""" + return '''#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +基于评分数据的LoRA增量训练脚本 +自动生成 - 请根据实际环境调整配置 +""" + +import json +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import LoraConfig, get_peft_model, TaskType +import os + +class IterativeLoRATrainer: + def __init__(self, base_model_path, training_data_path, output_path): + self.base_model_path = base_model_path + self.training_data_path = training_data_path + self.output_path = output_path + + # LoRA配置 + self.lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=16, # LoRA rank + lora_alpha=32, + lora_dropout=0.1, + target_modules=["q_proj", "v_proj", "k_proj", "o_proj"] + ) + + def load_training_data(self): + """加载训练数据""" + with open(self.training_data_path, 'r', encoding='utf-8') as f: + data = json.load(f) + return data['data'] + + def prepare_training_samples(self, data): + """准备训练样本""" + samples = [] + + for item in data: + if item.get('label') == 'positive' or item.get('overall_score', 0) >= 8.0: + # 高质量样本 + sample = { + 'input': f"角色: {item['speaker']}\\n请生成高质量对话:", + 'output': item['content'], + 'quality_score': item.get('overall_score', 8.0) + } + samples.append(sample) + + return samples + + def train(self): + """执行训练""" + print("开始LoRA增量训练...") + + # 加载模型和分词器 + tokenizer = AutoTokenizer.from_pretrained(self.base_model_path) + model = AutoModelForCausalLM.from_pretrained( + self.base_model_path, + torch_dtype=torch.bfloat16, + device_map="auto" + ) + + # 应用LoRA + model = get_peft_model(model, self.lora_config) + + # 加载训练数据 + training_data = self.load_training_data() + training_samples = self.prepare_training_samples(training_data) + + print(f"训练样本数量: {len(training_samples)}") + + # 这里添加实际的训练循环 + # 建议使用transformers的Trainer或自定义训练循环 + + # 保存模型 + model.save_pretrained(self.output_path) + tokenizer.save_pretrained(self.output_path) + + print(f"训练完成,模型保存到: {self.output_path}") + +if __name__ == '__main__': + # 配置参数 - 请根据实际情况修改 + BASE_MODEL_PATH = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B' + TRAINING_DATA_PATH = './training_data/high_quality_dialogues_latest.json' + OUTPUT_PATH = './output/iterative_lora_v2' + + trainer = IterativeLoRATrainer(BASE_MODEL_PATH, TRAINING_DATA_PATH, OUTPUT_PATH) + trainer.train() +''' + +def generate_prompt_optimization_config(): + """生成提示优化配置""" + return { + "optimization_rules": { + "quality_thresholds": { + "excellent": 9.0, + "good": 8.0, + "acceptable": 7.0, + "needs_improvement": 6.0 + }, + "adaptive_adjustments": { + "low_coherence": { + "add_context_emphasis": True, + "increase_history_weight": 1.2, + "add_logical_constraints": True + }, + "low_character_consistency": { + "enhance_character_description": True, + "add_personality_reminders": True, + "increase_character_weight": 1.5 + }, + "low_creativity": { + "add_creativity_prompts": True, + "increase_temperature": 0.1, + "diversify_examples": True + } + } + }, + "dynamic_prompts": { + "quality_boost_templates": [ + "请生成一个富有创意且符合角色特征的高质量对话", + "确保对话逻辑连贯,信息丰富,表达自然", + "体现角色的独特性格和说话风格" + ], + "problem_specific_guidance": { + "repetitive": "避免重复之前的内容,提供新的信息和观点", + "inconsistent": "严格遵循角色设定,保持人格一致性", + "dull": "增加对话的趣味性和深度,使用生动的表达" + } + }, + "generation_parameters": { + "adaptive_temperature": { + "high_creativity_needed": 0.9, + "normal": 0.8, + "high_consistency_needed": 0.7 + }, + "adaptive_top_p": { + "creative_mode": 0.9, + "balanced_mode": 0.8, + "conservative_mode": 0.7 + } + } + } + +def generate_ab_test_config(): + """生成A/B测试配置""" + return { + "test_name": "model_optimization_validation", + "created_at": datetime.now().isoformat(), + "groups": { + "control": { + "name": "原始模型", + "description": "未优化的基础模型", + "model_path": "/path/to/base/model", + "sample_ratio": 0.5 + }, + "experimental": { + "name": "优化模型", + "description": "经过迭代优化的模型", + "model_path": "/path/to/optimized/model", + "sample_ratio": 0.5 + } + }, + "evaluation_metrics": { + "primary": [ + "overall_dialogue_score", + "character_consistency_score", + "creativity_score" + ], + "secondary": [ + "user_satisfaction", + "response_time", + "coherence_score" + ] + }, + "test_duration": { + "target_samples": 200, + "max_duration_days": 7, + "min_samples_per_group": 50 + }, + "statistical_settings": { + "confidence_level": 0.95, + "minimum_effect_size": 0.3, + "power": 0.8 + } + } + +def generate_training_command(training_file_path): + """生成训练命令""" + return f"python ./scripts/iterative_lora_training.py --data {training_file_path} --output ./output/optimized_model_v{datetime.now().strftime('%Y%m%d')}" + +def show_scoring_statistics(): + """显示对话评分统计""" + print("\n" + "="*60) + print("对话评分统计") + print("="*60) + + try: + from dual_ai_dialogue_system import ConversationManager + import json + + conv_mgr = ConversationManager("./conversation_data/conversations.db") + + # 查询评分数据 + import sqlite3 + with sqlite3.connect(conv_mgr.db_path) as conn: + # 总体统计 + cursor = conn.execute(""" + SELECT + COUNT(*) as total_turns, + AVG(dialogue_score) as avg_score, + MAX(dialogue_score) as max_score, + MIN(dialogue_score) as min_score + FROM dialogue_turns + WHERE dialogue_score > 0 + """) + + stats = cursor.fetchone() + if stats and stats[0] > 0: + total_turns, avg_score, max_score, min_score = stats + print(f"总体统计:") + print(f" - 已评分对话轮数: {total_turns}") + print(f" - 平均分数: {avg_score:.2f}") + print(f" - 最高分数: {max_score:.2f}") + print(f" - 最低分数: {min_score:.2f}") + else: + print("暂无评分数据") + return + + # 按角色统计 + print(f"\n按角色统计:") + cursor = conn.execute(""" + SELECT + speaker, + COUNT(*) as turns, + AVG(dialogue_score) as avg_score, + MAX(dialogue_score) as max_score + FROM dialogue_turns + WHERE dialogue_score > 0 + GROUP BY speaker + ORDER BY avg_score DESC + """) + + for row in cursor.fetchall(): + speaker, turns, avg_score, max_score = row + print(f" - {speaker}: 平均{avg_score:.2f}分 (最高{max_score:.2f}分, {turns}轮对话)") + + # 最近高分对话 + print(f"\n最近高分对话 (分数≥8.0):") + cursor = conn.execute(""" + SELECT speaker, content, dialogue_score, score_feedback, timestamp + FROM dialogue_turns + WHERE dialogue_score >= 8.0 + ORDER BY timestamp DESC + LIMIT 5 + """) + + high_score_turns = cursor.fetchall() + if high_score_turns: + for speaker, content, score, feedback, timestamp in high_score_turns: + print(f" [{timestamp[:16]}] {speaker} ({score:.2f}分)") + print(f" 内容: {content[:80]}...") + if feedback: + print(f" 评价: {feedback[:60]}...") + print() + else: + print(" 暂无高分对话") + + # 分数分布统计 + print(f"\n分数分布:") + cursor = conn.execute(""" + SELECT + CASE + WHEN dialogue_score >= 9.0 THEN '优秀 (9-10分)' + WHEN dialogue_score >= 8.0 THEN '良好 (8-9分)' + WHEN dialogue_score >= 7.0 THEN '中等 (7-8分)' + WHEN dialogue_score >= 6.0 THEN '及格 (6-7分)' + ELSE '待改进 (<6分)' + END as score_range, + COUNT(*) as count + FROM dialogue_turns + WHERE dialogue_score > 0 + GROUP BY score_range + ORDER BY MIN(dialogue_score) DESC + """) + + for score_range, count in cursor.fetchall(): + percentage = (count / total_turns) * 100 + print(f" - {score_range}: {count}轮 ({percentage:.1f}%)") + + except Exception as e: + print(f"✗ 评分统计查询失败: {e}") + def show_system_status(): """显示系统状态""" print("\n" + "="*60) @@ -382,11 +1336,15 @@ def main(): print("3. 启动双AI对话系统 (支持双模型对话)") print("4. 创建演示对话场景") print("5. 系统状态检查") - print("6. 查看使用说明") + print("6. 查看对话评分统计") + print("7. 模型性能分析与优化") + print("8. 生成训练数据集") + print("9. 模型迭代优化") + print("10. 查看使用说明") print("0. 退出") print("="*50) - choice = input("请输入选择 (0-6): ").strip() + choice = input("请输入选择 (0-10): ").strip() if choice == '0': print("\n感谢使用双AI角色对话系统!") @@ -408,6 +1366,18 @@ def main(): show_system_status() elif choice == '6': + show_scoring_statistics() + + elif choice == '7': + analyze_model_performance() + + elif choice == '8': + generate_training_dataset() + + elif choice == '9': + run_model_optimization() + + elif choice == '10': show_usage_guide() else: