2025-08-15 17:25:07 +08:00
|
|
|
|
#!/usr/bin/env python
|
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
'''
|
|
|
|
|
|
双AI角色对话系统主控制程序
|
|
|
|
|
|
完整的工作流程:PDF处理 -> 角色加载 -> RAG对话 -> 历史记录
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
import sys
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
from typing import List, Dict
|
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
|
|
def check_dependencies():
|
|
|
|
|
|
"""检查依赖库"""
|
|
|
|
|
|
missing_deps = []
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
import PyPDF2
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
missing_deps.append("PyPDF2")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
import pymupdf
|
|
|
|
|
|
print("✓ pymupdf 可用")
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
print("⚠ pymupdf 不可用,将使用 PyPDF2")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
import sentence_transformers
|
|
|
|
|
|
import faiss
|
|
|
|
|
|
print("✓ 向量化功能可用")
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
print("⚠ 向量化功能不可用,将使用文本匹配")
|
|
|
|
|
|
|
|
|
|
|
|
if missing_deps:
|
|
|
|
|
|
print(f"✗ 缺少依赖库: {', '.join(missing_deps)}")
|
|
|
|
|
|
print("请运行: pip install PyPDF2 sentence-transformers faiss-cpu")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def setup_directories():
|
|
|
|
|
|
"""设置项目目录结构"""
|
|
|
|
|
|
directories = [
|
|
|
|
|
|
"./knowledge_base",
|
|
|
|
|
|
"./characters",
|
|
|
|
|
|
"./worldview",
|
|
|
|
|
|
"./rag_knowledge",
|
|
|
|
|
|
"./conversation_data"
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
for dir_path in directories:
|
|
|
|
|
|
os.makedirs(dir_path, exist_ok=True)
|
|
|
|
|
|
print(f"✓ 目录就绪: {dir_path}")
|
|
|
|
|
|
|
|
|
|
|
|
def copy_demo_files():
|
|
|
|
|
|
"""复制演示文档到知识库目录"""
|
|
|
|
|
|
file_mappings = [
|
|
|
|
|
|
("./worldview/worldview_template_coc.json", "./knowledge_base/worldview_template_coc.json"),
|
|
|
|
|
|
("./characters/character_template_detective.json", "./knowledge_base/character_template_detective.json"),
|
|
|
|
|
|
("./characters/character_template_professor.json", "./knowledge_base/character_template_professor.json")
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
for source, target in file_mappings:
|
|
|
|
|
|
if os.path.exists(source):
|
|
|
|
|
|
shutil.copy2(source, target)
|
|
|
|
|
|
print(f"✓ 复制文档: {os.path.basename(target)}")
|
|
|
|
|
|
|
|
|
|
|
|
def process_pdf_workflow():
|
|
|
|
|
|
"""PDF处理工作流"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("PDF世界观文档处理")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
from pdf_to_rag_processor import PDFToRAGProcessor
|
|
|
|
|
|
|
|
|
|
|
|
pdf_path = input("请输入PDF文件路径 (例: ./coc.pdf): ").strip()
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(pdf_path):
|
|
|
|
|
|
print(f"✗ 文件不存在: {pdf_path}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
processor = PDFToRAGProcessor()
|
|
|
|
|
|
result = processor.process_pdf_to_rag(pdf_path, "./rag_knowledge")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n✓ PDF处理完成!")
|
|
|
|
|
|
print(f" - 文档块数: {result['chunks_count']}")
|
|
|
|
|
|
print(f" - 概念数: {result['concepts_count']}")
|
|
|
|
|
|
print(f" - 向量索引: {'启用' if result['vector_enabled'] else '未启用'}")
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ PDF处理失败: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def show_character_info():
|
|
|
|
|
|
"""显示角色信息"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("角色设定信息")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
knowledge_dir = "./knowledge_base"
|
|
|
|
|
|
character_files = [f for f in os.listdir(knowledge_dir) if f.startswith('character') and f.endswith('.json')]
|
|
|
|
|
|
|
|
|
|
|
|
for char_file in character_files:
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(os.path.join(knowledge_dir, char_file), 'r', encoding='utf-8') as f:
|
|
|
|
|
|
char_data = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
name = char_data.get('character_name', '未知')
|
|
|
|
|
|
occupation = char_data.get('basic_info', {}).get('occupation', '未知')
|
|
|
|
|
|
traits = char_data.get('personality', {}).get('core_traits', [])
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n角色: {name}")
|
|
|
|
|
|
print(f" 职业: {occupation}")
|
|
|
|
|
|
print(f" 特点: {', '.join(traits[:3])}")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 读取角色文件失败: {char_file} - {e}")
|
|
|
|
|
|
|
2025-08-23 18:13:45 +08:00
|
|
|
|
def run_dialogue_system(enableScore: bool, useManualScoring: bool = False):
|
2025-08-15 17:58:11 +08:00
|
|
|
|
"""运行双AI对话系统"""
|
2025-08-15 17:25:07 +08:00
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("启动双AI角色对话系统")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
|
|
|
|
|
# 直接启动双模型对话
|
|
|
|
|
|
print("\n正在初始化双模型对话系统...")
|
|
|
|
|
|
|
|
|
|
|
|
from dual_ai_dialogue_system import RAGKnowledgeBase, ConversationManager, DualAIDialogueEngine
|
|
|
|
|
|
from npc_dialogue_generator import DualModelDialogueGenerator
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化组件
|
|
|
|
|
|
kb = RAGKnowledgeBase("./knowledge_base")
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查模型路径
|
2025-08-23 17:10:23 +08:00
|
|
|
|
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B'
|
2025-08-24 19:44:46 +08:00
|
|
|
|
lora_model_path = './output/iterative_lora_simple/final_model'
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(base_model_path):
|
|
|
|
|
|
print(f"✗ 基础模型路径不存在: {base_model_path}")
|
|
|
|
|
|
print("请修改 main_controller.py 中的模型路径")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(lora_model_path):
|
|
|
|
|
|
lora_model_path = None
|
|
|
|
|
|
print("⚠ LoRA模型不存在,使用基础模型")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查角色数据
|
|
|
|
|
|
if not hasattr(kb, 'character_data') or len(kb.character_data) < 2:
|
|
|
|
|
|
print("✗ 角色数据不足,无法创建双模型对话系统")
|
|
|
|
|
|
print("请确保knowledge_base目录中有至少两个角色文件")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 获取前两个角色
|
|
|
|
|
|
character_names = list(kb.character_data.keys())[:2]
|
|
|
|
|
|
char1_name = character_names[0]
|
|
|
|
|
|
char2_name = character_names[1]
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✓ 使用角色: {char1_name} 和 {char2_name}")
|
|
|
|
|
|
|
|
|
|
|
|
# 配置两个角色的模型
|
|
|
|
|
|
character1_config = {
|
|
|
|
|
|
"name": char1_name,
|
|
|
|
|
|
"lora_path": lora_model_path,
|
|
|
|
|
|
"character_data": kb.character_data[char1_name]
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
character2_config = {
|
|
|
|
|
|
"name": char2_name,
|
|
|
|
|
|
"lora_path": lora_model_path,
|
|
|
|
|
|
"character_data": kb.character_data[char2_name]
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 创建双模型对话生成器
|
|
|
|
|
|
print("正在初始化双模型对话生成器...")
|
|
|
|
|
|
dual_generator = DualModelDialogueGenerator(
|
|
|
|
|
|
base_model_path,
|
|
|
|
|
|
character1_config,
|
|
|
|
|
|
character2_config
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-08-23 17:10:23 +08:00
|
|
|
|
# 创建对话引擎(启用评分功能)
|
|
|
|
|
|
dialogue_engine = DualAIDialogueEngine(
|
|
|
|
|
|
kb,
|
|
|
|
|
|
conv_mgr,
|
|
|
|
|
|
dual_generator,
|
2025-08-23 17:27:01 +08:00
|
|
|
|
enable_scoring=enableScore,
|
2025-08-23 18:13:45 +08:00
|
|
|
|
base_model_path=base_model_path,
|
|
|
|
|
|
use_manual_scoring=useManualScoring
|
2025-08-23 17:10:23 +08:00
|
|
|
|
)
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
|
|
|
|
|
# 创建对话会话
|
|
|
|
|
|
characters = [char1_name, char2_name]
|
|
|
|
|
|
worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观'
|
|
|
|
|
|
|
|
|
|
|
|
session_id = conv_mgr.create_session(characters, worldview)
|
|
|
|
|
|
print(f"✓ 创建对话会话: {session_id}")
|
|
|
|
|
|
|
|
|
|
|
|
# 交互式对话循环
|
|
|
|
|
|
print(f"\n=== 双AI模型对话系统 ===")
|
|
|
|
|
|
print(f"角色: {char1_name} vs {char2_name}")
|
|
|
|
|
|
print(f"世界观: {worldview}")
|
|
|
|
|
|
print("输入 'quit' 退出对话")
|
|
|
|
|
|
print("-" * 50)
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 获取用户输入
|
|
|
|
|
|
user_input = input("\n请输入对话主题或指令: ").strip()
|
|
|
|
|
|
|
|
|
|
|
|
if user_input.lower() == 'quit':
|
|
|
|
|
|
print("退出双AI对话系统")
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
if not user_input:
|
|
|
|
|
|
print("请输入有效的对话主题")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 询问对话轮数
|
|
|
|
|
|
turns_input = input("请输入对话轮数 (默认4): ").strip()
|
|
|
|
|
|
turns = int(turns_input) if turns_input.isdigit() else 4
|
|
|
|
|
|
|
|
|
|
|
|
# 询问历史上下文设置
|
2025-08-18 18:25:19 +08:00
|
|
|
|
history_input = input("使用历史对话轮数 (默认2): ").strip()
|
|
|
|
|
|
history_count = int(history_input) if history_input.isdigit() else 2
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
2025-08-18 18:25:19 +08:00
|
|
|
|
context_input = input("使用上下文信息数量 (默认10): ").strip()
|
|
|
|
|
|
context_info_count = int(context_input) if context_input.isdigit() else 10
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
|
|
|
|
|
print(f"\n开始对话 - 主题: {user_input}")
|
|
|
|
|
|
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
|
|
|
|
|
|
print("-" * 50)
|
|
|
|
|
|
|
|
|
|
|
|
# 运行双模型对话
|
|
|
|
|
|
dialogue_engine.run_dual_model_conversation(
|
|
|
|
|
|
session_id, user_input, turns, history_count, context_info_count
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
print("-" * 50)
|
|
|
|
|
|
print("对话完成!")
|
|
|
|
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
|
print("\n\n用户中断对话")
|
|
|
|
|
|
break
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"对话过程中出现错误: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-08-15 17:25:07 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 对话系统启动失败: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-08-23 17:10:23 +08:00
|
|
|
|
|
|
|
|
|
|
def generate_training_dataset():
|
|
|
|
|
|
"""生成训练数据集"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("生成训练数据集")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
from dual_ai_dialogue_system import ConversationManager
|
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
import json
|
|
|
|
|
|
import os
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
|
|
|
|
|
|
|
|
|
|
|
# 创建输出目录
|
|
|
|
|
|
output_dir = "./training_data"
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
2025-08-23 18:20:50 +08:00
|
|
|
|
print("请选择训练数据生成类型:(默认1)")
|
2025-08-23 17:10:23 +08:00
|
|
|
|
print("1. 高质量对话数据集 (分数≥8.0)")
|
|
|
|
|
|
print("2. 问题对话改进数据集 (分数<6.0)")
|
|
|
|
|
|
print("3. 角色一致性训练集")
|
|
|
|
|
|
print("4. 创意性增强训练集")
|
|
|
|
|
|
print("5. 完整对话质量数据集")
|
|
|
|
|
|
|
|
|
|
|
|
choice = input("请输入选择 (1-5): ").strip()
|
2025-08-23 18:20:50 +08:00
|
|
|
|
choice = int(choice) if choice.isdigit() else 1
|
2025-08-23 17:10:23 +08:00
|
|
|
|
with sqlite3.connect(conv_mgr.db_path) as conn:
|
|
|
|
|
|
training_data = []
|
|
|
|
|
|
|
2025-08-23 18:20:50 +08:00
|
|
|
|
if choice == 1:
|
2025-08-23 17:10:23 +08:00
|
|
|
|
# 高质量对话数据集
|
|
|
|
|
|
print("\n生成高质量对话数据集...")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, score_details, score_feedback
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score >= 8.0
|
|
|
|
|
|
ORDER BY dialogue_score DESC
|
|
|
|
|
|
LIMIT 200
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for speaker, content, score_details, feedback in cursor.fetchall():
|
|
|
|
|
|
training_data.append({
|
|
|
|
|
|
"type": "high_quality",
|
|
|
|
|
|
"speaker": speaker,
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
"scores": json.loads(score_details) if score_details else {},
|
|
|
|
|
|
"feedback": feedback,
|
|
|
|
|
|
"label": "positive"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
output_file = f"{output_dir}/high_quality_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
|
|
|
|
|
|
|
2025-08-23 18:20:50 +08:00
|
|
|
|
elif choice == 2:
|
2025-08-23 17:10:23 +08:00
|
|
|
|
# 问题对话改进数据集
|
|
|
|
|
|
print("\n生成问题对话改进数据集...")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, score_details, score_feedback
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score < 6.0 AND dialogue_score > 0
|
|
|
|
|
|
ORDER BY dialogue_score ASC
|
|
|
|
|
|
LIMIT 100
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for speaker, content, score_details, feedback in cursor.fetchall():
|
|
|
|
|
|
# 为每个低分对话生成改进指导
|
|
|
|
|
|
improvement_prompt = generate_improvement_prompt(content, feedback)
|
|
|
|
|
|
|
|
|
|
|
|
training_data.append({
|
|
|
|
|
|
"type": "improvement",
|
|
|
|
|
|
"speaker": speaker,
|
|
|
|
|
|
"original_content": content,
|
|
|
|
|
|
"scores": json.loads(score_details) if score_details else {},
|
|
|
|
|
|
"problems": feedback,
|
|
|
|
|
|
"improvement_prompt": improvement_prompt,
|
|
|
|
|
|
"label": "needs_improvement"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
output_file = f"{output_dir}/improvement_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
|
|
|
|
|
|
|
2025-08-23 18:20:50 +08:00
|
|
|
|
elif choice == 3:
|
2025-08-23 17:10:23 +08:00
|
|
|
|
# 角色一致性训练集
|
|
|
|
|
|
print("\n生成角色一致性训练集...")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, score_details
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
ORDER BY json_extract(score_details, '$.character_consistency') DESC
|
|
|
|
|
|
LIMIT 150
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for speaker, content, score_details in cursor.fetchall():
|
|
|
|
|
|
scores = json.loads(score_details) if score_details else {}
|
|
|
|
|
|
char_consistency = scores.get('character_consistency', 0)
|
|
|
|
|
|
|
|
|
|
|
|
training_data.append({
|
|
|
|
|
|
"type": "character_consistency",
|
|
|
|
|
|
"speaker": speaker,
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
"character_consistency_score": char_consistency,
|
|
|
|
|
|
"label": "high_consistency" if char_consistency >= 8.0 else "medium_consistency"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
output_file = f"{output_dir}/character_consistency_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
|
|
|
|
|
|
|
2025-08-23 18:20:50 +08:00
|
|
|
|
elif choice == 4:
|
2025-08-23 17:10:23 +08:00
|
|
|
|
# 创意性增强训练集
|
|
|
|
|
|
print("\n生成创意性增强训练集...")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, score_details
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
ORDER BY json_extract(score_details, '$.creativity') DESC
|
|
|
|
|
|
LIMIT 150
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for speaker, content, score_details in cursor.fetchall():
|
|
|
|
|
|
scores = json.loads(score_details) if score_details else {}
|
|
|
|
|
|
creativity = scores.get('creativity', 0)
|
|
|
|
|
|
|
|
|
|
|
|
training_data.append({
|
|
|
|
|
|
"type": "creativity",
|
|
|
|
|
|
"speaker": speaker,
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
"creativity_score": creativity,
|
|
|
|
|
|
"label": "high_creativity" if creativity >= 8.0 else "medium_creativity"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
output_file = f"{output_dir}/creativity_enhancement_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
|
|
|
|
|
|
|
2025-08-23 18:20:50 +08:00
|
|
|
|
elif choice == 5:
|
2025-08-23 17:10:23 +08:00
|
|
|
|
# 完整对话质量数据集
|
|
|
|
|
|
print("\n生成完整对话质量数据集...")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, dialogue_score, score_details, score_feedback
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
ORDER BY timestamp DESC
|
|
|
|
|
|
LIMIT 300
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for speaker, content, score, score_details, feedback in cursor.fetchall():
|
|
|
|
|
|
training_data.append({
|
|
|
|
|
|
"type": "complete_dataset",
|
|
|
|
|
|
"speaker": speaker,
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
"overall_score": score,
|
|
|
|
|
|
"dimension_scores": json.loads(score_details) if score_details else {},
|
|
|
|
|
|
"feedback": feedback,
|
|
|
|
|
|
"quality_label": get_quality_label(score)
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
output_file = f"{output_dir}/complete_quality_dataset_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("❌ 无效选择")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if training_data:
|
|
|
|
|
|
# 保存训练数据
|
|
|
|
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump({
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"created_at": datetime.now().isoformat(),
|
|
|
|
|
|
"total_samples": len(training_data),
|
|
|
|
|
|
"data_type": choice,
|
|
|
|
|
|
"source": "dialogue_scoring_system"
|
|
|
|
|
|
},
|
|
|
|
|
|
"data": training_data
|
|
|
|
|
|
}, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n✓ 训练数据集生成成功!")
|
|
|
|
|
|
print(f" - 文件路径: {output_file}")
|
|
|
|
|
|
print(f" - 数据条数: {len(training_data)}")
|
|
|
|
|
|
print(f" - 数据类型: {get_dataset_description(choice)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 生成数据集统计信息
|
|
|
|
|
|
generate_dataset_statistics(training_data, choice)
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("❌ 未找到符合条件的数据")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 训练数据集生成失败: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
def generate_improvement_prompt(content, feedback):
|
|
|
|
|
|
"""生成改进提示"""
|
|
|
|
|
|
return f"""原对话: {content}
|
|
|
|
|
|
|
|
|
|
|
|
问题分析: {feedback}
|
|
|
|
|
|
|
|
|
|
|
|
改进要求:
|
|
|
|
|
|
1. 保持角色特征和设定
|
|
|
|
|
|
2. 增强对话的逻辑性和连贯性
|
|
|
|
|
|
3. 提升表达的自然度
|
|
|
|
|
|
4. 增加信息密度,避免冗余
|
|
|
|
|
|
5. 适当增加创意元素
|
|
|
|
|
|
|
|
|
|
|
|
请生成一个改进版本的对话。"""
|
|
|
|
|
|
|
|
|
|
|
|
def get_quality_label(score):
|
|
|
|
|
|
"""根据分数获取质量标签"""
|
|
|
|
|
|
if score >= 9.0:
|
|
|
|
|
|
return "excellent"
|
|
|
|
|
|
elif score >= 8.0:
|
|
|
|
|
|
return "good"
|
|
|
|
|
|
elif score >= 7.0:
|
|
|
|
|
|
return "average"
|
|
|
|
|
|
elif score >= 6.0:
|
|
|
|
|
|
return "below_average"
|
|
|
|
|
|
else:
|
|
|
|
|
|
return "poor"
|
|
|
|
|
|
|
|
|
|
|
|
def get_dataset_description(choice):
|
|
|
|
|
|
"""获取数据集描述"""
|
|
|
|
|
|
descriptions = {
|
|
|
|
|
|
'1': "高质量对话数据集",
|
|
|
|
|
|
'2': "问题对话改进数据集",
|
|
|
|
|
|
'3': "角色一致性训练集",
|
|
|
|
|
|
'4': "创意性增强训练集",
|
|
|
|
|
|
'5': "完整对话质量数据集"
|
|
|
|
|
|
}
|
|
|
|
|
|
return descriptions.get(choice, "未知类型")
|
|
|
|
|
|
|
|
|
|
|
|
def generate_dataset_statistics(training_data, data_type):
|
|
|
|
|
|
"""生成数据集统计信息"""
|
|
|
|
|
|
print(f"\n数据集统计信息:")
|
|
|
|
|
|
|
|
|
|
|
|
if data_type == '1': # 高质量数据集
|
|
|
|
|
|
speakers = {}
|
|
|
|
|
|
for item in training_data:
|
|
|
|
|
|
speaker = item['speaker']
|
|
|
|
|
|
speakers[speaker] = speakers.get(speaker, 0) + 1
|
|
|
|
|
|
|
|
|
|
|
|
print(f" - 角色分布:")
|
|
|
|
|
|
for speaker, count in speakers.items():
|
|
|
|
|
|
print(f" • {speaker}: {count}条对话")
|
|
|
|
|
|
|
|
|
|
|
|
elif data_type == '5': # 完整数据集
|
|
|
|
|
|
quality_dist = {}
|
|
|
|
|
|
for item in training_data:
|
|
|
|
|
|
label = item['quality_label']
|
|
|
|
|
|
quality_dist[label] = quality_dist.get(label, 0) + 1
|
|
|
|
|
|
|
|
|
|
|
|
print(f" - 质量分布:")
|
|
|
|
|
|
for label, count in quality_dist.items():
|
|
|
|
|
|
print(f" • {label}: {count}条对话")
|
|
|
|
|
|
|
|
|
|
|
|
def run_model_optimization():
|
|
|
|
|
|
"""运行模型迭代优化"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("模型迭代优化")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
from dual_ai_dialogue_system import ConversationManager
|
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
import json
|
|
|
|
|
|
import os
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
|
|
|
|
|
|
2025-08-23 18:37:00 +08:00
|
|
|
|
# 生成LoRA训练脚本
|
2025-08-24 19:44:46 +08:00
|
|
|
|
print("\n=== 使用LoRA训练脚本 ===")
|
|
|
|
|
|
import iterative_lora_training
|
|
|
|
|
|
iterative_lora_training.main()
|
|
|
|
|
|
|
2025-08-23 17:10:23 +08:00
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 模型优化失败: {e}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
2025-08-23 18:37:00 +08:00
|
|
|
|
|
2025-08-23 17:10:23 +08:00
|
|
|
|
|
|
|
|
|
|
def generate_training_command(training_file_path):
|
|
|
|
|
|
"""生成训练命令"""
|
|
|
|
|
|
return f"python ./scripts/iterative_lora_training.py --data {training_file_path} --output ./output/optimized_model_v{datetime.now().strftime('%Y%m%d')}"
|
|
|
|
|
|
|
|
|
|
|
|
def show_scoring_statistics():
|
|
|
|
|
|
"""显示对话评分统计"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("对话评分统计")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
from dual_ai_dialogue_system import ConversationManager
|
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
|
|
|
|
|
|
|
|
|
|
|
# 查询评分数据
|
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
with sqlite3.connect(conv_mgr.db_path) as conn:
|
|
|
|
|
|
# 总体统计
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
COUNT(*) as total_turns,
|
|
|
|
|
|
AVG(dialogue_score) as avg_score,
|
|
|
|
|
|
MAX(dialogue_score) as max_score,
|
|
|
|
|
|
MIN(dialogue_score) as min_score
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
stats = cursor.fetchone()
|
|
|
|
|
|
if stats and stats[0] > 0:
|
|
|
|
|
|
total_turns, avg_score, max_score, min_score = stats
|
|
|
|
|
|
print(f"总体统计:")
|
|
|
|
|
|
print(f" - 已评分对话轮数: {total_turns}")
|
|
|
|
|
|
print(f" - 平均分数: {avg_score:.2f}")
|
|
|
|
|
|
print(f" - 最高分数: {max_score:.2f}")
|
|
|
|
|
|
print(f" - 最低分数: {min_score:.2f}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("暂无评分数据")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 按角色统计
|
|
|
|
|
|
print(f"\n按角色统计:")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
speaker,
|
|
|
|
|
|
COUNT(*) as turns,
|
|
|
|
|
|
AVG(dialogue_score) as avg_score,
|
|
|
|
|
|
MAX(dialogue_score) as max_score
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
GROUP BY speaker
|
|
|
|
|
|
ORDER BY avg_score DESC
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for row in cursor.fetchall():
|
|
|
|
|
|
speaker, turns, avg_score, max_score = row
|
|
|
|
|
|
print(f" - {speaker}: 平均{avg_score:.2f}分 (最高{max_score:.2f}分, {turns}轮对话)")
|
|
|
|
|
|
|
|
|
|
|
|
# 最近高分对话
|
|
|
|
|
|
print(f"\n最近高分对话 (分数≥8.0):")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT speaker, content, dialogue_score, score_feedback, timestamp
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score >= 8.0
|
|
|
|
|
|
ORDER BY timestamp DESC
|
|
|
|
|
|
LIMIT 5
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
high_score_turns = cursor.fetchall()
|
|
|
|
|
|
if high_score_turns:
|
|
|
|
|
|
for speaker, content, score, feedback, timestamp in high_score_turns:
|
|
|
|
|
|
print(f" [{timestamp[:16]}] {speaker} ({score:.2f}分)")
|
|
|
|
|
|
print(f" 内容: {content[:80]}...")
|
|
|
|
|
|
if feedback:
|
|
|
|
|
|
print(f" 评价: {feedback[:60]}...")
|
|
|
|
|
|
print()
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(" 暂无高分对话")
|
|
|
|
|
|
|
|
|
|
|
|
# 分数分布统计
|
|
|
|
|
|
print(f"\n分数分布:")
|
|
|
|
|
|
cursor = conn.execute("""
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
CASE
|
|
|
|
|
|
WHEN dialogue_score >= 9.0 THEN '优秀 (9-10分)'
|
|
|
|
|
|
WHEN dialogue_score >= 8.0 THEN '良好 (8-9分)'
|
|
|
|
|
|
WHEN dialogue_score >= 7.0 THEN '中等 (7-8分)'
|
|
|
|
|
|
WHEN dialogue_score >= 6.0 THEN '及格 (6-7分)'
|
|
|
|
|
|
ELSE '待改进 (<6分)'
|
|
|
|
|
|
END as score_range,
|
|
|
|
|
|
COUNT(*) as count
|
|
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE dialogue_score > 0
|
|
|
|
|
|
GROUP BY score_range
|
|
|
|
|
|
ORDER BY MIN(dialogue_score) DESC
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
for score_range, count in cursor.fetchall():
|
|
|
|
|
|
percentage = (count / total_turns) * 100
|
|
|
|
|
|
print(f" - {score_range}: {count}轮 ({percentage:.1f}%)")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 评分统计查询失败: {e}")
|
|
|
|
|
|
|
2025-08-15 17:25:07 +08:00
|
|
|
|
def show_system_status():
|
|
|
|
|
|
"""显示系统状态"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("系统状态检查")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查文件
|
|
|
|
|
|
files_to_check = [
|
|
|
|
|
|
("./knowledge_base/worldview_template_coc.json", "世界观模板"),
|
|
|
|
|
|
("./knowledge_base/character_template_detective.json", "侦探角色"),
|
|
|
|
|
|
("./knowledge_base/character_template_professor.json", "教授角色"),
|
|
|
|
|
|
("./pdf_to_rag_processor.py", "PDF处理器"),
|
|
|
|
|
|
("./dual_ai_dialogue_system.py", "对话系统"),
|
|
|
|
|
|
("./npc_dialogue_generator.py", "NPC生成器")
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
print("\n文件检查:")
|
|
|
|
|
|
for file_path, description in files_to_check:
|
|
|
|
|
|
if os.path.exists(file_path):
|
|
|
|
|
|
print(f"✓ {description}: {file_path}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"✗ {description}: {file_path} (不存在)")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查目录
|
|
|
|
|
|
print("\n目录检查:")
|
|
|
|
|
|
directories = ["./knowledge_base", "./rag_knowledge", "./conversation_data"]
|
|
|
|
|
|
for dir_path in directories:
|
|
|
|
|
|
if os.path.exists(dir_path):
|
|
|
|
|
|
file_count = len([f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))])
|
|
|
|
|
|
print(f"✓ {dir_path}: {file_count} 个文件")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"✗ {dir_path}: 不存在")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查对话会话
|
|
|
|
|
|
try:
|
|
|
|
|
|
from dual_ai_dialogue_system import ConversationManager
|
|
|
|
|
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
|
|
|
|
|
sessions = conv_mgr.list_sessions()
|
|
|
|
|
|
print(f"\n✓ 对话会话: {len(sessions)} 个")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"\n✗ 对话会话检查失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
"""主控制程序"""
|
|
|
|
|
|
print("="*70)
|
|
|
|
|
|
print(" 双AI角色对话系统 - 主控制程序")
|
|
|
|
|
|
print(" 基于RAG的世界观增强对话引擎")
|
|
|
|
|
|
print("="*70)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查依赖
|
|
|
|
|
|
if not check_dependencies():
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 设置目录
|
|
|
|
|
|
# setup_directories()
|
|
|
|
|
|
# copy_demo_files()
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
print("\n" + "="*50)
|
|
|
|
|
|
print("主菜单 - 请选择操作:")
|
|
|
|
|
|
print("1. 处理PDF世界观文档 (转换为RAG格式)")
|
|
|
|
|
|
print("2. 查看角色设定信息")
|
2025-08-23 18:13:45 +08:00
|
|
|
|
print("3. 启动双AI对话系统 (开启AI打分)")
|
|
|
|
|
|
print("4. 启动双AI对话系统 (关闭AI打分)")
|
|
|
|
|
|
print("5. 启动双AI对话系统 (开启人工打分)")
|
|
|
|
|
|
print("6. 系统状态检查")
|
|
|
|
|
|
print("7. 查看对话评分统计")
|
2025-08-23 18:20:50 +08:00
|
|
|
|
print("8. 生成训练数据集")
|
|
|
|
|
|
print("9. 模型迭代优化")
|
2025-08-15 17:25:07 +08:00
|
|
|
|
print("0. 退出")
|
|
|
|
|
|
print("="*50)
|
|
|
|
|
|
|
2025-08-24 19:44:46 +08:00
|
|
|
|
choice = input("请输入选择 (0-9): ").strip()
|
2025-08-15 17:25:07 +08:00
|
|
|
|
|
|
|
|
|
|
if choice == '0':
|
|
|
|
|
|
print("\n感谢使用双AI角色对话系统!")
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '1':
|
|
|
|
|
|
process_pdf_workflow()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '2':
|
|
|
|
|
|
show_character_info()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '3':
|
2025-08-23 17:27:01 +08:00
|
|
|
|
run_dialogue_system(enableScore = True)
|
2025-08-15 17:25:07 +08:00
|
|
|
|
|
|
|
|
|
|
elif choice == '4':
|
2025-08-23 17:27:01 +08:00
|
|
|
|
run_dialogue_system(enableScore = False)
|
2025-08-15 17:25:07 +08:00
|
|
|
|
|
|
|
|
|
|
elif choice == '5':
|
2025-08-23 18:13:45 +08:00
|
|
|
|
run_dialogue_system(enableScore = True, useManualScoring = True)
|
2025-08-15 17:25:07 +08:00
|
|
|
|
|
|
|
|
|
|
elif choice == '6':
|
2025-08-23 18:13:45 +08:00
|
|
|
|
show_system_status()
|
2025-08-23 17:10:23 +08:00
|
|
|
|
|
|
|
|
|
|
elif choice == '7':
|
2025-08-23 18:13:45 +08:00
|
|
|
|
show_scoring_statistics()
|
2025-08-23 17:10:23 +08:00
|
|
|
|
|
|
|
|
|
|
elif choice == '8':
|
2025-08-23 18:13:45 +08:00
|
|
|
|
generate_training_dataset()
|
2025-08-23 17:10:23 +08:00
|
|
|
|
|
2025-08-23 18:20:50 +08:00
|
|
|
|
elif choice == '9':
|
2025-08-23 18:13:45 +08:00
|
|
|
|
run_model_optimization()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == '11':
|
2025-08-15 17:25:07 +08:00
|
|
|
|
show_usage_guide()
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("❌ 无效选择,请重新输入")
|
|
|
|
|
|
|
|
|
|
|
|
def show_usage_guide():
|
|
|
|
|
|
"""显示使用说明"""
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
|
|
|
print("系统使用说明")
|
|
|
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
|
|
|
guide = """
|
|
|
|
|
|
🚀 快速开始:
|
|
|
|
|
|
1. 首次使用建议先运行"创建演示对话场景"
|
|
|
|
|
|
2. 如有PDF世界观文档,选择"处理PDF世界观文档"
|
|
|
|
|
|
3. 通过"启动双AI对话系统"开始角色对话
|
|
|
|
|
|
|
|
|
|
|
|
📁 文档格式说明:
|
|
|
|
|
|
- 世界观文档: worldview_template_coc.json (参考COC设定)
|
|
|
|
|
|
- 角色设定: character_template_*.json (包含详细人设)
|
|
|
|
|
|
|
|
|
|
|
|
🔧 系统功能:
|
|
|
|
|
|
- PDF自动转换为RAG知识库
|
|
|
|
|
|
- 基于向量相似度的上下文检索
|
|
|
|
|
|
- 持久化对话历史存储
|
|
|
|
|
|
- 角色设定一致性保持
|
|
|
|
|
|
|
|
|
|
|
|
📝 自定义角色:
|
|
|
|
|
|
1. 参考 character_template_*.json 格式
|
|
|
|
|
|
2. 保存到 knowledge_base/ 目录
|
|
|
|
|
|
3. 重启对话系统加载新角色
|
|
|
|
|
|
|
|
|
|
|
|
💾 对话数据:
|
|
|
|
|
|
- 历史对话保存在 conversation_data/ 目录
|
|
|
|
|
|
- 支持会话恢复和历史查看
|
|
|
|
|
|
- 自动记录使用的上下文信息
|
|
|
|
|
|
|
|
|
|
|
|
⚠️ 注意事项:
|
|
|
|
|
|
- 确保模型路径正确设置
|
|
|
|
|
|
- 首次运行需要下载向量化模型
|
|
|
|
|
|
- PDF处理需要足够内存
|
|
|
|
|
|
"""
|
|
|
|
|
|
print(guide)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
main()
|