Project02/AITrain/main_controller.py
2025-08-23 18:37:00 +08:00

960 lines
35 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
双AI角色对话系统主控制程序
完整的工作流程PDF处理 -> 角色加载 -> RAG对话 -> 历史记录
'''
import os
import sys
import shutil
from typing import List, Dict
import json
def check_dependencies():
"""检查依赖库"""
missing_deps = []
try:
import PyPDF2
except ImportError:
missing_deps.append("PyPDF2")
try:
import pymupdf
print("✓ pymupdf 可用")
except ImportError:
print("⚠ pymupdf 不可用,将使用 PyPDF2")
try:
import sentence_transformers
import faiss
print("✓ 向量化功能可用")
except ImportError:
print("⚠ 向量化功能不可用,将使用文本匹配")
if missing_deps:
print(f"✗ 缺少依赖库: {', '.join(missing_deps)}")
print("请运行: pip install PyPDF2 sentence-transformers faiss-cpu")
return False
return True
def setup_directories():
"""设置项目目录结构"""
directories = [
"./knowledge_base",
"./characters",
"./worldview",
"./rag_knowledge",
"./conversation_data"
]
for dir_path in directories:
os.makedirs(dir_path, exist_ok=True)
print(f"✓ 目录就绪: {dir_path}")
def copy_demo_files():
"""复制演示文档到知识库目录"""
file_mappings = [
("./worldview/worldview_template_coc.json", "./knowledge_base/worldview_template_coc.json"),
("./characters/character_template_detective.json", "./knowledge_base/character_template_detective.json"),
("./characters/character_template_professor.json", "./knowledge_base/character_template_professor.json")
]
for source, target in file_mappings:
if os.path.exists(source):
shutil.copy2(source, target)
print(f"✓ 复制文档: {os.path.basename(target)}")
def process_pdf_workflow():
"""PDF处理工作流"""
print("\n" + "="*60)
print("PDF世界观文档处理")
print("="*60)
from pdf_to_rag_processor import PDFToRAGProcessor
pdf_path = input("请输入PDF文件路径 (例: ./coc.pdf): ").strip()
if not os.path.exists(pdf_path):
print(f"✗ 文件不存在: {pdf_path}")
return False
try:
processor = PDFToRAGProcessor()
result = processor.process_pdf_to_rag(pdf_path, "./rag_knowledge")
print(f"\n✓ PDF处理完成!")
print(f" - 文档块数: {result['chunks_count']}")
print(f" - 概念数: {result['concepts_count']}")
print(f" - 向量索引: {'启用' if result['vector_enabled'] else '未启用'}")
return True
except Exception as e:
print(f"✗ PDF处理失败: {e}")
return False
def show_character_info():
"""显示角色信息"""
print("\n" + "="*60)
print("角色设定信息")
print("="*60)
knowledge_dir = "./knowledge_base"
character_files = [f for f in os.listdir(knowledge_dir) if f.startswith('character') and f.endswith('.json')]
for char_file in character_files:
try:
with open(os.path.join(knowledge_dir, char_file), 'r', encoding='utf-8') as f:
char_data = json.load(f)
name = char_data.get('character_name', '未知')
occupation = char_data.get('basic_info', {}).get('occupation', '未知')
traits = char_data.get('personality', {}).get('core_traits', [])
print(f"\n角色: {name}")
print(f" 职业: {occupation}")
print(f" 特点: {', '.join(traits[:3])}")
except Exception as e:
print(f"✗ 读取角色文件失败: {char_file} - {e}")
def run_dialogue_system(enableScore: bool, useManualScoring: bool = False):
"""运行双AI对话系统"""
print("\n" + "="*60)
print("启动双AI角色对话系统")
print("="*60)
try:
# 直接启动双模型对话
print("\n正在初始化双模型对话系统...")
from dual_ai_dialogue_system import RAGKnowledgeBase, ConversationManager, DualAIDialogueEngine
from npc_dialogue_generator import DualModelDialogueGenerator
# 初始化组件
kb = RAGKnowledgeBase("./knowledge_base")
conv_mgr = ConversationManager("./conversation_data/conversations.db")
# 检查模型路径
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B'
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
if not os.path.exists(base_model_path):
print(f"✗ 基础模型路径不存在: {base_model_path}")
print("请修改 main_controller.py 中的模型路径")
return
if not os.path.exists(lora_model_path):
lora_model_path = None
print("⚠ LoRA模型不存在使用基础模型")
# 检查角色数据
if not hasattr(kb, 'character_data') or len(kb.character_data) < 2:
print("✗ 角色数据不足,无法创建双模型对话系统")
print("请确保knowledge_base目录中有至少两个角色文件")
return
# 获取前两个角色
character_names = list(kb.character_data.keys())[:2]
char1_name = character_names[0]
char2_name = character_names[1]
print(f"✓ 使用角色: {char1_name}{char2_name}")
# 配置两个角色的模型
character1_config = {
"name": char1_name,
"lora_path": lora_model_path,
"character_data": kb.character_data[char1_name]
}
character2_config = {
"name": char2_name,
"lora_path": lora_model_path,
"character_data": kb.character_data[char2_name]
}
# 创建双模型对话生成器
print("正在初始化双模型对话生成器...")
dual_generator = DualModelDialogueGenerator(
base_model_path,
character1_config,
character2_config
)
# 创建对话引擎(启用评分功能)
dialogue_engine = DualAIDialogueEngine(
kb,
conv_mgr,
dual_generator,
enable_scoring=enableScore,
base_model_path=base_model_path,
use_manual_scoring=useManualScoring
)
# 创建对话会话
characters = [char1_name, char2_name]
worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观'
session_id = conv_mgr.create_session(characters, worldview)
print(f"✓ 创建对话会话: {session_id}")
# 交互式对话循环
print(f"\n=== 双AI模型对话系统 ===")
print(f"角色: {char1_name} vs {char2_name}")
print(f"世界观: {worldview}")
print("输入 'quit' 退出对话")
print("-" * 50)
while True:
try:
# 获取用户输入
user_input = input("\n请输入对话主题或指令: ").strip()
if user_input.lower() == 'quit':
print("退出双AI对话系统")
break
if not user_input:
print("请输入有效的对话主题")
continue
# 询问对话轮数
turns_input = input("请输入对话轮数 (默认4): ").strip()
turns = int(turns_input) if turns_input.isdigit() else 4
# 询问历史上下文设置
history_input = input("使用历史对话轮数 (默认2): ").strip()
history_count = int(history_input) if history_input.isdigit() else 2
context_input = input("使用上下文信息数量 (默认10): ").strip()
context_info_count = int(context_input) if context_input.isdigit() else 10
print(f"\n开始对话 - 主题: {user_input}")
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
print("-" * 50)
# 运行双模型对话
dialogue_engine.run_dual_model_conversation(
session_id, user_input, turns, history_count, context_info_count
)
print("-" * 50)
print("对话完成!")
except KeyboardInterrupt:
print("\n\n用户中断对话")
break
except Exception as e:
print(f"对话过程中出现错误: {e}")
import traceback
traceback.print_exc()
except Exception as e:
print(f"✗ 对话系统启动失败: {e}")
import traceback
traceback.print_exc()
def generate_training_dataset():
"""生成训练数据集"""
print("\n" + "="*60)
print("生成训练数据集")
print("="*60)
try:
from dual_ai_dialogue_system import ConversationManager
import sqlite3
import json
import os
from datetime import datetime
conv_mgr = ConversationManager("./conversation_data/conversations.db")
# 创建输出目录
output_dir = "./training_data"
os.makedirs(output_dir, exist_ok=True)
print("请选择训练数据生成类型:默认1")
print("1. 高质量对话数据集 (分数≥8.0)")
print("2. 问题对话改进数据集 (分数<6.0)")
print("3. 角色一致性训练集")
print("4. 创意性增强训练集")
print("5. 完整对话质量数据集")
choice = input("请输入选择 (1-5): ").strip()
choice = int(choice) if choice.isdigit() else 1
with sqlite3.connect(conv_mgr.db_path) as conn:
training_data = []
if choice == 1:
# 高质量对话数据集
print("\n生成高质量对话数据集...")
cursor = conn.execute("""
SELECT speaker, content, score_details, score_feedback
FROM dialogue_turns
WHERE dialogue_score >= 8.0
ORDER BY dialogue_score DESC
LIMIT 200
""")
for speaker, content, score_details, feedback in cursor.fetchall():
training_data.append({
"type": "high_quality",
"speaker": speaker,
"content": content,
"scores": json.loads(score_details) if score_details else {},
"feedback": feedback,
"label": "positive"
})
output_file = f"{output_dir}/high_quality_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
elif choice == 2:
# 问题对话改进数据集
print("\n生成问题对话改进数据集...")
cursor = conn.execute("""
SELECT speaker, content, score_details, score_feedback
FROM dialogue_turns
WHERE dialogue_score < 6.0 AND dialogue_score > 0
ORDER BY dialogue_score ASC
LIMIT 100
""")
for speaker, content, score_details, feedback in cursor.fetchall():
# 为每个低分对话生成改进指导
improvement_prompt = generate_improvement_prompt(content, feedback)
training_data.append({
"type": "improvement",
"speaker": speaker,
"original_content": content,
"scores": json.loads(score_details) if score_details else {},
"problems": feedback,
"improvement_prompt": improvement_prompt,
"label": "needs_improvement"
})
output_file = f"{output_dir}/improvement_dialogues_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
elif choice == 3:
# 角色一致性训练集
print("\n生成角色一致性训练集...")
cursor = conn.execute("""
SELECT speaker, content, score_details
FROM dialogue_turns
WHERE dialogue_score > 0
ORDER BY json_extract(score_details, '$.character_consistency') DESC
LIMIT 150
""")
for speaker, content, score_details in cursor.fetchall():
scores = json.loads(score_details) if score_details else {}
char_consistency = scores.get('character_consistency', 0)
training_data.append({
"type": "character_consistency",
"speaker": speaker,
"content": content,
"character_consistency_score": char_consistency,
"label": "high_consistency" if char_consistency >= 8.0 else "medium_consistency"
})
output_file = f"{output_dir}/character_consistency_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
elif choice == 4:
# 创意性增强训练集
print("\n生成创意性增强训练集...")
cursor = conn.execute("""
SELECT speaker, content, score_details
FROM dialogue_turns
WHERE dialogue_score > 0
ORDER BY json_extract(score_details, '$.creativity') DESC
LIMIT 150
""")
for speaker, content, score_details in cursor.fetchall():
scores = json.loads(score_details) if score_details else {}
creativity = scores.get('creativity', 0)
training_data.append({
"type": "creativity",
"speaker": speaker,
"content": content,
"creativity_score": creativity,
"label": "high_creativity" if creativity >= 8.0 else "medium_creativity"
})
output_file = f"{output_dir}/creativity_enhancement_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
elif choice == 5:
# 完整对话质量数据集
print("\n生成完整对话质量数据集...")
cursor = conn.execute("""
SELECT speaker, content, dialogue_score, score_details, score_feedback
FROM dialogue_turns
WHERE dialogue_score > 0
ORDER BY timestamp DESC
LIMIT 300
""")
for speaker, content, score, score_details, feedback in cursor.fetchall():
training_data.append({
"type": "complete_dataset",
"speaker": speaker,
"content": content,
"overall_score": score,
"dimension_scores": json.loads(score_details) if score_details else {},
"feedback": feedback,
"quality_label": get_quality_label(score)
})
output_file = f"{output_dir}/complete_quality_dataset_{datetime.now().strftime('%Y%m%d_%H%M')}.json"
else:
print("❌ 无效选择")
return
if training_data:
# 保存训练数据
with open(output_file, 'w', encoding='utf-8') as f:
json.dump({
"metadata": {
"created_at": datetime.now().isoformat(),
"total_samples": len(training_data),
"data_type": choice,
"source": "dialogue_scoring_system"
},
"data": training_data
}, f, ensure_ascii=False, indent=2)
print(f"\n✓ 训练数据集生成成功!")
print(f" - 文件路径: {output_file}")
print(f" - 数据条数: {len(training_data)}")
print(f" - 数据类型: {get_dataset_description(choice)}")
# 生成数据集统计信息
generate_dataset_statistics(training_data, choice)
else:
print("❌ 未找到符合条件的数据")
except Exception as e:
print(f"✗ 训练数据集生成失败: {e}")
import traceback
traceback.print_exc()
def generate_improvement_prompt(content, feedback):
"""生成改进提示"""
return f"""原对话: {content}
问题分析: {feedback}
改进要求:
1. 保持角色特征和设定
2. 增强对话的逻辑性和连贯性
3. 提升表达的自然度
4. 增加信息密度,避免冗余
5. 适当增加创意元素
请生成一个改进版本的对话。"""
def get_quality_label(score):
"""根据分数获取质量标签"""
if score >= 9.0:
return "excellent"
elif score >= 8.0:
return "good"
elif score >= 7.0:
return "average"
elif score >= 6.0:
return "below_average"
else:
return "poor"
def get_dataset_description(choice):
"""获取数据集描述"""
descriptions = {
'1': "高质量对话数据集",
'2': "问题对话改进数据集",
'3': "角色一致性训练集",
'4': "创意性增强训练集",
'5': "完整对话质量数据集"
}
return descriptions.get(choice, "未知类型")
def generate_dataset_statistics(training_data, data_type):
"""生成数据集统计信息"""
print(f"\n数据集统计信息:")
if data_type == '1': # 高质量数据集
speakers = {}
for item in training_data:
speaker = item['speaker']
speakers[speaker] = speakers.get(speaker, 0) + 1
print(f" - 角色分布:")
for speaker, count in speakers.items():
print(f"{speaker}: {count}条对话")
elif data_type == '5': # 完整数据集
quality_dist = {}
for item in training_data:
label = item['quality_label']
quality_dist[label] = quality_dist.get(label, 0) + 1
print(f" - 质量分布:")
for label, count in quality_dist.items():
print(f"{label}: {count}条对话")
def run_model_optimization():
"""运行模型迭代优化"""
print("\n" + "="*60)
print("模型迭代优化")
print("="*60)
try:
from dual_ai_dialogue_system import ConversationManager
import sqlite3
import json
import os
from datetime import datetime
conv_mgr = ConversationManager("./conversation_data/conversations.db")
# print("模型优化选项:")
# print("1. 生成LoRA训练脚本")
# print("2. 执行增量训练")
# choice = input("请输入选择 (1-5): ").strip()
# if choice == '1':
# 生成LoRA训练脚本
print("\n=== 生成LoRA训练脚本 ===")
script_content = generate_lora_training_script()
script_path = "./scripts/iterative_lora_training.py"
os.makedirs("./scripts", exist_ok=True)
with open(script_path, 'w', encoding='utf-8') as f:
f.write(script_content)
print(f"✓ LoRA训练脚本已生成: {script_path}")
print("使用方法:")
print(" 1. 先运行训练数据生成 (选项8)")
print(" 2. 修改脚本中的路径配置")
print(f" 3. 运行: python {script_path}")
# elif choice == '2':
# # 执行增量训练
# print("\n=== 执行增量训练 ===")
# # 检查训练数据
# training_dir = "./training_data"
# if not os.path.exists(training_dir):
# print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)")
# return
# training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')]
# if not training_files:
# print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)")
# return
# print(f"找到训练数据文件:")
# for i, file in enumerate(training_files, 1):
# print(f" {i}. {file}")
# file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip()
# try:
# selected_file = training_files[int(file_idx) - 1]
# training_file_path = os.path.join(training_dir, selected_file)
# print(f"将使用训练文件: {selected_file}")
# print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源")
# # 生成训练命令
# training_command = generate_training_command(training_file_path)
# print(f"建议训练命令:")
# print(f" {training_command}")
# # 可选:执行训练(需要用户确认)
# confirm = input("是否现在执行训练?(y/N): ").strip().lower()
# if confirm == 'y':
# print("开始增量训练...")
# # 这里可以添加实际的训练执行逻辑
# print("⚠ 训练功能需要根据实际环境配置")
# except (ValueError, IndexError):
# print("❌ 无效的文件选择")
# else:
# print("❌ 无效选择")
except Exception as e:
print(f"✗ 模型优化失败: {e}")
import traceback
traceback.print_exc()
def generate_lora_training_script():
"""生成LoRA训练脚本"""
return '''#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
基于评分数据的LoRA增量训练脚本
自动生成 - 请根据实际环境调整配置
"""
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
import os
class IterativeLoRATrainer:
def __init__(self, base_model_path, training_data_path, output_path):
self.base_model_path = base_model_path
self.training_data_path = training_data_path
self.output_path = output_path
# LoRA配置
self.lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=16, # LoRA rank
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
)
def load_training_data(self):
"""加载训练数据"""
with open(self.training_data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data['data']
def prepare_training_samples(self, data):
"""准备训练样本"""
samples = []
for item in data:
if item.get('label') == 'positive' or item.get('overall_score', 0) >= 8.0:
# 高质量样本
sample = {
'input': f"角色: {item['speaker']}\\n请生成高质量对话:",
'output': item['content'],
'quality_score': item.get('overall_score', 8.0)
}
samples.append(sample)
return samples
def train(self):
"""执行训练"""
print("开始LoRA增量训练...")
# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained(self.base_model_path)
model = AutoModelForCausalLM.from_pretrained(
self.base_model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# 应用LoRA
model = get_peft_model(model, self.lora_config)
# 加载训练数据
training_data = self.load_training_data()
training_samples = self.prepare_training_samples(training_data)
print(f"训练样本数量: {len(training_samples)}")
# 这里添加实际的训练循环
# 建议使用transformers的Trainer或自定义训练循环
# 保存模型
model.save_pretrained(self.output_path)
tokenizer.save_pretrained(self.output_path)
print(f"训练完成,模型保存到: {self.output_path}")
if __name__ == '__main__':
# 配置参数 - 请根据实际情况修改
BASE_MODEL_PATH = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
TRAINING_DATA_PATH = './training_data/high_quality_dialogues_latest.json'
OUTPUT_PATH = './output/iterative_lora_v2'
trainer = IterativeLoRATrainer(BASE_MODEL_PATH, TRAINING_DATA_PATH, OUTPUT_PATH)
trainer.train()
'''
def generate_training_command(training_file_path):
"""生成训练命令"""
return f"python ./scripts/iterative_lora_training.py --data {training_file_path} --output ./output/optimized_model_v{datetime.now().strftime('%Y%m%d')}"
def show_scoring_statistics():
"""显示对话评分统计"""
print("\n" + "="*60)
print("对话评分统计")
print("="*60)
try:
from dual_ai_dialogue_system import ConversationManager
import json
conv_mgr = ConversationManager("./conversation_data/conversations.db")
# 查询评分数据
import sqlite3
with sqlite3.connect(conv_mgr.db_path) as conn:
# 总体统计
cursor = conn.execute("""
SELECT
COUNT(*) as total_turns,
AVG(dialogue_score) as avg_score,
MAX(dialogue_score) as max_score,
MIN(dialogue_score) as min_score
FROM dialogue_turns
WHERE dialogue_score > 0
""")
stats = cursor.fetchone()
if stats and stats[0] > 0:
total_turns, avg_score, max_score, min_score = stats
print(f"总体统计:")
print(f" - 已评分对话轮数: {total_turns}")
print(f" - 平均分数: {avg_score:.2f}")
print(f" - 最高分数: {max_score:.2f}")
print(f" - 最低分数: {min_score:.2f}")
else:
print("暂无评分数据")
return
# 按角色统计
print(f"\n按角色统计:")
cursor = conn.execute("""
SELECT
speaker,
COUNT(*) as turns,
AVG(dialogue_score) as avg_score,
MAX(dialogue_score) as max_score
FROM dialogue_turns
WHERE dialogue_score > 0
GROUP BY speaker
ORDER BY avg_score DESC
""")
for row in cursor.fetchall():
speaker, turns, avg_score, max_score = row
print(f" - {speaker}: 平均{avg_score:.2f}分 (最高{max_score:.2f}分, {turns}轮对话)")
# 最近高分对话
print(f"\n最近高分对话 (分数≥8.0):")
cursor = conn.execute("""
SELECT speaker, content, dialogue_score, score_feedback, timestamp
FROM dialogue_turns
WHERE dialogue_score >= 8.0
ORDER BY timestamp DESC
LIMIT 5
""")
high_score_turns = cursor.fetchall()
if high_score_turns:
for speaker, content, score, feedback, timestamp in high_score_turns:
print(f" [{timestamp[:16]}] {speaker} ({score:.2f}分)")
print(f" 内容: {content[:80]}...")
if feedback:
print(f" 评价: {feedback[:60]}...")
print()
else:
print(" 暂无高分对话")
# 分数分布统计
print(f"\n分数分布:")
cursor = conn.execute("""
SELECT
CASE
WHEN dialogue_score >= 9.0 THEN '优秀 (9-10分)'
WHEN dialogue_score >= 8.0 THEN '良好 (8-9分)'
WHEN dialogue_score >= 7.0 THEN '中等 (7-8分)'
WHEN dialogue_score >= 6.0 THEN '及格 (6-7分)'
ELSE '待改进 (<6分)'
END as score_range,
COUNT(*) as count
FROM dialogue_turns
WHERE dialogue_score > 0
GROUP BY score_range
ORDER BY MIN(dialogue_score) DESC
""")
for score_range, count in cursor.fetchall():
percentage = (count / total_turns) * 100
print(f" - {score_range}: {count}轮 ({percentage:.1f}%)")
except Exception as e:
print(f"✗ 评分统计查询失败: {e}")
def show_system_status():
"""显示系统状态"""
print("\n" + "="*60)
print("系统状态检查")
print("="*60)
# 检查文件
files_to_check = [
("./knowledge_base/worldview_template_coc.json", "世界观模板"),
("./knowledge_base/character_template_detective.json", "侦探角色"),
("./knowledge_base/character_template_professor.json", "教授角色"),
("./pdf_to_rag_processor.py", "PDF处理器"),
("./dual_ai_dialogue_system.py", "对话系统"),
("./npc_dialogue_generator.py", "NPC生成器")
]
print("\n文件检查:")
for file_path, description in files_to_check:
if os.path.exists(file_path):
print(f"{description}: {file_path}")
else:
print(f"{description}: {file_path} (不存在)")
# 检查目录
print("\n目录检查:")
directories = ["./knowledge_base", "./rag_knowledge", "./conversation_data"]
for dir_path in directories:
if os.path.exists(dir_path):
file_count = len([f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))])
print(f"{dir_path}: {file_count} 个文件")
else:
print(f"{dir_path}: 不存在")
# 检查对话会话
try:
from dual_ai_dialogue_system import ConversationManager
conv_mgr = ConversationManager("./conversation_data/conversations.db")
sessions = conv_mgr.list_sessions()
print(f"\n✓ 对话会话: {len(sessions)}")
except Exception as e:
print(f"\n✗ 对话会话检查失败: {e}")
def main():
"""主控制程序"""
print("="*70)
print(" 双AI角色对话系统 - 主控制程序")
print(" 基于RAG的世界观增强对话引擎")
print("="*70)
# 检查依赖
if not check_dependencies():
return
# 设置目录
# setup_directories()
# copy_demo_files()
while True:
print("\n" + "="*50)
print("主菜单 - 请选择操作:")
print("1. 处理PDF世界观文档 (转换为RAG格式)")
print("2. 查看角色设定信息")
print("3. 启动双AI对话系统 (开启AI打分)")
print("4. 启动双AI对话系统 (关闭AI打分)")
print("5. 启动双AI对话系统 (开启人工打分)")
print("6. 系统状态检查")
print("7. 查看对话评分统计")
print("8. 生成训练数据集")
print("9. 模型迭代优化")
print("0. 退出")
print("="*50)
choice = input("请输入选择 (0-11): ").strip()
if choice == '0':
print("\n感谢使用双AI角色对话系统")
break
elif choice == '1':
process_pdf_workflow()
elif choice == '2':
show_character_info()
elif choice == '3':
run_dialogue_system(enableScore = True)
elif choice == '4':
run_dialogue_system(enableScore = False)
elif choice == '5':
run_dialogue_system(enableScore = True, useManualScoring = True)
elif choice == '6':
show_system_status()
elif choice == '7':
show_scoring_statistics()
elif choice == '8':
generate_training_dataset()
elif choice == '9':
run_model_optimization()
elif choice == '11':
show_usage_guide()
else:
print("❌ 无效选择,请重新输入")
def show_usage_guide():
"""显示使用说明"""
print("\n" + "="*60)
print("系统使用说明")
print("="*60)
guide = """
🚀 快速开始:
1. 首次使用建议先运行"创建演示对话场景"
2. 如有PDF世界观文档选择"处理PDF世界观文档"
3. 通过"启动双AI对话系统"开始角色对话
📁 文档格式说明:
- 世界观文档: worldview_template_coc.json (参考COC设定)
- 角色设定: character_template_*.json (包含详细人设)
🔧 系统功能:
- PDF自动转换为RAG知识库
- 基于向量相似度的上下文检索
- 持久化对话历史存储
- 角色设定一致性保持
📝 自定义角色:
1. 参考 character_template_*.json 格式
2. 保存到 knowledge_base/ 目录
3. 重启对话系统加载新角色
💾 对话数据:
- 历史对话保存在 conversation_data/ 目录
- 支持会话恢复和历史查看
- 自动记录使用的上下文信息
⚠️ 注意事项:
- 确保模型路径正确设置
- 首次运行需要下载向量化模型
- PDF处理需要足够内存
"""
print(guide)
if __name__ == '__main__':
main()