Project02/AITrain/main_controller.py

1200 lines
44 KiB
Python
Raw Normal View History

2025-08-15 17:25:07 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
双AI角色对话系统主控制程序
完整的工作流程PDF处理 -> 角色加载 -> RAG对话 -> 历史记录
'''
import os
import sys
import shutil
from typing import List, Dict
import json
def check_dependencies():
"""检查依赖库"""
missing_deps = []
try:
import PyPDF2
except ImportError:
missing_deps.append("PyPDF2")
try:
import pymupdf
print("✓ pymupdf 可用")
except ImportError:
print("⚠ pymupdf 不可用,将使用 PyPDF2")
try:
import sentence_transformers
import faiss
print("✓ 向量化功能可用")
except ImportError:
print("⚠ 向量化功能不可用,将使用文本匹配")
if missing_deps:
print(f"✗ 缺少依赖库: {', '.join(missing_deps)}")
print("请运行: pip install PyPDF2 sentence-transformers faiss-cpu")
return False
return True
def setup_directories():
"""设置项目目录结构"""
directories = [
"./knowledge_base",
"./characters",
"./worldview",
"./rag_knowledge",
"./conversation_data"
]
for dir_path in directories:
os.makedirs(dir_path, exist_ok=True)
print(f"✓ 目录就绪: {dir_path}")
def copy_demo_files():
"""复制演示文档到知识库目录"""
file_mappings = [
("./worldview/worldview_template_coc.json", "./knowledge_base/worldview_template_coc.json"),
("./characters/character_template_detective.json", "./knowledge_base/character_template_detective.json"),
("./characters/character_template_professor.json", "./knowledge_base/character_template_professor.json")
]
for source, target in file_mappings:
if os.path.exists(source):
shutil.copy2(source, target)
print(f"✓ 复制文档: {os.path.basename(target)}")
def process_pdf_workflow():
"""PDF处理工作流"""
print("\n" + "="*60)
print("PDF世界观文档处理")
print("="*60)
from pdf_to_rag_processor import PDFToRAGProcessor
pdf_path = input("请输入PDF文件路径 (例: ./coc.pdf): ").strip()
if not os.path.exists(pdf_path):
print(f"✗ 文件不存在: {pdf_path}")
return False
try:
processor = PDFToRAGProcessor()
result = processor.process_pdf_to_rag(pdf_path, "./rag_knowledge")
print(f"\n✓ PDF处理完成!")
print(f" - 文档块数: {result['chunks_count']}")
print(f" - 概念数: {result['concepts_count']}")
print(f" - 向量索引: {'启用' if result['vector_enabled'] else '未启用'}")
return True
except Exception as e:
print(f"✗ PDF处理失败: {e}")
return False
def show_character_info():
"""显示角色信息"""
print("\n" + "="*60)
print("角色设定信息")
print("="*60)
knowledge_dir = "./knowledge_base"
character_files = [f for f in os.listdir(knowledge_dir) if f.startswith('character') and f.endswith('.json')]
for char_file in character_files:
try:
with open(os.path.join(knowledge_dir, char_file), 'r', encoding='utf-8') as f:
char_data = json.load(f)
name = char_data.get('character_name', '未知')
occupation = char_data.get('basic_info', {}).get('occupation', '未知')
traits = char_data.get('personality', {}).get('core_traits', [])
print(f"\n角色: {name}")
print(f" 职业: {occupation}")
print(f" 特点: {', '.join(traits[:3])}")
except Exception as e:
print(f"✗ 读取角色文件失败: {char_file} - {e}")
2025-08-23 18:13:45 +08:00
def run_dialogue_system(enableScore: bool, useManualScoring: bool = False):
2025-08-15 17:58:11 +08:00
"""运行双AI对话系统"""
2025-08-15 17:25:07 +08:00
print("\n" + "="*60)
print("启动双AI角色对话系统")
print("="*60)
try:
2025-08-15 17:58:11 +08:00
# 直接启动双模型对话
print("\n正在初始化双模型对话系统...")
from dual_ai_dialogue_system import RAGKnowledgeBase, ConversationManager, DualAIDialogueEngine
from npc_dialogue_generator import DualModelDialogueGenerator
# 初始化组件
kb = RAGKnowledgeBase("./knowledge_base")
conv_mgr = ConversationManager("./conversation_data/conversations.db")
# 检查模型路径
2025-08-23 17:10:23 +08:00
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B'
2025-08-15 17:58:11 +08:00
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
if not os.path.exists(base_model_path):
print(f"✗ 基础模型路径不存在: {base_model_path}")
print("请修改 main_controller.py 中的模型路径")
return
if not os.path.exists(lora_model_path):
lora_model_path = None
print("⚠ LoRA模型不存在使用基础模型")
# 检查角色数据
if not hasattr(kb, 'character_data') or len(kb.character_data) < 2:
print("✗ 角色数据不足,无法创建双模型对话系统")
print("请确保knowledge_base目录中有至少两个角色文件")
return
# 获取前两个角色
character_names = list(kb.character_data.keys())[:2]
char1_name = character_names[0]
char2_name = character_names[1]
print(f"✓ 使用角色: {char1_name}{char2_name}")
# 配置两个角色的模型
character1_config = {
"name": char1_name,
"lora_path": lora_model_path,
"character_data": kb.character_data[char1_name]
}
character2_config = {
"name": char2_name,
"lora_path": lora_model_path,
"character_data": kb.character_data[char2_name]
}
# 创建双模型对话生成器
print("正在初始化双模型对话生成器...")
dual_generator = DualModelDialogueGenerator(
base_model_path,
character1_config,
character2_config
)
2025-08-23 17:10:23 +08:00
# 创建对话引擎(启用评分功能)
dialogue_engine = DualAIDialogueEngine(
kb,
conv_mgr,
dual_generator,
2025-08-23 17:27:01 +08:00
enable_scoring=enableScore,
2025-08-23 18:13:45 +08:00
base_model_path=base_model_path,
use_manual_scoring=useManualScoring
2025-08-23 17:10:23 +08:00
)
2025-08-15 17:58:11 +08:00
# 创建对话会话
characters = [char1_name, char2_name]
worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观'
session_id = conv_mgr.create_session(characters, worldview)
print(f"✓ 创建对话会话: {session_id}")
# 交互式对话循环
print(f"\n=== 双AI模型对话系统 ===")
print(f"角色: {char1_name} vs {char2_name}")
print(f"世界观: {worldview}")
print("输入 'quit' 退出对话")
print("-" * 50)
while True:
try:
# 获取用户输入
user_input = input("\n请输入对话主题或指令: ").strip()
if user_input.lower() == 'quit':
print("退出双AI对话系统")
break
if not user_input:
print("请输入有效的对话主题")
continue
# 询问对话轮数
turns_input = input("请输入对话轮数 (默认4): ").strip()
turns = int(turns_input) if turns_input.isdigit() else 4
# 询问历史上下文设置
2025-08-18 18:25:19 +08:00
history_input = input("使用历史对话轮数 (默认2): ").strip()
history_count = int(history_input) if history_input.isdigit() else 2
2025-08-15 17:58:11 +08:00
2025-08-18 18:25:19 +08:00
context_input = input("使用上下文信息数量 (默认10): ").strip()
context_info_count = int(context_input) if context_input.isdigit() else 10
2025-08-15 17:58:11 +08:00
print(f"\n开始对话 - 主题: {user_input}")
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
print("-" * 50)
# 运行双模型对话
dialogue_engine.run_dual_model_conversation(
session_id, user_input, turns, history_count, context_info_count
)
print("-" * 50)
print("对话完成!")
except KeyboardInterrupt:
print("\n\n用户中断对话")
break
except Exception as e:
print(f"对话过程中出现错误: {e}")
import traceback
traceback.print_exc()
2025-08-15 17:25:07 +08:00
except Exception as e:
print(f"✗ 对话系统启动失败: {e}")
import traceback
traceback.print_exc()
2025-08-23 17:10:23 +08:00
def generate_training_dataset():
"""生成训练数据集"""
print("\n" + "="*60)
print("生成训练数据集")
print("="*60)
try:
from dual_ai_dialogue_system import ConversationManager
import sqlite3
import json
import os
from datetime import datetime
conv_mgr = ConversationManager("./conversation_data/conversations.db")
# 创建输出目录
output_dir = "./training_data"
os.makedirs(output_dir, exist_ok=True)
print("请选择训练数据生成类型:")
print("1. 高质量对话数据集 (分数≥8.0)")
print("2. 问题对话改进数据集 (分数<6.0)")
print("3. 角色一致性训练集")
print("4. 创意性增强训练集")
print("5. 完整对话质量数据集")
choice = input("请输入选择 (1-5): ").strip()
with sqlite3.connect(conv_mgr.db_path) as conn:
training_data = []
if choice == '1':
# 高质量对话数据集
print("\n生成高质量对话数据集...")
cursor = conn.execute("""
SELECT speaker, content, score_details, score_feedback
FROM dialogue_turns
WHERE dialogue_score >= 8.0
ORDER BY dialogue_score DESC
LIMIT 200
""")
for speaker, content, score_details, feedback in cursor.fetchall():
training_data.append({
"type": "high_quality",
"speaker": speaker,
"content": content,
"scores": json.loads(score_details) if score_details else {},
"feedback": feedback,
"label": "positive"
})
output_file = f"{output_dir}/high_quality_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
elif choice == '2':
# 问题对话改进数据集
print("\n生成问题对话改进数据集...")
cursor = conn.execute("""
SELECT speaker, content, score_details, score_feedback
FROM dialogue_turns
WHERE dialogue_score < 6.0 AND dialogue_score > 0
ORDER BY dialogue_score ASC
LIMIT 100
""")
for speaker, content, score_details, feedback in cursor.fetchall():
# 为每个低分对话生成改进指导
improvement_prompt = generate_improvement_prompt(content, feedback)
training_data.append({
"type": "improvement",
"speaker": speaker,
"original_content": content,
"scores": json.loads(score_details) if score_details else {},
"problems": feedback,
"improvement_prompt": improvement_prompt,
"label": "needs_improvement"
})
output_file = f"{output_dir}/improvement_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
elif choice == '3':
# 角色一致性训练集
print("\n生成角色一致性训练集...")
cursor = conn.execute("""
SELECT speaker, content, score_details
FROM dialogue_turns
WHERE dialogue_score > 0
ORDER BY json_extract(score_details, '$.character_consistency') DESC
LIMIT 150
""")
for speaker, content, score_details in cursor.fetchall():
scores = json.loads(score_details) if score_details else {}
char_consistency = scores.get('character_consistency', 0)
training_data.append({
"type": "character_consistency",
"speaker": speaker,
"content": content,
"character_consistency_score": char_consistency,
"label": "high_consistency" if char_consistency >= 8.0 else "medium_consistency"
})
output_file = f"{output_dir}/character_consistency_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
elif choice == '4':
# 创意性增强训练集
print("\n生成创意性增强训练集...")
cursor = conn.execute("""
SELECT speaker, content, score_details
FROM dialogue_turns
WHERE dialogue_score > 0
ORDER BY json_extract(score_details, '$.creativity') DESC
LIMIT 150
""")
for speaker, content, score_details in cursor.fetchall():
scores = json.loads(score_details) if score_details else {}
creativity = scores.get('creativity', 0)
training_data.append({
"type": "creativity",
"speaker": speaker,
"content": content,
"creativity_score": creativity,
"label": "high_creativity" if creativity >= 8.0 else "medium_creativity"
})
output_file = f"{output_dir}/creativity_enhancement_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
elif choice == '5':
# 完整对话质量数据集
print("\n生成完整对话质量数据集...")
cursor = conn.execute("""
SELECT speaker, content, dialogue_score, score_details, score_feedback
FROM dialogue_turns
WHERE dialogue_score > 0
ORDER BY timestamp DESC
LIMIT 300
""")
for speaker, content, score, score_details, feedback in cursor.fetchall():
training_data.append({
"type": "complete_dataset",
"speaker": speaker,
"content": content,
"overall_score": score,
"dimension_scores": json.loads(score_details) if score_details else {},
"feedback": feedback,
"quality_label": get_quality_label(score)
})
output_file = f"{output_dir}/complete_quality_dataset_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
else:
print("❌ 无效选择")
return
if training_data:
# 保存训练数据
with open(output_file, 'w', encoding='utf-8') as f:
json.dump({
"metadata": {
"created_at": datetime.now().isoformat(),
"total_samples": len(training_data),
"data_type": choice,
"source": "dialogue_scoring_system"
},
"data": training_data
}, f, ensure_ascii=False, indent=2)
print(f"\n✓ 训练数据集生成成功!")
print(f" - 文件路径: {output_file}")
print(f" - 数据条数: {len(training_data)}")
print(f" - 数据类型: {get_dataset_description(choice)}")
# 生成数据集统计信息
generate_dataset_statistics(training_data, choice)
else:
print("❌ 未找到符合条件的数据")
except Exception as e:
print(f"✗ 训练数据集生成失败: {e}")
import traceback
traceback.print_exc()
def generate_improvement_prompt(content, feedback):
"""生成改进提示"""
return f"""原对话: {content}
问题分析: {feedback}
改进要求:
1. 保持角色特征和设定
2. 增强对话的逻辑性和连贯性
3. 提升表达的自然度
4. 增加信息密度避免冗余
5. 适当增加创意元素
请生成一个改进版本的对话"""
def get_quality_label(score):
"""根据分数获取质量标签"""
if score >= 9.0:
return "excellent"
elif score >= 8.0:
return "good"
elif score >= 7.0:
return "average"
elif score >= 6.0:
return "below_average"
else:
return "poor"
def get_dataset_description(choice):
"""获取数据集描述"""
descriptions = {
'1': "高质量对话数据集",
'2': "问题对话改进数据集",
'3': "角色一致性训练集",
'4': "创意性增强训练集",
'5': "完整对话质量数据集"
}
return descriptions.get(choice, "未知类型")
def generate_dataset_statistics(training_data, data_type):
"""生成数据集统计信息"""
print(f"\n数据集统计信息:")
if data_type == '1': # 高质量数据集
speakers = {}
for item in training_data:
speaker = item['speaker']
speakers[speaker] = speakers.get(speaker, 0) + 1
print(f" - 角色分布:")
for speaker, count in speakers.items():
print(f"{speaker}: {count}条对话")
elif data_type == '5': # 完整数据集
quality_dist = {}
for item in training_data:
label = item['quality_label']
quality_dist[label] = quality_dist.get(label, 0) + 1
print(f" - 质量分布:")
for label, count in quality_dist.items():
print(f"{label}: {count}条对话")
def run_model_optimization():
"""运行模型迭代优化"""
print("\n" + "="*60)
print("模型迭代优化")
print("="*60)
try:
from dual_ai_dialogue_system import ConversationManager
import sqlite3
import json
import os
from datetime import datetime
conv_mgr = ConversationManager("./conversation_data/conversations.db")
print("模型优化选项:")
print("1. 分析优化需求")
print("2. 生成LoRA训练脚本")
print("3. 创建提示优化配置")
print("4. 执行增量训练")
print("5. 性能对比验证")
choice = input("请输入选择 (1-5): ").strip()
if choice == '1':
# 分析优化需求
print("\n=== 优化需求分析 ===")
with sqlite3.connect(conv_mgr.db_path) as conn:
# 获取性能数据
cursor = conn.execute("""
SELECT
COUNT(*) as total,
AVG(dialogue_score) as avg_score,
AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate
FROM dialogue_turns WHERE dialogue_score > 0
""")
total, avg_score, hq_rate = cursor.fetchone()
print(f"当前性能指标:")
print(f" - 总评分样本: {total}")
print(f" - 平均分数: {avg_score:.2f}")
print(f" - 高质量率: {hq_rate*100:.1f}%")
# 分析优化潜力
optimization_needs = []
if avg_score < 7.0:
optimization_needs.append("整体质量需要提升")
if hq_rate < 0.6:
optimization_needs.append("高质量对话比例偏低")
# 分析各维度表现
cursor = conn.execute("""
SELECT score_details FROM dialogue_turns
WHERE dialogue_score > 0 AND score_details != '{}'
ORDER BY timestamp DESC LIMIT 100
""")
dim_scores = {'coherence': [], 'character_consistency': [],
'naturalness': [], 'information_density': [], 'creativity': []}
for (score_details,) in cursor.fetchall():
try:
scores = json.loads(score_details)
for dim, score in scores.items():
if dim in dim_scores:
dim_scores[dim].append(float(score))
except:
continue
weak_dimensions = []
print(f"\n维度分析:")
for dim, scores in dim_scores.items():
if scores:
avg = sum(scores) / len(scores)
print(f" - {dim}: {avg:.2f}")
if avg < 7.0:
weak_dimensions.append(dim)
if weak_dimensions:
optimization_needs.append(f"薄弱维度: {weak_dimensions}")
print(f"\n优化建议:")
if optimization_needs:
for i, need in enumerate(optimization_needs, 1):
print(f" {i}. {need}")
else:
print(" 当前模型表现良好,可考虑细微调优")
elif choice == '2':
# 生成LoRA训练脚本
print("\n=== 生成LoRA训练脚本 ===")
script_content = generate_lora_training_script()
script_path = "./scripts/iterative_lora_training.py"
os.makedirs("./scripts", exist_ok=True)
with open(script_path, 'w', encoding='utf-8') as f:
f.write(script_content)
print(f"✓ LoRA训练脚本已生成: {script_path}")
print("使用方法:")
print(" 1. 先运行训练数据生成 (选项8)")
print(" 2. 修改脚本中的路径配置")
print(f" 3. 运行: python {script_path}")
elif choice == '3':
# 创建提示优化配置
print("\n=== 创建提示优化配置 ===")
config = generate_prompt_optimization_config()
config_path = "./config/prompt_optimization.json"
os.makedirs("./config", exist_ok=True)
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(config, f, ensure_ascii=False, indent=2)
print(f"✓ 提示优化配置已生成: {config_path}")
print("配置包含:")
print(" - 动态提示调整规则")
print(" - 质量阈值设置")
print(" - 生成参数优化")
elif choice == '4':
# 执行增量训练
print("\n=== 执行增量训练 ===")
# 检查训练数据
training_dir = "./training_data"
if not os.path.exists(training_dir):
print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)")
return
training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')]
if not training_files:
print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)")
return
print(f"找到训练数据文件:")
for i, file in enumerate(training_files, 1):
print(f" {i}. {file}")
file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip()
try:
selected_file = training_files[int(file_idx) - 1]
training_file_path = os.path.join(training_dir, selected_file)
print(f"将使用训练文件: {selected_file}")
print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源")
# 生成训练命令
training_command = generate_training_command(training_file_path)
print(f"建议训练命令:")
print(f" {training_command}")
# 可选:执行训练(需要用户确认)
confirm = input("是否现在执行训练?(y/N): ").strip().lower()
if confirm == 'y':
print("开始增量训练...")
# 这里可以添加实际的训练执行逻辑
print("⚠ 训练功能需要根据实际环境配置")
except (ValueError, IndexError):
print("❌ 无效的文件选择")
elif choice == '5':
# 性能对比验证
print("\n=== 性能对比验证 ===")
print("验证选项:")
print("1. 历史性能趋势对比")
print("2. A/B测试配置生成")
print("3. 模型版本性能对比")
verify_choice = input("请输入选择 (1-3): ").strip()
if verify_choice == '1':
# 历史性能趋势
with sqlite3.connect(conv_mgr.db_path) as conn:
cursor = conn.execute("""
SELECT
DATE(timestamp) as date,
AVG(dialogue_score) as avg_score,
COUNT(*) as count
FROM dialogue_turns
WHERE dialogue_score > 0
AND timestamp >= datetime('now', '-30 days')
GROUP BY DATE(timestamp)
ORDER BY date ASC
""")
trend_data = cursor.fetchall()
if trend_data:
print(f"30天性能趋势:")
for date, avg_score, count in trend_data:
trend = "📈" if avg_score > 7.5 else "📉" if avg_score < 6.5 else "📊"
print(f" {date}: {avg_score:.2f}{trend} ({count}条对话)")
else:
print("暂无足够的历史数据")
elif verify_choice == '2':
# A/B测试配置
ab_config = generate_ab_test_config()
ab_config_path = "./config/ab_test_config.json"
os.makedirs("./config", exist_ok=True)
with open(ab_config_path, 'w', encoding='utf-8') as f:
json.dump(ab_config, f, ensure_ascii=False, indent=2)
print(f"✓ A/B测试配置已生成: {ab_config_path}")
print("配置包含:")
print(" - 对照组和实验组设置")
print(" - 评估指标定义")
print(" - 测试持续时间配置")
elif verify_choice == '3':
print("模型版本对比功能开发中...")
print("建议手动记录不同版本的性能指标进行对比")
else:
print("❌ 无效选择")
except Exception as e:
print(f"✗ 模型优化失败: {e}")
import traceback
traceback.print_exc()
def generate_lora_training_script():
"""生成LoRA训练脚本"""
return '''#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
基于评分数据的LoRA增量训练脚本
自动生成 - 请根据实际环境调整配置
"""
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
import os
class IterativeLoRATrainer:
def __init__(self, base_model_path, training_data_path, output_path):
self.base_model_path = base_model_path
self.training_data_path = training_data_path
self.output_path = output_path
# LoRA配置
self.lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=16, # LoRA rank
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
)
def load_training_data(self):
"""加载训练数据"""
with open(self.training_data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data['data']
def prepare_training_samples(self, data):
"""准备训练样本"""
samples = []
for item in data:
if item.get('label') == 'positive' or item.get('overall_score', 0) >= 8.0:
# 高质量样本
sample = {
'input': f"角色: {item['speaker']}\\n请生成高质量对话:",
'output': item['content'],
'quality_score': item.get('overall_score', 8.0)
}
samples.append(sample)
return samples
def train(self):
"""执行训练"""
print("开始LoRA增量训练...")
# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained(self.base_model_path)
model = AutoModelForCausalLM.from_pretrained(
self.base_model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# 应用LoRA
model = get_peft_model(model, self.lora_config)
# 加载训练数据
training_data = self.load_training_data()
training_samples = self.prepare_training_samples(training_data)
print(f"训练样本数量: {len(training_samples)}")
# 这里添加实际的训练循环
# 建议使用transformers的Trainer或自定义训练循环
# 保存模型
model.save_pretrained(self.output_path)
tokenizer.save_pretrained(self.output_path)
print(f"训练完成,模型保存到: {self.output_path}")
if __name__ == '__main__':
# 配置参数 - 请根据实际情况修改
BASE_MODEL_PATH = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
TRAINING_DATA_PATH = './training_data/high_quality_dialogues_latest.json'
OUTPUT_PATH = './output/iterative_lora_v2'
trainer = IterativeLoRATrainer(BASE_MODEL_PATH, TRAINING_DATA_PATH, OUTPUT_PATH)
trainer.train()
'''
def generate_prompt_optimization_config():
"""生成提示优化配置"""
return {
"optimization_rules": {
"quality_thresholds": {
"excellent": 9.0,
"good": 8.0,
"acceptable": 7.0,
"needs_improvement": 6.0
},
"adaptive_adjustments": {
"low_coherence": {
"add_context_emphasis": True,
"increase_history_weight": 1.2,
"add_logical_constraints": True
},
"low_character_consistency": {
"enhance_character_description": True,
"add_personality_reminders": True,
"increase_character_weight": 1.5
},
"low_creativity": {
"add_creativity_prompts": True,
"increase_temperature": 0.1,
"diversify_examples": True
}
}
},
"dynamic_prompts": {
"quality_boost_templates": [
"请生成一个富有创意且符合角色特征的高质量对话",
"确保对话逻辑连贯,信息丰富,表达自然",
"体现角色的独特性格和说话风格"
],
"problem_specific_guidance": {
"repetitive": "避免重复之前的内容,提供新的信息和观点",
"inconsistent": "严格遵循角色设定,保持人格一致性",
"dull": "增加对话的趣味性和深度,使用生动的表达"
}
},
"generation_parameters": {
"adaptive_temperature": {
"high_creativity_needed": 0.9,
"normal": 0.8,
"high_consistency_needed": 0.7
},
"adaptive_top_p": {
"creative_mode": 0.9,
"balanced_mode": 0.8,
"conservative_mode": 0.7
}
}
}
def generate_ab_test_config():
"""生成A/B测试配置"""
return {
"test_name": "model_optimization_validation",
"created_at": datetime.now().isoformat(),
"groups": {
"control": {
"name": "原始模型",
"description": "未优化的基础模型",
"model_path": "/path/to/base/model",
"sample_ratio": 0.5
},
"experimental": {
"name": "优化模型",
"description": "经过迭代优化的模型",
"model_path": "/path/to/optimized/model",
"sample_ratio": 0.5
}
},
"evaluation_metrics": {
"primary": [
"overall_dialogue_score",
"character_consistency_score",
"creativity_score"
],
"secondary": [
"user_satisfaction",
"response_time",
"coherence_score"
]
},
"test_duration": {
"target_samples": 200,
"max_duration_days": 7,
"min_samples_per_group": 50
},
"statistical_settings": {
"confidence_level": 0.95,
"minimum_effect_size": 0.3,
"power": 0.8
}
}
def generate_training_command(training_file_path):
"""生成训练命令"""
return f"python ./scripts/iterative_lora_training.py --data {training_file_path} --output ./output/optimized_model_v{datetime.now().strftime('%Y%m%d')}"
def show_scoring_statistics():
"""显示对话评分统计"""
print("\n" + "="*60)
print("对话评分统计")
print("="*60)
try:
from dual_ai_dialogue_system import ConversationManager
import json
conv_mgr = ConversationManager("./conversation_data/conversations.db")
# 查询评分数据
import sqlite3
with sqlite3.connect(conv_mgr.db_path) as conn:
# 总体统计
cursor = conn.execute("""
SELECT
COUNT(*) as total_turns,
AVG(dialogue_score) as avg_score,
MAX(dialogue_score) as max_score,
MIN(dialogue_score) as min_score
FROM dialogue_turns
WHERE dialogue_score > 0
""")
stats = cursor.fetchone()
if stats and stats[0] > 0:
total_turns, avg_score, max_score, min_score = stats
print(f"总体统计:")
print(f" - 已评分对话轮数: {total_turns}")
print(f" - 平均分数: {avg_score:.2f}")
print(f" - 最高分数: {max_score:.2f}")
print(f" - 最低分数: {min_score:.2f}")
else:
print("暂无评分数据")
return
# 按角色统计
print(f"\n按角色统计:")
cursor = conn.execute("""
SELECT
speaker,
COUNT(*) as turns,
AVG(dialogue_score) as avg_score,
MAX(dialogue_score) as max_score
FROM dialogue_turns
WHERE dialogue_score > 0
GROUP BY speaker
ORDER BY avg_score DESC
""")
for row in cursor.fetchall():
speaker, turns, avg_score, max_score = row
print(f" - {speaker}: 平均{avg_score:.2f}分 (最高{max_score:.2f}分, {turns}轮对话)")
# 最近高分对话
print(f"\n最近高分对话 (分数≥8.0):")
cursor = conn.execute("""
SELECT speaker, content, dialogue_score, score_feedback, timestamp
FROM dialogue_turns
WHERE dialogue_score >= 8.0
ORDER BY timestamp DESC
LIMIT 5
""")
high_score_turns = cursor.fetchall()
if high_score_turns:
for speaker, content, score, feedback, timestamp in high_score_turns:
print(f" [{timestamp[:16]}] {speaker} ({score:.2f}分)")
print(f" 内容: {content[:80]}...")
if feedback:
print(f" 评价: {feedback[:60]}...")
print()
else:
print(" 暂无高分对话")
# 分数分布统计
print(f"\n分数分布:")
cursor = conn.execute("""
SELECT
CASE
WHEN dialogue_score >= 9.0 THEN '优秀 (9-10分)'
WHEN dialogue_score >= 8.0 THEN '良好 (8-9分)'
WHEN dialogue_score >= 7.0 THEN '中等 (7-8分)'
WHEN dialogue_score >= 6.0 THEN '及格 (6-7分)'
ELSE '待改进 (<6分)'
END as score_range,
COUNT(*) as count
FROM dialogue_turns
WHERE dialogue_score > 0
GROUP BY score_range
ORDER BY MIN(dialogue_score) DESC
""")
for score_range, count in cursor.fetchall():
percentage = (count / total_turns) * 100
print(f" - {score_range}: {count}轮 ({percentage:.1f}%)")
except Exception as e:
print(f"✗ 评分统计查询失败: {e}")
2025-08-15 17:25:07 +08:00
def show_system_status():
"""显示系统状态"""
print("\n" + "="*60)
print("系统状态检查")
print("="*60)
# 检查文件
files_to_check = [
("./knowledge_base/worldview_template_coc.json", "世界观模板"),
("./knowledge_base/character_template_detective.json", "侦探角色"),
("./knowledge_base/character_template_professor.json", "教授角色"),
("./pdf_to_rag_processor.py", "PDF处理器"),
("./dual_ai_dialogue_system.py", "对话系统"),
("./npc_dialogue_generator.py", "NPC生成器")
]
print("\n文件检查:")
for file_path, description in files_to_check:
if os.path.exists(file_path):
print(f"{description}: {file_path}")
else:
print(f"{description}: {file_path} (不存在)")
# 检查目录
print("\n目录检查:")
directories = ["./knowledge_base", "./rag_knowledge", "./conversation_data"]
for dir_path in directories:
if os.path.exists(dir_path):
file_count = len([f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))])
print(f"{dir_path}: {file_count} 个文件")
else:
print(f"{dir_path}: 不存在")
# 检查对话会话
try:
from dual_ai_dialogue_system import ConversationManager
conv_mgr = ConversationManager("./conversation_data/conversations.db")
sessions = conv_mgr.list_sessions()
print(f"\n✓ 对话会话: {len(sessions)}")
except Exception as e:
print(f"\n✗ 对话会话检查失败: {e}")
def main():
"""主控制程序"""
print("="*70)
print(" 双AI角色对话系统 - 主控制程序")
print(" 基于RAG的世界观增强对话引擎")
print("="*70)
# 检查依赖
if not check_dependencies():
return
# 设置目录
# setup_directories()
# copy_demo_files()
while True:
print("\n" + "="*50)
print("主菜单 - 请选择操作:")
print("1. 处理PDF世界观文档 (转换为RAG格式)")
print("2. 查看角色设定信息")
2025-08-23 18:13:45 +08:00
print("3. 启动双AI对话系统 (开启AI打分)")
print("4. 启动双AI对话系统 (关闭AI打分)")
print("5. 启动双AI对话系统 (开启人工打分)")
print("6. 系统状态检查")
print("7. 查看对话评分统计")
print("8. 模型性能分析与优化")
print("9. 生成训练数据集")
print("10. 模型迭代优化")
print("11. 查看使用说明")
2025-08-15 17:25:07 +08:00
print("0. 退出")
print("="*50)
2025-08-23 18:13:45 +08:00
choice = input("请输入选择 (0-11): ").strip()
2025-08-15 17:25:07 +08:00
if choice == '0':
print("\n感谢使用双AI角色对话系统")
break
elif choice == '1':
process_pdf_workflow()
elif choice == '2':
show_character_info()
elif choice == '3':
2025-08-23 17:27:01 +08:00
run_dialogue_system(enableScore = True)
2025-08-15 17:25:07 +08:00
elif choice == '4':
2025-08-23 17:27:01 +08:00
run_dialogue_system(enableScore = False)
2025-08-15 17:25:07 +08:00
elif choice == '5':
2025-08-23 18:13:45 +08:00
run_dialogue_system(enableScore = True, useManualScoring = True)
2025-08-15 17:25:07 +08:00
elif choice == '6':
2025-08-23 18:13:45 +08:00
show_system_status()
2025-08-23 17:10:23 +08:00
elif choice == '7':
2025-08-23 18:13:45 +08:00
show_scoring_statistics()
2025-08-23 17:10:23 +08:00
elif choice == '8':
2025-08-23 18:13:45 +08:00
# 模型性能分析与优化 - 待实现
print("模型性能分析与优化功能开发中...")
2025-08-23 17:10:23 +08:00
elif choice == '9':
2025-08-23 18:13:45 +08:00
generate_training_dataset()
2025-08-23 17:10:23 +08:00
elif choice == '10':
2025-08-23 18:13:45 +08:00
run_model_optimization()
elif choice == '11':
2025-08-15 17:25:07 +08:00
show_usage_guide()
else:
print("❌ 无效选择,请重新输入")
def show_usage_guide():
"""显示使用说明"""
print("\n" + "="*60)
print("系统使用说明")
print("="*60)
guide = """
🚀 快速开始:
1. 首次使用建议先运行"创建演示对话场景"
2. 如有PDF世界观文档选择"处理PDF世界观文档"
3. 通过"启动双AI对话系统"开始角色对话
📁 文档格式说明:
- 世界观文档: worldview_template_coc.json (参考COC设定)
- 角色设定: character_template_*.json (包含详细人设)
🔧 系统功能:
- PDF自动转换为RAG知识库
- 基于向量相似度的上下文检索
- 持久化对话历史存储
- 角色设定一致性保持
📝 自定义角色:
1. 参考 character_template_*.json 格式
2. 保存到 knowledge_base/ 目录
3. 重启对话系统加载新角色
💾 对话数据:
- 历史对话保存在 conversation_data/ 目录
- 支持会话恢复和历史查看
- 自动记录使用的上下文信息
注意事项:
- 确保模型路径正确设置
- 首次运行需要下载向量化模型
- PDF处理需要足够内存
"""
print(guide)
if __name__ == '__main__':
main()