2025-08-15 17:25:07 +08:00
|
|
|
|
#!/usr/bin/env python
|
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
'''
|
|
|
|
|
|
双AI角色对话系统主控制程序
|
|
|
|
|
|
完整的工作流程:PDF处理 -> 角色加载 -> RAG对话 -> 历史记录
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
import sys
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
from typing import List, Dict
|
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
|
|
def check_dependencies():
|
|
|
|
|
|
"""检查依赖库"""
|
|
|
|
|
|
missing_deps = []
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
import PyPDF2
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
missing_deps.append("PyPDF2")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
import pymupdf
|
|
|
|
|
|
print("✓ pymupdf 可用")
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
print("⚠ pymupdf 不可用,将使用 PyPDF2")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
import sentence_transformers
|
|
|
|
|
|
import faiss
|
|
|
|
|
|
print("✓ 向量化功能可用")
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
print("⚠ 向量化功能不可用,将使用文本匹配")
|
|
|
|
|
|
|
|
|
|
|
|
if missing_deps:
|
|
|
|
|
|
print(f"✗ 缺少依赖库: {', '.join(missing_deps)}")
|
|
|
|
|
|
print("请运行: pip install PyPDF2 sentence-transformers faiss-cpu")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def setup_directories():
|
|
|
|
|
|
"""设置项目目录结构"""
|
|
|
|
|
|
directories = [
|
|
|
|
|
|
"./knowledge_base",
|
|
|
|
|
|
"./characters",
|
|
|
|
|
|
"./worldview",
|
|
|
|
|
|
"./rag_knowledge",
|
|
|
|
|
|
"./conversation_data"
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
for dir_path in directories:
|
|
|
|
|
|
os.makedirs(dir_path, exist_ok=True)
|
|
|
|
|
|
print(f"✓ 目录就绪: {dir_path}")
|
|
|
|
|
|
|
|
|
|
|
|
def copy_demo_files():
|
|
|
|
|
|
"""复制演示文档到知识库目录"""
|
|
|
|
|
|
file_mappings = [
|
|
|
|
|
|
("./worldview/worldview_template_coc.json", "./knowledge_base/worldview_template_coc.json"),
|
|
|
|
|
|
("./characters/character_template_detective.json", "./knowledge_base/character_template_detective.json"),
|
|
|
|
|
|
("./characters/character_template_professor.json", "./knowledge_base/character_template_professor.json")
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
for source, target in file_mappings:
|
|
|
|
|
|
if os.path.exists(source):
|
|
|
|
|
|
shutil.copy2(source, target)
|
|
|
|
|
|
print(f"✓ 复制文档: {os.path.basename(target)}")
|
|
|
|
|
|
|
|
|
|
|
|
def process_pdf_workflow():
|
|
|
|
|
|
"""PDF处理工作流"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("PDF世界观文档处理")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
from pdf_to_rag_processor import PDFToRAGProcessor
|
|
|
|
|
|
|
|
|
|
|
|
pdf_path = input("请输入PDF文件路径 (例: ./coc.pdf): ").strip()
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(pdf_path):
|
|
|
|
|
|
print(f"✗ 文件不存在: {pdf_path}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
processor = PDFToRAGProcessor()
|
|
|
|
|
|
result = processor.process_pdf_to_rag(pdf_path, "./rag_knowledge")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n✓ PDF处理完成!")
|
|
|
|
|
|
print(f" - 文档块数: {result['chunks_count']}")
|
|
|
|
|
|
print(f" - 概念数: {result['concepts_count']}")
|
|
|
|
|
|
print(f" - 向量索引: {'启用' if result['vector_enabled'] else '未启用'}")
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ PDF处理失败: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def show_character_info():
|
|
|
|
|
|
"""显示角色信息"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("角色设定信息")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
knowledge_dir = "./knowledge_base"
|
|
|
|
|
|
character_files = [f for f in os.listdir(knowledge_dir) if f.startswith('character') and f.endswith('.json')]
|
|
|
|
|
|
|
|
|
|
|
|
for char_file in character_files:
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(os.path.join(knowledge_dir, char_file), 'r', encoding='utf-8') as f:
|
|
|
|
|
|
char_data = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
name = char_data.get('character_name', '未知')
|
|
|
|
|
|
occupation = char_data.get('basic_info', {}).get('occupation', '未知')
|
|
|
|
|
|
traits = char_data.get('personality', {}).get('core_traits', [])
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n角色: {name}")
|
|
|
|
|
|
print(f" 职业: {occupation}")
|
|
|
|
|
|
print(f" 特点: {', '.join(traits[:3])}")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 读取角色文件失败: {char_file} - {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def run_dialogue_system():
|
2025-08-15 17:58:11 +08:00
|
|
|
|
"""运行双AI对话系统"""
|
2025-08-15 17:25:07 +08:00
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("启动双AI角色对话系统")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
|
|
|
|
|
# 直接启动双模型对话
|
|
|
|
|
|
print("\n正在初始化双模型对话系统...")
|
|
|
|
|
|
|
|
|
|
|
|
from dual_ai_dialogue_system import RAGKnowledgeBase, ConversationManager, DualAIDialogueEngine
|
|
|
|
|
|
from npc_dialogue_generator import DualModelDialogueGenerator
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化组件
|
|
|
|
|
|
kb = RAGKnowledgeBase("./knowledge_base")
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查模型路径
|
2025-08-23 17:10:23 +08:00
|
|
|
|
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B'
|
2025-08-15 17:58:11 +08:00
|
|
|
|
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(base_model_path):
|
|
|
|
|
|
print(f"✗ 基础模型路径不存在: {base_model_path}")
|
|
|
|
|
|
print("请修改 main_controller.py 中的模型路径")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(lora_model_path):
|
|
|
|
|
|
lora_model_path = None
|
|
|
|
|
|
print("⚠ LoRA模型不存在,使用基础模型")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查角色数据
|
|
|
|
|
|
if not hasattr(kb, 'character_data') or len(kb.character_data) < 2:
|
|
|
|
|
|
print("✗ 角色数据不足,无法创建双模型对话系统")
|
|
|
|
|
|
print("请确保knowledge_base目录中有至少两个角色文件")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 获取前两个角色
|
|
|
|
|
|
character_names = list(kb.character_data.keys())[:2]
|
|
|
|
|
|
char1_name = character_names[0]
|
|
|
|
|
|
char2_name = character_names[1]
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✓ 使用角色: {char1_name} 和 {char2_name}")
|
|
|
|
|
|
|
|
|
|
|
|
# 配置两个角色的模型
|
|
|
|
|
|
character1_config = {
|
|
|
|
|
|
"name": char1_name,
|
|
|
|
|
|
"lora_path": lora_model_path,
|
|
|
|
|
|
"character_data": kb.character_data[char1_name]
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
character2_config = {
|
|
|
|
|
|
"name": char2_name,
|
|
|
|
|
|
"lora_path": lora_model_path,
|
|
|
|
|
|
"character_data": kb.character_data[char2_name]
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 创建双模型对话生成器
|
|
|
|
|
|
print("正在初始化双模型对话生成器...")
|
|
|
|
|
|
dual_generator = DualModelDialogueGenerator(
|
|
|
|
|
|
base_model_path,
|
|
|
|
|
|
character1_config,
|
|
|
|
|
|
character2_config
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-08-23 17:10:23 +08:00
|
|
|
|
# 创建对话引擎(启用评分功能)
|
|
|
|
|
|
dialogue_engine = DualAIDialogueEngine(
|
|
|
|
|
|
kb,
|
|
|
|
|
|
conv_mgr,
|
|
|
|
|
|
dual_generator,
|
|
|
|
|
|
enable_scoring=True,
|
|
|
|
|
|
base_model_path=base_model_path
|
|
|
|
|
|
)
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
|
|
|
|
|
# 创建对话会话
|
|
|
|
|
|
characters = [char1_name, char2_name]
|
|
|
|
|
|
worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观'
|
|
|
|
|
|
|
|
|
|
|
|
session_id = conv_mgr.create_session(characters, worldview)
|
|
|
|
|
|
print(f"✓ 创建对话会话: {session_id}")
|
|
|
|
|
|
|
|
|
|
|
|
# 交互式对话循环
|
|
|
|
|
|
print(f"\n=== 双AI模型对话系统 ===")
|
|
|
|
|
|
print(f"角色: {char1_name} vs {char2_name}")
|
|
|
|
|
|
print(f"世界观: {worldview}")
|
|
|
|
|
|
print("输入 'quit' 退出对话")
|
|
|
|
|
|
print("-" * 50)
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 获取用户输入
|
|
|
|
|
|
user_input = input("\n请输入对话主题或指令: ").strip()
|
|
|
|
|
|
|
|
|
|
|
|
if user_input.lower() == 'quit':
|
|
|
|
|
|
print("退出双AI对话系统")
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
if not user_input:
|
|
|
|
|
|
print("请输入有效的对话主题")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 询问对话轮数
|
|
|
|
|
|
turns_input = input("请输入对话轮数 (默认4): ").strip()
|
|
|
|
|
|
turns = int(turns_input) if turns_input.isdigit() else 4
|
|
|
|
|
|
|
|
|
|
|
|
# 询问历史上下文设置
|
2025-08-18 18:25:19 +08:00
|
|
|
|
history_input = input("使用历史对话轮数 (默认2): ").strip()
|
|
|
|
|
|
history_count = int(history_input) if history_input.isdigit() else 2
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
2025-08-18 18:25:19 +08:00
|
|
|
|
context_input = input("使用上下文信息数量 (默认10): ").strip()
|
|
|
|
|
|
context_info_count = int(context_input) if context_input.isdigit() else 10
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
|
|
|
|
|
print(f"\n开始对话 - 主题: {user_input}")
|
|
|
|
|
|
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
|
|
|
|
|
|
print("-" * 50)
|
|
|
|
|
|
|
|
|
|
|
|
# 运行双模型对话
|
|
|
|
|
|
dialogue_engine.run_dual_model_conversation(
|
|
|
|
|
|
session_id, user_input, turns, history_count, context_info_count
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
print("-" * 50)
|
|
|
|
|
|
print("对话完成!")
|
|
|
|
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
|
print("\n\n用户中断对话")
|
|
|
|
|
|
break
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"对话过程中出现错误: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-08-15 17:25:07 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 对话系统启动失败: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
def create_demo_scenario():
|
|
|
|
|
|
"""创建演示场景"""
|
|
|
|
|
|
print("\n创建演示对话场景...")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
from dual_ai_dialogue_system import RAGKnowledgeBase, ConversationManager, DualAIDialogueEngine
|
|
|
|
|
|
from npc_dialogue_generator import NPCDialogueGenerator
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化组件
|
|
|
|
|
|
kb = RAGKnowledgeBase("./knowledge_base")
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/demo_conversations.db")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查模型路径
|
2025-08-18 18:25:19 +08:00
|
|
|
|
base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
|
2025-08-15 17:25:07 +08:00
|
|
|
|
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(base_model_path):
|
|
|
|
|
|
print(f"✗ 基础模型路径不存在: {base_model_path}")
|
|
|
|
|
|
print("请修改 main_controller.py 中的模型路径")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(lora_model_path):
|
|
|
|
|
|
lora_model_path = None
|
|
|
|
|
|
print("⚠ LoRA模型不存在,使用基础模型")
|
|
|
|
|
|
|
|
|
|
|
|
llm_generator = NPCDialogueGenerator(base_model_path, lora_model_path, kb.character_data)
|
|
|
|
|
|
dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建演示对话
|
|
|
|
|
|
characters = ["维多利亚·布莱克伍德", "阿奇博尔德·韦恩"]
|
|
|
|
|
|
worldview = "克苏鲁的呼唤"
|
|
|
|
|
|
|
|
|
|
|
|
session_id = conv_mgr.create_session(characters, worldview)
|
|
|
|
|
|
print(f"✓ 创建演示会话: {session_id}")
|
|
|
|
|
|
|
|
|
|
|
|
# 运行几轮对话
|
|
|
|
|
|
topic = "最近发生的神秘事件"
|
|
|
|
|
|
print(f"\n开始演示对话 - 主题: {topic}")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
# 演示不同的历史上下文设置
|
|
|
|
|
|
# print("演示1: 使用默认上下文设置(历史3轮,信息2个)")
|
|
|
|
|
|
# dialogue_engine.run_conversation_turn(session_id, characters, 6, topic)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session_id = conv_mgr.create_session(characters, worldview)
|
|
|
|
|
|
print(f"✓ 创建演示会话: {session_id}")
|
|
|
|
|
|
print("\n演示3: 使用最少历史上下文(历史1轮,信息1个)")
|
|
|
|
|
|
dialogue_engine.run_conversation_turn(session_id, characters, 6, topic, 1, 10)
|
|
|
|
|
|
|
|
|
|
|
|
session_id = conv_mgr.create_session(characters, worldview)
|
|
|
|
|
|
print(f"✓ 创建演示会话: {session_id}")
|
|
|
|
|
|
print("\n演示2: 使用更多历史上下文(历史10轮,信息10个)")
|
|
|
|
|
|
dialogue_engine.run_conversation_turn(session_id, characters, 6, topic, 5, 10)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n✓ 演示完成!会话ID: {session_id}")
|
|
|
|
|
|
print("你可以通过主对话系统继续这个对话")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 演示场景创建失败: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
2025-08-23 17:10:23 +08:00
|
|
|
|
def analyze_model_performance():
|
|
|
|
|
|
"""分析模型性能"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("模型性能分析")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
from dual_ai_dialogue_system import ConversationManager
|
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
import json
|
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
|
|
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
|
|
|
|
|
|
|
|
|
|
|
with sqlite3.connect(conv_mgr.db_path) as conn:
|
|
|
|
|
|
print("\n1. 总体性能趋势分析:")
|
|
|
|
|
|
|
|
|
|
|
|
# 按时间段分析性能趋势
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
DATE(timestamp) as date,
|
|
|
|
|
|
COUNT(*) as dialogue_count,
|
|
|
|
|
|
AVG(dialogue_score) as avg_score,
|
|
|
|
|
|
AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
AND timestamp >= datetime('now', '-7 days')
|
|
|
|
|
|
GROUP BY DATE(timestamp)
|
|
|
|
|
|
ORDER BY date DESC
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
trend_data = cursor.fetchall()
|
|
|
|
|
|
if trend_data:
|
|
|
|
|
|
print(f" 最近7天性能趋势:")
|
|
|
|
|
|
for date, count, avg_score, hq_rate in trend_data:
|
|
|
|
|
|
print(f" {date}: 平均{avg_score:.2f}分 ({count}轮对话, {hq_rate*100:.1f}%高质量)")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(" 暂无足够数据进行趋势分析")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n2. 维度问题分析:")
|
|
|
|
|
|
|
|
|
|
|
|
# 分析各维度的问题
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT score_details
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0 AND score_details != '{}'
|
|
|
|
|
|
ORDER BY timestamp DESC
|
|
|
|
|
|
LIMIT 100
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
dimension_scores = {
|
|
|
|
|
|
'coherence': [],
|
|
|
|
|
|
'character_consistency': [],
|
|
|
|
|
|
'naturalness': [],
|
|
|
|
|
|
'information_density': [],
|
|
|
|
|
|
'creativity': []
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (score_details,) in cursor.fetchall():
|
|
|
|
|
|
try:
|
|
|
|
|
|
scores = json.loads(score_details)
|
|
|
|
|
|
for dim, score in scores.items():
|
|
|
|
|
|
if dim in dimension_scores:
|
|
|
|
|
|
dimension_scores[dim].append(float(score))
|
|
|
|
|
|
except:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
dimension_names = {
|
|
|
|
|
|
'coherence': '连贯性',
|
|
|
|
|
|
'character_consistency': '角色一致性',
|
|
|
|
|
|
'naturalness': '自然度',
|
|
|
|
|
|
'information_density': '信息密度',
|
|
|
|
|
|
'creativity': '创意性'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
weak_dimensions = []
|
|
|
|
|
|
for dim, scores in dimension_scores.items():
|
|
|
|
|
|
if scores:
|
|
|
|
|
|
avg_score = sum(scores) / len(scores)
|
|
|
|
|
|
print(f" {dimension_names[dim]}: 平均{avg_score:.2f}分 ({len(scores)}个样本)")
|
|
|
|
|
|
if avg_score < 7.0:
|
|
|
|
|
|
weak_dimensions.append(dim)
|
|
|
|
|
|
|
|
|
|
|
|
if weak_dimensions:
|
|
|
|
|
|
print(f"\n ⚠ 发现薄弱维度: {[dimension_names[d] for d in weak_dimensions]}")
|
|
|
|
|
|
print(" 建议进行针对性优化训练")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n3. 角色表现分析:")
|
|
|
|
|
|
|
|
|
|
|
|
# 分析不同角色的表现
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
speaker,
|
|
|
|
|
|
COUNT(*) as dialogue_count,
|
|
|
|
|
|
AVG(dialogue_score) as avg_score,
|
|
|
|
|
|
MIN(dialogue_score) as min_score,
|
|
|
|
|
|
MAX(dialogue_score) as max_score,
|
|
|
|
|
|
AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
GROUP BY speaker
|
|
|
|
|
|
ORDER BY avg_score DESC
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
character_performance = cursor.fetchall()
|
|
|
|
|
|
if character_performance:
|
|
|
|
|
|
print(" 角色表现排名:")
|
|
|
|
|
|
for i, (speaker, count, avg, min_s, max_s, hq_rate) in enumerate(character_performance, 1):
|
|
|
|
|
|
status = "✓" if avg >= 7.5 else "⚠" if avg >= 6.5 else "✗"
|
|
|
|
|
|
print(f" {i}. {speaker} {status}")
|
|
|
|
|
|
print(f" 平均{avg:.2f}分 (范围{min_s:.1f}-{max_s:.1f}, {hq_rate*100:.1f}%高质量, {count}轮)")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n4. 问题模式识别:")
|
|
|
|
|
|
|
|
|
|
|
|
# 识别低分对话的常见问题
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT content, dialogue_score, score_feedback
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0 AND dialogue_score < 6.0
|
|
|
|
|
|
ORDER BY dialogue_score ASC
|
|
|
|
|
|
LIMIT 5
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
low_score_examples = cursor.fetchall()
|
|
|
|
|
|
if low_score_examples:
|
|
|
|
|
|
print(" 低分对话示例:")
|
|
|
|
|
|
for i, (content, score, feedback) in enumerate(low_score_examples, 1):
|
|
|
|
|
|
print(f" {i}. 分数{score:.1f}: {content[:50]}...")
|
|
|
|
|
|
if feedback:
|
|
|
|
|
|
print(f" 问题: {feedback[:80]}...")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(" 暂无低分对话样本")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n5. 优化建议:")
|
|
|
|
|
|
|
|
|
|
|
|
# 生成优化建议
|
|
|
|
|
|
suggestions = []
|
|
|
|
|
|
|
|
|
|
|
|
if weak_dimensions:
|
|
|
|
|
|
if 'character_consistency' in weak_dimensions:
|
|
|
|
|
|
suggestions.append("• 加强角色设定训练,增加角色特征描述的权重")
|
|
|
|
|
|
if 'creativity' in weak_dimensions:
|
|
|
|
|
|
suggestions.append("• 增加创意性训练数据,提高对话的趣味性")
|
|
|
|
|
|
if 'coherence' in weak_dimensions:
|
|
|
|
|
|
suggestions.append("• 优化上下文理解,加强对话逻辑连贯性")
|
|
|
|
|
|
if 'naturalness' in weak_dimensions:
|
|
|
|
|
|
suggestions.append("• 增加自然语言训练,改善表达流畅度")
|
|
|
|
|
|
if 'information_density' in weak_dimensions:
|
|
|
|
|
|
suggestions.append("• 优化信息组织,避免冗余表达")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否需要数据收集
|
|
|
|
|
|
cursor = conn.execute("SELECT COUNT(*) FROM dialogue_turns WHERE dialogue_score > 0")
|
|
|
|
|
|
total_scored = cursor.fetchone()[0]
|
|
|
|
|
|
|
|
|
|
|
|
if total_scored < 50:
|
|
|
|
|
|
suggestions.append("• 需要收集更多评分数据以进行准确分析")
|
|
|
|
|
|
|
|
|
|
|
|
if total_scored >= 100:
|
|
|
|
|
|
suggestions.append("• 数据量充足,建议开始模型迭代优化")
|
|
|
|
|
|
|
|
|
|
|
|
if suggestions:
|
|
|
|
|
|
for suggestion in suggestions:
|
|
|
|
|
|
print(f" {suggestion}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(" 当前性能表现良好,继续保持!")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 性能分析失败: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
def generate_training_dataset():
|
|
|
|
|
|
"""生成训练数据集"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("生成训练数据集")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
from dual_ai_dialogue_system import ConversationManager
|
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
import json
|
|
|
|
|
|
import os
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
|
|
|
|
|
|
|
|
|
|
|
# 创建输出目录
|
|
|
|
|
|
output_dir = "./training_data"
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
print("请选择训练数据生成类型:")
|
|
|
|
|
|
print("1. 高质量对话数据集 (分数≥8.0)")
|
|
|
|
|
|
print("2. 问题对话改进数据集 (分数<6.0)")
|
|
|
|
|
|
print("3. 角色一致性训练集")
|
|
|
|
|
|
print("4. 创意性增强训练集")
|
|
|
|
|
|
print("5. 完整对话质量数据集")
|
|
|
|
|
|
|
|
|
|
|
|
choice = input("请输入选择 (1-5): ").strip()
|
|
|
|
|
|
|
|
|
|
|
|
with sqlite3.connect(conv_mgr.db_path) as conn:
|
|
|
|
|
|
training_data = []
|
|
|
|
|
|
|
|
|
|
|
|
if choice == '1':
|
|
|
|
|
|
# 高质量对话数据集
|
|
|
|
|
|
print("\n生成高质量对话数据集...")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, score_details, score_feedback
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score >= 8.0
|
|
|
|
|
|
ORDER BY dialogue_score DESC
|
|
|
|
|
|
LIMIT 200
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for speaker, content, score_details, feedback in cursor.fetchall():
|
|
|
|
|
|
training_data.append({
|
|
|
|
|
|
"type": "high_quality",
|
|
|
|
|
|
"speaker": speaker,
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
"scores": json.loads(score_details) if score_details else {},
|
|
|
|
|
|
"feedback": feedback,
|
|
|
|
|
|
"label": "positive"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
output_file = f"{output_dir}/high_quality_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '2':
|
|
|
|
|
|
# 问题对话改进数据集
|
|
|
|
|
|
print("\n生成问题对话改进数据集...")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, score_details, score_feedback
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score < 6.0 AND dialogue_score > 0
|
|
|
|
|
|
ORDER BY dialogue_score ASC
|
|
|
|
|
|
LIMIT 100
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for speaker, content, score_details, feedback in cursor.fetchall():
|
|
|
|
|
|
# 为每个低分对话生成改进指导
|
|
|
|
|
|
improvement_prompt = generate_improvement_prompt(content, feedback)
|
|
|
|
|
|
|
|
|
|
|
|
training_data.append({
|
|
|
|
|
|
"type": "improvement",
|
|
|
|
|
|
"speaker": speaker,
|
|
|
|
|
|
"original_content": content,
|
|
|
|
|
|
"scores": json.loads(score_details) if score_details else {},
|
|
|
|
|
|
"problems": feedback,
|
|
|
|
|
|
"improvement_prompt": improvement_prompt,
|
|
|
|
|
|
"label": "needs_improvement"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
output_file = f"{output_dir}/improvement_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '3':
|
|
|
|
|
|
# 角色一致性训练集
|
|
|
|
|
|
print("\n生成角色一致性训练集...")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, score_details
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
ORDER BY json_extract(score_details, '$.character_consistency') DESC
|
|
|
|
|
|
LIMIT 150
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for speaker, content, score_details in cursor.fetchall():
|
|
|
|
|
|
scores = json.loads(score_details) if score_details else {}
|
|
|
|
|
|
char_consistency = scores.get('character_consistency', 0)
|
|
|
|
|
|
|
|
|
|
|
|
training_data.append({
|
|
|
|
|
|
"type": "character_consistency",
|
|
|
|
|
|
"speaker": speaker,
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
"character_consistency_score": char_consistency,
|
|
|
|
|
|
"label": "high_consistency" if char_consistency >= 8.0 else "medium_consistency"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
output_file = f"{output_dir}/character_consistency_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '4':
|
|
|
|
|
|
# 创意性增强训练集
|
|
|
|
|
|
print("\n生成创意性增强训练集...")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, score_details
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
ORDER BY json_extract(score_details, '$.creativity') DESC
|
|
|
|
|
|
LIMIT 150
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for speaker, content, score_details in cursor.fetchall():
|
|
|
|
|
|
scores = json.loads(score_details) if score_details else {}
|
|
|
|
|
|
creativity = scores.get('creativity', 0)
|
|
|
|
|
|
|
|
|
|
|
|
training_data.append({
|
|
|
|
|
|
"type": "creativity",
|
|
|
|
|
|
"speaker": speaker,
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
"creativity_score": creativity,
|
|
|
|
|
|
"label": "high_creativity" if creativity >= 8.0 else "medium_creativity"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
output_file = f"{output_dir}/creativity_enhancement_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '5':
|
|
|
|
|
|
# 完整对话质量数据集
|
|
|
|
|
|
print("\n生成完整对话质量数据集...")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, dialogue_score, score_details, score_feedback
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
ORDER BY timestamp DESC
|
|
|
|
|
|
LIMIT 300
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for speaker, content, score, score_details, feedback in cursor.fetchall():
|
|
|
|
|
|
training_data.append({
|
|
|
|
|
|
"type": "complete_dataset",
|
|
|
|
|
|
"speaker": speaker,
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
"overall_score": score,
|
|
|
|
|
|
"dimension_scores": json.loads(score_details) if score_details else {},
|
|
|
|
|
|
"feedback": feedback,
|
|
|
|
|
|
"quality_label": get_quality_label(score)
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
output_file = f"{output_dir}/complete_quality_dataset_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("❌ 无效选择")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if training_data:
|
|
|
|
|
|
# 保存训练数据
|
|
|
|
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump({
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"created_at": datetime.now().isoformat(),
|
|
|
|
|
|
"total_samples": len(training_data),
|
|
|
|
|
|
"data_type": choice,
|
|
|
|
|
|
"source": "dialogue_scoring_system"
|
|
|
|
|
|
},
|
|
|
|
|
|
"data": training_data
|
|
|
|
|
|
}, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n✓ 训练数据集生成成功!")
|
|
|
|
|
|
print(f" - 文件路径: {output_file}")
|
|
|
|
|
|
print(f" - 数据条数: {len(training_data)}")
|
|
|
|
|
|
print(f" - 数据类型: {get_dataset_description(choice)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 生成数据集统计信息
|
|
|
|
|
|
generate_dataset_statistics(training_data, choice)
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("❌ 未找到符合条件的数据")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 训练数据集生成失败: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
def generate_improvement_prompt(content, feedback):
|
|
|
|
|
|
"""生成改进提示"""
|
|
|
|
|
|
return f"""原对话: {content}
|
|
|
|
|
|
|
|
|
|
|
|
问题分析: {feedback}
|
|
|
|
|
|
|
|
|
|
|
|
改进要求:
|
|
|
|
|
|
1. 保持角色特征和设定
|
|
|
|
|
|
2. 增强对话的逻辑性和连贯性
|
|
|
|
|
|
3. 提升表达的自然度
|
|
|
|
|
|
4. 增加信息密度,避免冗余
|
|
|
|
|
|
5. 适当增加创意元素
|
|
|
|
|
|
|
|
|
|
|
|
请生成一个改进版本的对话。"""
|
|
|
|
|
|
|
|
|
|
|
|
def get_quality_label(score):
|
|
|
|
|
|
"""根据分数获取质量标签"""
|
|
|
|
|
|
if score >= 9.0:
|
|
|
|
|
|
return "excellent"
|
|
|
|
|
|
elif score >= 8.0:
|
|
|
|
|
|
return "good"
|
|
|
|
|
|
elif score >= 7.0:
|
|
|
|
|
|
return "average"
|
|
|
|
|
|
elif score >= 6.0:
|
|
|
|
|
|
return "below_average"
|
|
|
|
|
|
else:
|
|
|
|
|
|
return "poor"
|
|
|
|
|
|
|
|
|
|
|
|
def get_dataset_description(choice):
|
|
|
|
|
|
"""获取数据集描述"""
|
|
|
|
|
|
descriptions = {
|
|
|
|
|
|
'1': "高质量对话数据集",
|
|
|
|
|
|
'2': "问题对话改进数据集",
|
|
|
|
|
|
'3': "角色一致性训练集",
|
|
|
|
|
|
'4': "创意性增强训练集",
|
|
|
|
|
|
'5': "完整对话质量数据集"
|
|
|
|
|
|
}
|
|
|
|
|
|
return descriptions.get(choice, "未知类型")
|
|
|
|
|
|
|
|
|
|
|
|
def generate_dataset_statistics(training_data, data_type):
|
|
|
|
|
|
"""生成数据集统计信息"""
|
|
|
|
|
|
print(f"\n数据集统计信息:")
|
|
|
|
|
|
|
|
|
|
|
|
if data_type == '1': # 高质量数据集
|
|
|
|
|
|
speakers = {}
|
|
|
|
|
|
for item in training_data:
|
|
|
|
|
|
speaker = item['speaker']
|
|
|
|
|
|
speakers[speaker] = speakers.get(speaker, 0) + 1
|
|
|
|
|
|
|
|
|
|
|
|
print(f" - 角色分布:")
|
|
|
|
|
|
for speaker, count in speakers.items():
|
|
|
|
|
|
print(f" • {speaker}: {count}条对话")
|
|
|
|
|
|
|
|
|
|
|
|
elif data_type == '5': # 完整数据集
|
|
|
|
|
|
quality_dist = {}
|
|
|
|
|
|
for item in training_data:
|
|
|
|
|
|
label = item['quality_label']
|
|
|
|
|
|
quality_dist[label] = quality_dist.get(label, 0) + 1
|
|
|
|
|
|
|
|
|
|
|
|
print(f" - 质量分布:")
|
|
|
|
|
|
for label, count in quality_dist.items():
|
|
|
|
|
|
print(f" • {label}: {count}条对话")
|
|
|
|
|
|
|
|
|
|
|
|
def run_model_optimization():
|
|
|
|
|
|
"""运行模型迭代优化"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("模型迭代优化")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
from dual_ai_dialogue_system import ConversationManager
|
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
import json
|
|
|
|
|
|
import os
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
|
|
|
|
|
|
|
|
|
|
|
print("模型优化选项:")
|
|
|
|
|
|
print("1. 分析优化需求")
|
|
|
|
|
|
print("2. 生成LoRA训练脚本")
|
|
|
|
|
|
print("3. 创建提示优化配置")
|
|
|
|
|
|
print("4. 执行增量训练")
|
|
|
|
|
|
print("5. 性能对比验证")
|
|
|
|
|
|
|
|
|
|
|
|
choice = input("请输入选择 (1-5): ").strip()
|
|
|
|
|
|
|
|
|
|
|
|
if choice == '1':
|
|
|
|
|
|
# 分析优化需求
|
|
|
|
|
|
print("\n=== 优化需求分析 ===")
|
|
|
|
|
|
|
|
|
|
|
|
with sqlite3.connect(conv_mgr.db_path) as conn:
|
|
|
|
|
|
# 获取性能数据
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
COUNT(*) as total,
|
|
|
|
|
|
AVG(dialogue_score) as avg_score,
|
|
|
|
|
|
AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate
|
|
|
|
|
|
FROM dialogue_turns WHERE dialogue_score > 0
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
total, avg_score, hq_rate = cursor.fetchone()
|
|
|
|
|
|
|
|
|
|
|
|
print(f"当前性能指标:")
|
|
|
|
|
|
print(f" - 总评分样本: {total}")
|
|
|
|
|
|
print(f" - 平均分数: {avg_score:.2f}")
|
|
|
|
|
|
print(f" - 高质量率: {hq_rate*100:.1f}%")
|
|
|
|
|
|
|
|
|
|
|
|
# 分析优化潜力
|
|
|
|
|
|
optimization_needs = []
|
|
|
|
|
|
|
|
|
|
|
|
if avg_score < 7.0:
|
|
|
|
|
|
optimization_needs.append("整体质量需要提升")
|
|
|
|
|
|
|
|
|
|
|
|
if hq_rate < 0.6:
|
|
|
|
|
|
optimization_needs.append("高质量对话比例偏低")
|
|
|
|
|
|
|
|
|
|
|
|
# 分析各维度表现
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT score_details FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0 AND score_details != '{}'
|
|
|
|
|
|
ORDER BY timestamp DESC LIMIT 100
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
dim_scores = {'coherence': [], 'character_consistency': [],
|
|
|
|
|
|
'naturalness': [], 'information_density': [], 'creativity': []}
|
|
|
|
|
|
|
|
|
|
|
|
for (score_details,) in cursor.fetchall():
|
|
|
|
|
|
try:
|
|
|
|
|
|
scores = json.loads(score_details)
|
|
|
|
|
|
for dim, score in scores.items():
|
|
|
|
|
|
if dim in dim_scores:
|
|
|
|
|
|
dim_scores[dim].append(float(score))
|
|
|
|
|
|
except:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
weak_dimensions = []
|
|
|
|
|
|
print(f"\n维度分析:")
|
|
|
|
|
|
for dim, scores in dim_scores.items():
|
|
|
|
|
|
if scores:
|
|
|
|
|
|
avg = sum(scores) / len(scores)
|
|
|
|
|
|
print(f" - {dim}: {avg:.2f}分")
|
|
|
|
|
|
if avg < 7.0:
|
|
|
|
|
|
weak_dimensions.append(dim)
|
|
|
|
|
|
|
|
|
|
|
|
if weak_dimensions:
|
|
|
|
|
|
optimization_needs.append(f"薄弱维度: {weak_dimensions}")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n优化建议:")
|
|
|
|
|
|
if optimization_needs:
|
|
|
|
|
|
for i, need in enumerate(optimization_needs, 1):
|
|
|
|
|
|
print(f" {i}. {need}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(" 当前模型表现良好,可考虑细微调优")
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '2':
|
|
|
|
|
|
# 生成LoRA训练脚本
|
|
|
|
|
|
print("\n=== 生成LoRA训练脚本 ===")
|
|
|
|
|
|
|
|
|
|
|
|
script_content = generate_lora_training_script()
|
|
|
|
|
|
script_path = "./scripts/iterative_lora_training.py"
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs("./scripts", exist_ok=True)
|
|
|
|
|
|
with open(script_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
f.write(script_content)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✓ LoRA训练脚本已生成: {script_path}")
|
|
|
|
|
|
print("使用方法:")
|
|
|
|
|
|
print(" 1. 先运行训练数据生成 (选项8)")
|
|
|
|
|
|
print(" 2. 修改脚本中的路径配置")
|
|
|
|
|
|
print(f" 3. 运行: python {script_path}")
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '3':
|
|
|
|
|
|
# 创建提示优化配置
|
|
|
|
|
|
print("\n=== 创建提示优化配置 ===")
|
|
|
|
|
|
|
|
|
|
|
|
config = generate_prompt_optimization_config()
|
|
|
|
|
|
config_path = "./config/prompt_optimization.json"
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs("./config", exist_ok=True)
|
|
|
|
|
|
with open(config_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(config, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✓ 提示优化配置已生成: {config_path}")
|
|
|
|
|
|
print("配置包含:")
|
|
|
|
|
|
print(" - 动态提示调整规则")
|
|
|
|
|
|
print(" - 质量阈值设置")
|
|
|
|
|
|
print(" - 生成参数优化")
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '4':
|
|
|
|
|
|
# 执行增量训练
|
|
|
|
|
|
print("\n=== 执行增量训练 ===")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查训练数据
|
|
|
|
|
|
training_dir = "./training_data"
|
|
|
|
|
|
if not os.path.exists(training_dir):
|
|
|
|
|
|
print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')]
|
|
|
|
|
|
if not training_files:
|
|
|
|
|
|
print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
print(f"找到训练数据文件:")
|
|
|
|
|
|
for i, file in enumerate(training_files, 1):
|
|
|
|
|
|
print(f" {i}. {file}")
|
|
|
|
|
|
|
|
|
|
|
|
file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip()
|
|
|
|
|
|
try:
|
|
|
|
|
|
selected_file = training_files[int(file_idx) - 1]
|
|
|
|
|
|
training_file_path = os.path.join(training_dir, selected_file)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"将使用训练文件: {selected_file}")
|
|
|
|
|
|
print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源")
|
|
|
|
|
|
|
|
|
|
|
|
# 生成训练命令
|
|
|
|
|
|
training_command = generate_training_command(training_file_path)
|
|
|
|
|
|
print(f"建议训练命令:")
|
|
|
|
|
|
print(f" {training_command}")
|
|
|
|
|
|
|
|
|
|
|
|
# 可选:执行训练(需要用户确认)
|
|
|
|
|
|
confirm = input("是否现在执行训练?(y/N): ").strip().lower()
|
|
|
|
|
|
if confirm == 'y':
|
|
|
|
|
|
print("开始增量训练...")
|
|
|
|
|
|
# 这里可以添加实际的训练执行逻辑
|
|
|
|
|
|
print("⚠ 训练功能需要根据实际环境配置")
|
|
|
|
|
|
|
|
|
|
|
|
except (ValueError, IndexError):
|
|
|
|
|
|
print("❌ 无效的文件选择")
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '5':
|
|
|
|
|
|
# 性能对比验证
|
|
|
|
|
|
print("\n=== 性能对比验证 ===")
|
|
|
|
|
|
|
|
|
|
|
|
print("验证选项:")
|
|
|
|
|
|
print("1. 历史性能趋势对比")
|
|
|
|
|
|
print("2. A/B测试配置生成")
|
|
|
|
|
|
print("3. 模型版本性能对比")
|
|
|
|
|
|
|
|
|
|
|
|
verify_choice = input("请输入选择 (1-3): ").strip()
|
|
|
|
|
|
|
|
|
|
|
|
if verify_choice == '1':
|
|
|
|
|
|
# 历史性能趋势
|
|
|
|
|
|
with sqlite3.connect(conv_mgr.db_path) as conn:
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
DATE(timestamp) as date,
|
|
|
|
|
|
AVG(dialogue_score) as avg_score,
|
|
|
|
|
|
COUNT(*) as count
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
AND timestamp >= datetime('now', '-30 days')
|
|
|
|
|
|
GROUP BY DATE(timestamp)
|
|
|
|
|
|
ORDER BY date ASC
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
trend_data = cursor.fetchall()
|
|
|
|
|
|
if trend_data:
|
|
|
|
|
|
print(f"30天性能趋势:")
|
|
|
|
|
|
for date, avg_score, count in trend_data:
|
|
|
|
|
|
trend = "📈" if avg_score > 7.5 else "📉" if avg_score < 6.5 else "📊"
|
|
|
|
|
|
print(f" {date}: {avg_score:.2f}分 {trend} ({count}条对话)")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("暂无足够的历史数据")
|
|
|
|
|
|
|
|
|
|
|
|
elif verify_choice == '2':
|
|
|
|
|
|
# A/B测试配置
|
|
|
|
|
|
ab_config = generate_ab_test_config()
|
|
|
|
|
|
ab_config_path = "./config/ab_test_config.json"
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs("./config", exist_ok=True)
|
|
|
|
|
|
with open(ab_config_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(ab_config, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✓ A/B测试配置已生成: {ab_config_path}")
|
|
|
|
|
|
print("配置包含:")
|
|
|
|
|
|
print(" - 对照组和实验组设置")
|
|
|
|
|
|
print(" - 评估指标定义")
|
|
|
|
|
|
print(" - 测试持续时间配置")
|
|
|
|
|
|
|
|
|
|
|
|
elif verify_choice == '3':
|
|
|
|
|
|
print("模型版本对比功能开发中...")
|
|
|
|
|
|
print("建议手动记录不同版本的性能指标进行对比")
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("❌ 无效选择")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 模型优化失败: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
def generate_lora_training_script():
|
|
|
|
|
|
"""生成LoRA训练脚本"""
|
|
|
|
|
|
return '''#!/usr/bin/env python
|
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
"""
|
|
|
|
|
|
基于评分数据的LoRA增量训练脚本
|
|
|
|
|
|
自动生成 - 请根据实际环境调整配置
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
from peft import LoraConfig, get_peft_model, TaskType
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
class IterativeLoRATrainer:
|
|
|
|
|
|
def __init__(self, base_model_path, training_data_path, output_path):
|
|
|
|
|
|
self.base_model_path = base_model_path
|
|
|
|
|
|
self.training_data_path = training_data_path
|
|
|
|
|
|
self.output_path = output_path
|
|
|
|
|
|
|
|
|
|
|
|
# LoRA配置
|
|
|
|
|
|
self.lora_config = LoraConfig(
|
|
|
|
|
|
task_type=TaskType.CAUSAL_LM,
|
|
|
|
|
|
inference_mode=False,
|
|
|
|
|
|
r=16, # LoRA rank
|
|
|
|
|
|
lora_alpha=32,
|
|
|
|
|
|
lora_dropout=0.1,
|
|
|
|
|
|
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def load_training_data(self):
|
|
|
|
|
|
"""加载训练数据"""
|
|
|
|
|
|
with open(self.training_data_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
data = json.load(f)
|
|
|
|
|
|
return data['data']
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_training_samples(self, data):
|
|
|
|
|
|
"""准备训练样本"""
|
|
|
|
|
|
samples = []
|
|
|
|
|
|
|
|
|
|
|
|
for item in data:
|
|
|
|
|
|
if item.get('label') == 'positive' or item.get('overall_score', 0) >= 8.0:
|
|
|
|
|
|
# 高质量样本
|
|
|
|
|
|
sample = {
|
|
|
|
|
|
'input': f"角色: {item['speaker']}\\n请生成高质量对话:",
|
|
|
|
|
|
'output': item['content'],
|
|
|
|
|
|
'quality_score': item.get('overall_score', 8.0)
|
|
|
|
|
|
}
|
|
|
|
|
|
samples.append(sample)
|
|
|
|
|
|
|
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
|
"""执行训练"""
|
|
|
|
|
|
print("开始LoRA增量训练...")
|
|
|
|
|
|
|
|
|
|
|
|
# 加载模型和分词器
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(self.base_model_path)
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
self.base_model_path,
|
|
|
|
|
|
torch_dtype=torch.bfloat16,
|
|
|
|
|
|
device_map="auto"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 应用LoRA
|
|
|
|
|
|
model = get_peft_model(model, self.lora_config)
|
|
|
|
|
|
|
|
|
|
|
|
# 加载训练数据
|
|
|
|
|
|
training_data = self.load_training_data()
|
|
|
|
|
|
training_samples = self.prepare_training_samples(training_data)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"训练样本数量: {len(training_samples)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 这里添加实际的训练循环
|
|
|
|
|
|
# 建议使用transformers的Trainer或自定义训练循环
|
|
|
|
|
|
|
|
|
|
|
|
# 保存模型
|
|
|
|
|
|
model.save_pretrained(self.output_path)
|
|
|
|
|
|
tokenizer.save_pretrained(self.output_path)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"训练完成,模型保存到: {self.output_path}")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
# 配置参数 - 请根据实际情况修改
|
|
|
|
|
|
BASE_MODEL_PATH = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
|
|
|
|
|
|
TRAINING_DATA_PATH = './training_data/high_quality_dialogues_latest.json'
|
|
|
|
|
|
OUTPUT_PATH = './output/iterative_lora_v2'
|
|
|
|
|
|
|
|
|
|
|
|
trainer = IterativeLoRATrainer(BASE_MODEL_PATH, TRAINING_DATA_PATH, OUTPUT_PATH)
|
|
|
|
|
|
trainer.train()
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
def generate_prompt_optimization_config():
|
|
|
|
|
|
"""生成提示优化配置"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"optimization_rules": {
|
|
|
|
|
|
"quality_thresholds": {
|
|
|
|
|
|
"excellent": 9.0,
|
|
|
|
|
|
"good": 8.0,
|
|
|
|
|
|
"acceptable": 7.0,
|
|
|
|
|
|
"needs_improvement": 6.0
|
|
|
|
|
|
},
|
|
|
|
|
|
"adaptive_adjustments": {
|
|
|
|
|
|
"low_coherence": {
|
|
|
|
|
|
"add_context_emphasis": True,
|
|
|
|
|
|
"increase_history_weight": 1.2,
|
|
|
|
|
|
"add_logical_constraints": True
|
|
|
|
|
|
},
|
|
|
|
|
|
"low_character_consistency": {
|
|
|
|
|
|
"enhance_character_description": True,
|
|
|
|
|
|
"add_personality_reminders": True,
|
|
|
|
|
|
"increase_character_weight": 1.5
|
|
|
|
|
|
},
|
|
|
|
|
|
"low_creativity": {
|
|
|
|
|
|
"add_creativity_prompts": True,
|
|
|
|
|
|
"increase_temperature": 0.1,
|
|
|
|
|
|
"diversify_examples": True
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"dynamic_prompts": {
|
|
|
|
|
|
"quality_boost_templates": [
|
|
|
|
|
|
"请生成一个富有创意且符合角色特征的高质量对话",
|
|
|
|
|
|
"确保对话逻辑连贯,信息丰富,表达自然",
|
|
|
|
|
|
"体现角色的独特性格和说话风格"
|
|
|
|
|
|
],
|
|
|
|
|
|
"problem_specific_guidance": {
|
|
|
|
|
|
"repetitive": "避免重复之前的内容,提供新的信息和观点",
|
|
|
|
|
|
"inconsistent": "严格遵循角色设定,保持人格一致性",
|
|
|
|
|
|
"dull": "增加对话的趣味性和深度,使用生动的表达"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"generation_parameters": {
|
|
|
|
|
|
"adaptive_temperature": {
|
|
|
|
|
|
"high_creativity_needed": 0.9,
|
|
|
|
|
|
"normal": 0.8,
|
|
|
|
|
|
"high_consistency_needed": 0.7
|
|
|
|
|
|
},
|
|
|
|
|
|
"adaptive_top_p": {
|
|
|
|
|
|
"creative_mode": 0.9,
|
|
|
|
|
|
"balanced_mode": 0.8,
|
|
|
|
|
|
"conservative_mode": 0.7
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def generate_ab_test_config():
|
|
|
|
|
|
"""生成A/B测试配置"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"test_name": "model_optimization_validation",
|
|
|
|
|
|
"created_at": datetime.now().isoformat(),
|
|
|
|
|
|
"groups": {
|
|
|
|
|
|
"control": {
|
|
|
|
|
|
"name": "原始模型",
|
|
|
|
|
|
"description": "未优化的基础模型",
|
|
|
|
|
|
"model_path": "/path/to/base/model",
|
|
|
|
|
|
"sample_ratio": 0.5
|
|
|
|
|
|
},
|
|
|
|
|
|
"experimental": {
|
|
|
|
|
|
"name": "优化模型",
|
|
|
|
|
|
"description": "经过迭代优化的模型",
|
|
|
|
|
|
"model_path": "/path/to/optimized/model",
|
|
|
|
|
|
"sample_ratio": 0.5
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"evaluation_metrics": {
|
|
|
|
|
|
"primary": [
|
|
|
|
|
|
"overall_dialogue_score",
|
|
|
|
|
|
"character_consistency_score",
|
|
|
|
|
|
"creativity_score"
|
|
|
|
|
|
],
|
|
|
|
|
|
"secondary": [
|
|
|
|
|
|
"user_satisfaction",
|
|
|
|
|
|
"response_time",
|
|
|
|
|
|
"coherence_score"
|
|
|
|
|
|
]
|
|
|
|
|
|
},
|
|
|
|
|
|
"test_duration": {
|
|
|
|
|
|
"target_samples": 200,
|
|
|
|
|
|
"max_duration_days": 7,
|
|
|
|
|
|
"min_samples_per_group": 50
|
|
|
|
|
|
},
|
|
|
|
|
|
"statistical_settings": {
|
|
|
|
|
|
"confidence_level": 0.95,
|
|
|
|
|
|
"minimum_effect_size": 0.3,
|
|
|
|
|
|
"power": 0.8
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def generate_training_command(training_file_path):
|
|
|
|
|
|
"""生成训练命令"""
|
|
|
|
|
|
return f"python ./scripts/iterative_lora_training.py --data {training_file_path} --output ./output/optimized_model_v{datetime.now().strftime('%Y%m%d')}"
|
|
|
|
|
|
|
|
|
|
|
|
def show_scoring_statistics():
|
|
|
|
|
|
"""显示对话评分统计"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("对话评分统计")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
from dual_ai_dialogue_system import ConversationManager
|
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
|
|
|
|
|
|
|
|
|
|
|
# 查询评分数据
|
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
with sqlite3.connect(conv_mgr.db_path) as conn:
|
|
|
|
|
|
# 总体统计
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
COUNT(*) as total_turns,
|
|
|
|
|
|
AVG(dialogue_score) as avg_score,
|
|
|
|
|
|
MAX(dialogue_score) as max_score,
|
|
|
|
|
|
MIN(dialogue_score) as min_score
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
stats = cursor.fetchone()
|
|
|
|
|
|
if stats and stats[0] > 0:
|
|
|
|
|
|
total_turns, avg_score, max_score, min_score = stats
|
|
|
|
|
|
print(f"总体统计:")
|
|
|
|
|
|
print(f" - 已评分对话轮数: {total_turns}")
|
|
|
|
|
|
print(f" - 平均分数: {avg_score:.2f}")
|
|
|
|
|
|
print(f" - 最高分数: {max_score:.2f}")
|
|
|
|
|
|
print(f" - 最低分数: {min_score:.2f}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("暂无评分数据")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 按角色统计
|
|
|
|
|
|
print(f"\n按角色统计:")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
speaker,
|
|
|
|
|
|
COUNT(*) as turns,
|
|
|
|
|
|
AVG(dialogue_score) as avg_score,
|
|
|
|
|
|
MAX(dialogue_score) as max_score
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
GROUP BY speaker
|
|
|
|
|
|
ORDER BY avg_score DESC
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for row in cursor.fetchall():
|
|
|
|
|
|
speaker, turns, avg_score, max_score = row
|
|
|
|
|
|
print(f" - {speaker}: 平均{avg_score:.2f}分 (最高{max_score:.2f}分, {turns}轮对话)")
|
|
|
|
|
|
|
|
|
|
|
|
# 最近高分对话
|
|
|
|
|
|
print(f"\n最近高分对话 (分数≥8.0):")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, dialogue_score, score_feedback, timestamp
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score >= 8.0
|
|
|
|
|
|
ORDER BY timestamp DESC
|
|
|
|
|
|
LIMIT 5
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
high_score_turns = cursor.fetchall()
|
|
|
|
|
|
if high_score_turns:
|
|
|
|
|
|
for speaker, content, score, feedback, timestamp in high_score_turns:
|
|
|
|
|
|
print(f" [{timestamp[:16]}] {speaker} ({score:.2f}分)")
|
|
|
|
|
|
print(f" 内容: {content[:80]}...")
|
|
|
|
|
|
if feedback:
|
|
|
|
|
|
print(f" 评价: {feedback[:60]}...")
|
|
|
|
|
|
print()
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(" 暂无高分对话")
|
|
|
|
|
|
|
|
|
|
|
|
# 分数分布统计
|
|
|
|
|
|
print(f"\n分数分布:")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
CASE
|
|
|
|
|
|
WHEN dialogue_score >= 9.0 THEN '优秀 (9-10分)'
|
|
|
|
|
|
WHEN dialogue_score >= 8.0 THEN '良好 (8-9分)'
|
|
|
|
|
|
WHEN dialogue_score >= 7.0 THEN '中等 (7-8分)'
|
|
|
|
|
|
WHEN dialogue_score >= 6.0 THEN '及格 (6-7分)'
|
|
|
|
|
|
ELSE '待改进 (<6分)'
|
|
|
|
|
|
END as score_range,
|
|
|
|
|
|
COUNT(*) as count
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
GROUP BY score_range
|
|
|
|
|
|
ORDER BY MIN(dialogue_score) DESC
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for score_range, count in cursor.fetchall():
|
|
|
|
|
|
percentage = (count / total_turns) * 100
|
|
|
|
|
|
print(f" - {score_range}: {count}轮 ({percentage:.1f}%)")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 评分统计查询失败: {e}")
|
|
|
|
|
|
|
2025-08-15 17:25:07 +08:00
|
|
|
|
def show_system_status():
|
|
|
|
|
|
"""显示系统状态"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("系统状态检查")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查文件
|
|
|
|
|
|
files_to_check = [
|
|
|
|
|
|
("./knowledge_base/worldview_template_coc.json", "世界观模板"),
|
|
|
|
|
|
("./knowledge_base/character_template_detective.json", "侦探角色"),
|
|
|
|
|
|
("./knowledge_base/character_template_professor.json", "教授角色"),
|
|
|
|
|
|
("./pdf_to_rag_processor.py", "PDF处理器"),
|
|
|
|
|
|
("./dual_ai_dialogue_system.py", "对话系统"),
|
|
|
|
|
|
("./npc_dialogue_generator.py", "NPC生成器")
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
print("\n文件检查:")
|
|
|
|
|
|
for file_path, description in files_to_check:
|
|
|
|
|
|
if os.path.exists(file_path):
|
|
|
|
|
|
print(f"✓ {description}: {file_path}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"✗ {description}: {file_path} (不存在)")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查目录
|
|
|
|
|
|
print("\n目录检查:")
|
|
|
|
|
|
directories = ["./knowledge_base", "./rag_knowledge", "./conversation_data"]
|
|
|
|
|
|
for dir_path in directories:
|
|
|
|
|
|
if os.path.exists(dir_path):
|
|
|
|
|
|
file_count = len([f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))])
|
|
|
|
|
|
print(f"✓ {dir_path}: {file_count} 个文件")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"✗ {dir_path}: 不存在")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查对话会话
|
|
|
|
|
|
try:
|
|
|
|
|
|
from dual_ai_dialogue_system import ConversationManager
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
|
|
|
|
|
sessions = conv_mgr.list_sessions()
|
|
|
|
|
|
print(f"\n✓ 对话会话: {len(sessions)} 个")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"\n✗ 对话会话检查失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
"""主控制程序"""
|
|
|
|
|
|
print("="*70)
|
|
|
|
|
|
print(" 双AI角色对话系统 - 主控制程序")
|
|
|
|
|
|
print(" 基于RAG的世界观增强对话引擎")
|
|
|
|
|
|
print("="*70)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查依赖
|
|
|
|
|
|
if not check_dependencies():
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 设置目录
|
|
|
|
|
|
# setup_directories()
|
|
|
|
|
|
# copy_demo_files()
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
print("\n" + "="*50)
|
|
|
|
|
|
print("主菜单 - 请选择操作:")
|
|
|
|
|
|
print("1. 处理PDF世界观文档 (转换为RAG格式)")
|
|
|
|
|
|
print("2. 查看角色设定信息")
|
2025-08-15 17:58:11 +08:00
|
|
|
|
print("3. 启动双AI对话系统 (支持双模型对话)")
|
2025-08-15 17:25:07 +08:00
|
|
|
|
print("4. 创建演示对话场景")
|
|
|
|
|
|
print("5. 系统状态检查")
|
2025-08-23 17:10:23 +08:00
|
|
|
|
print("6. 查看对话评分统计")
|
|
|
|
|
|
print("7. 模型性能分析与优化")
|
|
|
|
|
|
print("8. 生成训练数据集")
|
|
|
|
|
|
print("9. 模型迭代优化")
|
|
|
|
|
|
print("10. 查看使用说明")
|
2025-08-15 17:25:07 +08:00
|
|
|
|
print("0. 退出")
|
|
|
|
|
|
print("="*50)
|
|
|
|
|
|
|
2025-08-23 17:10:23 +08:00
|
|
|
|
choice = input("请输入选择 (0-10): ").strip()
|
2025-08-15 17:25:07 +08:00
|
|
|
|
|
|
|
|
|
|
if choice == '0':
|
|
|
|
|
|
print("\n感谢使用双AI角色对话系统!")
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '1':
|
|
|
|
|
|
process_pdf_workflow()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '2':
|
|
|
|
|
|
show_character_info()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '3':
|
|
|
|
|
|
run_dialogue_system()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '4':
|
|
|
|
|
|
create_demo_scenario()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '5':
|
|
|
|
|
|
show_system_status()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '6':
|
2025-08-23 17:10:23 +08:00
|
|
|
|
show_scoring_statistics()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '7':
|
|
|
|
|
|
analyze_model_performance()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '8':
|
|
|
|
|
|
generate_training_dataset()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '9':
|
|
|
|
|
|
run_model_optimization()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '10':
|
2025-08-15 17:25:07 +08:00
|
|
|
|
show_usage_guide()
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("❌ 无效选择,请重新输入")
|
|
|
|
|
|
|
|
|
|
|
|
def show_usage_guide():
|
|
|
|
|
|
"""显示使用说明"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("系统使用说明")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
guide = """
|
|
|
|
|
|
🚀 快速开始:
|
|
|
|
|
|
1. 首次使用建议先运行"创建演示对话场景"
|
|
|
|
|
|
2. 如有PDF世界观文档,选择"处理PDF世界观文档"
|
|
|
|
|
|
3. 通过"启动双AI对话系统"开始角色对话
|
|
|
|
|
|
|
|
|
|
|
|
📁 文档格式说明:
|
|
|
|
|
|
- 世界观文档: worldview_template_coc.json (参考COC设定)
|
|
|
|
|
|
- 角色设定: character_template_*.json (包含详细人设)
|
|
|
|
|
|
|
|
|
|
|
|
🔧 系统功能:
|
|
|
|
|
|
- PDF自动转换为RAG知识库
|
|
|
|
|
|
- 基于向量相似度的上下文检索
|
|
|
|
|
|
- 持久化对话历史存储
|
|
|
|
|
|
- 角色设定一致性保持
|
|
|
|
|
|
|
|
|
|
|
|
📝 自定义角色:
|
|
|
|
|
|
1. 参考 character_template_*.json 格式
|
|
|
|
|
|
2. 保存到 knowledge_base/ 目录
|
|
|
|
|
|
3. 重启对话系统加载新角色
|
|
|
|
|
|
|
|
|
|
|
|
💾 对话数据:
|
|
|
|
|
|
- 历史对话保存在 conversation_data/ 目录
|
|
|
|
|
|
- 支持会话恢复和历史查看
|
|
|
|
|
|
- 自动记录使用的上下文信息
|
|
|
|
|
|
|
|
|
|
|
|
⚠️ 注意事项:
|
|
|
|
|
|
- 确保模型路径正确设置
|
|
|
|
|
|
- 首次运行需要下载向量化模型
|
|
|
|
|
|
- PDF处理需要足够内存
|
|
|
|
|
|
"""
|
|
|
|
|
|
print(guide)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
main()
|