Project02/AITrain/dual_ai_dialogue_system.py

844 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
RAG增强的角色对话系统
集成世界观知识库,支持角色设定加载和对话生成
'''
import json
import os
import sqlite3
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, asdict
import hashlib
# 尝试导入向量化相关库
try:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
EMBEDDING_AVAILABLE = True
except ImportError:
EMBEDDING_AVAILABLE = False
@dataclass
class DialogueTurn:
"""对话轮次数据结构"""
speaker: str
content: str
timestamp: str
context_used: List[str] # 使用的上下文信息
relevance_score: float = 0.0
@dataclass
class ConversationSession:
"""对话会话数据结构"""
session_id: str
characters: List[str]
worldview: str
start_time: str
last_update: str
dialogue_history: List[DialogueTurn]
class RAGKnowledgeBase:
"""RAG知识库管理器"""
def __init__(self, knowledge_dir: str):
self.knowledge_dir = knowledge_dir
self.worldview_data = None
self.character_data = {}
self.chunks = []
self.embedding_model = None
self.index = None
# 初始化向量模型
if EMBEDDING_AVAILABLE:
try:
self.embedding_model = SentenceTransformer('./sentence-transformers/all-MiniLM-L6-v2')
print("✓ 向量模型加载成功")
except Exception as e:
print(f"✗ 向量模型加载失败: {e}")
self._load_knowledge_base()
def _load_knowledge_base(self):
"""加载知识库"""
# 优先加载RAG知识库作为世界观
rag_worldview_path = "./rag_knowledge/knowledge_base.json"
if os.path.exists(rag_worldview_path):
try:
with open(rag_worldview_path, 'r', encoding='utf-8') as f:
rag_data = json.load(f)
# 从RAG数据中提取世界观信息
self.worldview_data = {
"worldview_name": "克苏鲁神话世界观 (RAG)",
"source": rag_data.get("metadata", {}).get("source_file", "未知"),
"description": f"基于{rag_data.get('metadata', {}).get('source_file', 'PDF文档')}的RAG知识库",
"total_chunks": rag_data.get("metadata", {}).get("total_chunks", 0),
"total_concepts": rag_data.get("metadata", {}).get("total_concepts", 0),
"rag_enabled": True
}
# 保存RAG数据用于检索
self.rag_chunks = rag_data.get("chunks", [])
print(f"✓ RAG世界观加载成功: {self.worldview_data['worldview_name']}")
print(f" - 文档块数: {self.worldview_data['total_chunks']}")
print(f" - 概念数: {self.worldview_data['total_concepts']}")
except Exception as e:
print(f"✗ RAG世界观加载失败: {e}")
self.rag_chunks = []
# 如果没有RAG知识库则加载传统世界观文件
if not hasattr(self, 'rag_chunks') or not self.rag_chunks:
worldview_files = [f for f in os.listdir(self.knowledge_dir)
if f.startswith('worldview') and f.endswith('.json')]
if worldview_files:
worldview_path = os.path.join(self.knowledge_dir, worldview_files[0])
with open(worldview_path, 'r', encoding='utf-8') as f:
self.worldview_data = json.load(f)
print(f"✓ 传统世界观加载成功: {self.worldview_data.get('worldview_name', '未知')}")
# 加载角色数据
character_files = [f for f in os.listdir(self.knowledge_dir)
if f.startswith('character') and f.endswith('.json')]
for char_file in character_files:
char_path = os.path.join(self.knowledge_dir, char_file)
with open(char_path, 'r', encoding='utf-8') as f:
char_data = json.load(f)
char_name = char_data.get('character_name', char_file)
self.character_data[char_name] = char_data
print(f"✓ 角色加载成功: {list(self.character_data.keys())}")
# 构建检索用的文本块
self._build_searchable_chunks()
# 构建向量索引
if EMBEDDING_AVAILABLE and self.embedding_model:
self._build_vector_index()
def _build_searchable_chunks(self):
"""构建可检索的文本块"""
self.chunks = []
# 优先使用RAG知识库的文本块
if hasattr(self, 'rag_chunks') and self.rag_chunks:
for rag_chunk in self.rag_chunks:
self.chunks.append({
"type": "worldview_rag",
"section": "rag_knowledge",
"subsection": rag_chunk.get("type", "unknown"),
"content": rag_chunk.get("content", ""),
"metadata": {
"source": "rag_worldview",
"chunk_id": rag_chunk.get("id", ""),
"size": rag_chunk.get("size", 0),
"hash": rag_chunk.get("hash", "")
}
})
print(f"✓ 使用RAG知识库文本块: {len(self.rag_chunks)}")
else:
# 传统世界观相关文本块
if self.worldview_data:
for section_key, section_data in self.worldview_data.items():
if isinstance(section_data, dict):
for sub_key, sub_data in section_data.items():
if isinstance(sub_data, (str, list)):
content = str(sub_data)
if len(content) > 50: # 只保留有意义的文本
self.chunks.append({
"type": "worldview",
"section": section_key,
"subsection": sub_key,
"content": content,
"metadata": {"source": "worldview"}
})
# 角色相关文本块
for char_name, char_data in self.character_data.items():
for section_key, section_data in char_data.items():
if isinstance(section_data, dict):
for sub_key, sub_data in section_data.items():
if isinstance(sub_data, (str, list)):
content = str(sub_data)
if len(content) > 30:
self.chunks.append({
"type": "character",
"character": char_name,
"section": section_key,
"subsection": sub_key,
"content": content,
"metadata": {"source": char_name}
})
print(f"✓ 构建文本块: {len(self.chunks)}")
def _build_vector_index(self):
"""构建向量索引"""
try:
# 优先使用RAG知识库的预构建向量索引
rag_vector_path = "./rag_knowledge/vector_index.faiss"
rag_embeddings_path = "./rag_knowledge/embeddings.npy"
if os.path.exists(rag_vector_path) and os.path.exists(rag_embeddings_path):
# 加载预构建的向量索引
self.index = faiss.read_index(rag_vector_path)
self.rag_embeddings = np.load(rag_embeddings_path)
print(f"✓ 使用RAG预构建向量索引: {self.index.ntotal}个向量")
return
# 如果没有预构建的向量索引,则重新构建
texts = [chunk["content"] for chunk in self.chunks]
embeddings = self.embedding_model.encode(texts)
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(embeddings.astype(np.float32))
print(f"✓ 向量索引构建成功: {dimension}维, {len(texts)}个向量")
except Exception as e:
print(f"✗ 向量索引构建失败: {e}")
def search_relevant_context(self, query: str, character_name: str = None, top_k: int = 3) -> List[Dict]:
"""搜索相关上下文"""
relevant_chunks = []
# 向量搜索
if EMBEDDING_AVAILABLE and self.embedding_model and self.index:
try:
# 如果使用RAG预构建向量索引直接搜索
if hasattr(self, 'rag_embeddings'):
query_vector = self.embedding_model.encode([query])
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
for distance, idx in zip(distances[0], indices[0]):
if idx < len(self.chunks):
chunk = self.chunks[idx].copy()
chunk["relevance_score"] = float(1 / (1 + distance))
relevant_chunks.append(chunk)
else:
# 传统向量搜索
query_vector = self.embedding_model.encode([query])
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
for distance, idx in zip(distances[0], indices[0]):
if idx < len(self.chunks):
chunk = self.chunks[idx].copy()
chunk["relevance_score"] = float(1 / (1 + distance))
relevant_chunks.append(chunk)
except Exception as e:
print(f"向量搜索失败: {e}")
# 文本搜索作为备选
if not relevant_chunks:
query_lower = query.lower()
for chunk in self.chunks:
content_lower = chunk["content"].lower()
score = 0
for word in query_lower.split():
if word in content_lower:
score += content_lower.count(word)
if score > 0:
chunk_copy = chunk.copy()
chunk_copy["relevance_score"] = score
relevant_chunks.append(chunk_copy)
# 按相关性排序
relevant_chunks.sort(key=lambda x: x["relevance_score"], reverse=True)
# 优先返回特定角色的相关信息
if character_name:
char_chunks = [c for c in relevant_chunks if c.get("character") == character_name]
other_chunks = [c for c in relevant_chunks if c.get("character") != character_name]
relevant_chunks = char_chunks + other_chunks
return relevant_chunks[:top_k]
class ConversationManager:
"""对话管理器"""
def __init__(self, db_path: str = "conversation_history.db"):
self.db_path = db_path
self._init_database()
def _init_database(self):
"""初始化对话历史数据库"""
with sqlite3.connect(self.db_path) as conn:
conn.execute('''
CREATE TABLE IF NOT EXISTS conversations (
session_id TEXT PRIMARY KEY,
characters TEXT,
worldview TEXT,
start_time TEXT,
last_update TEXT,
metadata TEXT
)
''')
conn.execute('''
CREATE TABLE IF NOT EXISTS dialogue_turns (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT,
turn_number INTEGER,
speaker TEXT,
content TEXT,
timestamp TEXT,
context_used TEXT,
relevance_score REAL,
FOREIGN KEY (session_id) REFERENCES conversations (session_id)
)
''')
conn.commit()
def create_session(self, characters: List[str], worldview: str) -> str:
"""创建新的对话会话"""
session_id = hashlib.md5(f"{'-'.join(characters)}-{datetime.now().isoformat()}".encode()).hexdigest()[:12]
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"INSERT INTO conversations (session_id, characters, worldview, start_time, last_update) VALUES (?, ?, ?, ?, ?)",
(session_id, json.dumps(characters), worldview, datetime.now().isoformat(), datetime.now().isoformat())
)
conn.commit()
print(f"✓ 创建对话会话: {session_id}")
return session_id
def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None, relevance_score: float = 0.0):
"""添加对话轮次"""
if context_used is None:
context_used = []
with sqlite3.connect(self.db_path) as conn:
# 获取当前轮次数
cursor = conn.execute("SELECT COUNT(*) FROM dialogue_turns WHERE session_id = ?", (session_id,))
turn_number = cursor.fetchone()[0] + 1
# 插入对话轮次
conn.execute(
"""INSERT INTO dialogue_turns
(session_id, turn_number, speaker, content, timestamp, context_used, relevance_score)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(session_id, turn_number, speaker, content, datetime.now().isoformat(),
json.dumps(context_used), relevance_score)
)
# 更新会话最后更新时间
conn.execute(
"UPDATE conversations SET last_update = ? WHERE session_id = ?",
(datetime.now().isoformat(), session_id)
)
conn.commit()
def get_conversation_history(self, session_id: str, last_n: int = 10) -> List[DialogueTurn]:
"""获取对话历史"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"""SELECT speaker, content, timestamp, context_used, relevance_score
FROM dialogue_turns
WHERE session_id = ?
ORDER BY turn_number DESC LIMIT ?""",
(session_id, last_n)
)
turns = []
for row in cursor.fetchall():
speaker, content, timestamp, context_used, relevance_score = row
turns.append(DialogueTurn(
speaker=speaker,
content=content,
timestamp=timestamp,
context_used=json.loads(context_used or "[]"),
relevance_score=relevance_score
))
return list(reversed(turns)) # 按时间正序返回
def list_sessions(self) -> List[Dict]:
"""列出所有对话会话"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"SELECT session_id, characters, worldview, start_time, last_update FROM conversations ORDER BY last_update DESC"
)
sessions = []
for row in cursor.fetchall():
session_id, characters, worldview, start_time, last_update = row
sessions.append({
"session_id": session_id,
"characters": json.loads(characters),
"worldview": worldview,
"start_time": start_time,
"last_update": last_update
})
return sessions
class DualAIDialogueEngine:
"""双AI对话引擎"""
def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator):
self.kb = knowledge_base
self.conv_mgr = conversation_manager
self.llm_generator = llm_generator
def generate_character_prompt(self, character_name: str, context_info: List[Dict], dialogue_history: List[DialogueTurn],
history_context_count: int = 3, context_info_count: int = 2) -> str:
"""为角色生成对话提示
Args:
character_name: 角色名称
context_info: 相关上下文信息
dialogue_history: 对话历史
history_context_count: 使用的历史对话轮数默认3轮
context_info_count: 使用的上下文信息数量默认2个
"""
char_data = self.kb.character_data.get(character_name, {})
# 基础角色设定
prompt_parts = []
prompt_parts.append(f"你是{character_name},具有以下设定:")
if char_data.get('personality', {}).get('core_traits'):
traits = ", ".join(char_data['personality']['core_traits'])
prompt_parts.append(f"性格特点:{traits}")
if char_data.get('speech_patterns', {}).get('sample_phrases'):
phrases = char_data['speech_patterns']['sample_phrases'][:3]
prompt_parts.append(f"说话风格示例:{'; '.join(phrases)}")
# 当前情境
if char_data.get('current_situation'):
situation = char_data['current_situation']
prompt_parts.append(f"当前状态:{situation.get('current_mood', '')}")
# 相关世界观信息(可控制数量)
if context_info:
prompt_parts.append("相关背景信息:")
for info in context_info[:context_info_count]:
content = info['content'][:200] + "..." if len(info['content']) > 200 else info['content']
prompt_parts.append(f"- {content}")
# 对话历史(可控制数量)
if dialogue_history:
prompt_parts.append("最近的对话:")
# 使用参数控制历史对话轮数
history_to_use = dialogue_history[-history_context_count:] if history_context_count > 0 else []
for turn in history_to_use:
prompt_parts.append(f"{turn.speaker}: {turn.content}")
prompt_parts.append("\n请根据角色设定和上下文生成符合角色特点的自然对话。回复应该在50-150字之间。")
return "\n".join(prompt_parts)
def generate_dialogue(self, session_id: str, current_speaker: str, topic_hint: str = "",
history_context_count: int = 3, context_info_count: int = 2) -> Tuple[str, List[str]]:
"""生成角色对话
Args:
session_id: 会话ID
current_speaker: 当前说话者
topic_hint: 话题提示
history_context_count: 使用的历史对话轮数默认3轮
context_info_count: 使用的上下文信息数量默认2个
"""
# 获取对话历史
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
# 构建搜索查询
if dialogue_history:
# 基于最近的对话内容(可控制数量)
recent_turns = dialogue_history[-history_context_count:] if history_context_count > 0 else []
recent_content = " ".join([turn.content for turn in recent_turns])
search_query = recent_content + " " + topic_hint
else:
# 首次对话
search_query = f"{current_speaker} {topic_hint} introduction greeting"
# 搜索相关上下文
context_info = self.kb.search_relevant_context(search_query, current_speaker, context_info_count)
# 生成提示(使用参数控制上下文数量)
prompt = self.generate_character_prompt(
current_speaker,
context_info,
dialogue_history,
history_context_count,
context_info_count
)
# 生成对话 - 使用双模型系统
try:
# 检查是否为双模型对话系统
if hasattr(self.llm_generator, 'generate_dual_character_dialogue'):
# 使用双模型系统
response = self.llm_generator.generate_dual_character_dialogue(
current_speaker,
prompt,
topic_hint or "请继续对话",
temperature=0.8,
max_new_tokens=150
)
else:
# 兼容旧的单模型系统
response = self.llm_generator.generate_character_dialogue(
current_speaker,
prompt,
topic_hint or "请继续对话",
temperature=0.8,
max_new_tokens=150
)
# 记录使用的上下文
context_used = [f"{info['section']}.{info['subsection']}" for info in context_info[:context_info_count]]
avg_relevance = sum(info['relevance_score'] for info in context_info[:context_info_count]) / len(context_info[:context_info_count]) if context_info else 0.0
# 保存对话轮次
self.conv_mgr.add_dialogue_turn(
session_id, current_speaker, response, context_used, avg_relevance
)
return response, context_used
except Exception as e:
print(f"✗ 对话生成失败: {e}")
return f"[{current_speaker}暂时无法回应]", []
def run_conversation_turn(self, session_id: str, characters: List[str], turns_count: int = 1, topic: str = "",
history_context_count: int = 3, context_info_count: int = 2):
"""运行对话轮次
Args:
session_id: 会话ID
characters: 角色列表
turns_count: 对话轮数
topic: 对话主题
history_context_count: 使用的历史对话轮数默认3轮
context_info_count: 使用的上下文信息数量默认2个
"""
results = []
print(f" [上下文设置: 历史{history_context_count}轮, 信息{context_info_count}个]")
for i in range(turns_count):
for char in characters:
response, context_used = self.generate_dialogue(
session_id,
char,
topic,
history_context_count,
context_info_count
)
results.append({
"speaker": char,
"content": response,
"context_used": context_used,
"turn": i + 1,
"context_settings": {
"history_count": history_context_count,
"context_info_count": context_info_count
}
})
print(f"{char}: {response}")
# if context_used:
# print(f" [使用上下文: {', '.join(context_used)}]")
print()
return results
def run_dual_model_conversation(self, session_id: str, topic: str = "", turns: int = 4,
history_context_count: int = 3, context_info_count: int = 2):
"""使用双模型系统运行对话
Args:
session_id: 会话ID
topic: 对话主题
turns: 对话轮数
history_context_count: 使用的历史对话轮数
context_info_count: 使用的上下文信息数量
"""
# 检查是否为双模型对话系统
if not hasattr(self.llm_generator, 'run_dual_character_conversation'):
print("⚠ 当前系统不支持双模型对话")
return self.run_conversation_turn(session_id, self.llm_generator.list_characters(), turns, topic,
history_context_count, context_info_count)
# 获取对话历史
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
# 构建上下文信息
if dialogue_history:
recent_turns = dialogue_history[-history_context_count:] if history_context_count > 0 else []
recent_content = " ".join([turn.content for turn in recent_turns])
search_query = recent_content + " " + topic
else:
search_query = f"{topic} introduction greeting"
# 搜索相关上下文
context_info = self.kb.search_relevant_context(search_query, top_k=context_info_count)
# 构建上下文字符串
context_str = ""
if context_info:
context_str = "相关背景信息:"
for info in context_info[:context_info_count]:
content = info['content'][:150] + "..." if len(info['content']) > 150 else info['content']
context_str += f"\n- {content}"
print(f"\n=== 双模型对话系统 ===")
print(f"主题: {topic}")
print(f"角色: {', '.join(self.llm_generator.list_characters())}")
print(f"轮数: {turns}")
print(f"上下文设置: 历史{history_context_count}轮, 信息{context_info_count}")
# 使用双模型系统生成对话
conversation_results = self.llm_generator.run_dual_character_conversation(
topic=topic,
turns=turns,
context=context_str,
temperature=0.8,
max_new_tokens=150
)
# 保存对话到数据库
for result in conversation_results:
self.conv_mgr.add_dialogue_turn(
session_id,
result['speaker'],
result['dialogue'],
[result.get('context_used', '')],
0.8 # 默认相关性分数
)
return conversation_results
def main():
"""主函数 - 演示系统使用"""
print("=== RAG增强双AI角色对话系统 ===")
# 设置路径
knowledge_dir = "./knowledge_base" # 包含世界观和角色文档的目录
# 检查必要文件
required_dirs = [knowledge_dir]
for dir_path in required_dirs:
if not os.path.exists(dir_path):
print(f"✗ 目录不存在: {dir_path}")
print("请确保以下文件存在:")
print("- ./knowledge_base/worldview_template_coc.json")
print("- ./knowledge_base/character_template_detective.json")
print("- ./knowledge_base/character_template_professor.json")
return
try:
# 初始化系统组件
print("\n初始化系统...")
kb = RAGKnowledgeBase(knowledge_dir)
conv_mgr = ConversationManager()
# 这里需要你的LLM生成器使用新的双模型对话系统
from npc_dialogue_generator import DualModelDialogueGenerator
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' # 根据你的路径调整
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
if not os.path.exists(lora_model_path):
lora_model_path = None
# 创建双模型对话生成器
if hasattr(kb, 'character_data') and len(kb.character_data) >= 2:
print("✓ 使用knowledge_base角色数据创建双模型对话系统")
# 获取前两个角色
character_names = list(kb.character_data.keys())[:2]
char1_name = character_names[0]
char2_name = character_names[1]
# 配置两个角色的模型
character1_config = {
"name": char1_name,
"lora_path": lora_model_path, # 可以为每个角色设置不同的LoRA
"character_data": kb.character_data[char1_name]
}
character2_config = {
"name": char2_name,
"lora_path": lora_model_path, # 可以为每个角色设置不同的LoRA
"character_data": kb.character_data[char2_name]
}
llm_generator = DualModelDialogueGenerator(
base_model_path,
character1_config,
character2_config
)
else:
print("⚠ 角色数据不足,无法创建双模型对话系统")
return
# 创建对话引擎
dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator)
print("✓ 系统初始化完成")
# 交互式菜单
while True:
print("\n" + "="*50)
print("双AI角色对话系统")
print("1. 创建新对话")
print("2. 继续已有对话")
print("3. 查看对话历史")
print("4. 列出所有会话")
print("0. 退出")
print("="*50)
choice = input("请选择操作: ").strip()
if choice == '0':
break
elif choice == '1':
# 创建新对话
print(f"可用角色: {list(kb.character_data.keys())}")
characters = input("请输入两个角色名称(用空格分隔): ").strip().split()
if len(characters) != 2:
print("❌ 请输入正好两个角色名称")
continue
worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观'
session_id = conv_mgr.create_session(characters, worldview)
topic = input("请输入对话主题(可选): ").strip()
turns = int(input("请输入对话轮次数量默认2: ").strip() or "2")
# 历史上下文控制选项
print("\n历史上下文设置:")
history_count = input("使用历史对话轮数默认30表示不使用: ").strip()
history_count = int(history_count) if history_count.isdigit() else 3
context_info_count = input("使用上下文信息数量默认2: ").strip()
context_info_count = int(context_info_count) if context_info_count.isdigit() else 2
print(f"\n开始对话 - 会话ID: {session_id}")
print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}")
# 询问是否使用双模型对话
use_dual_model = input("是否使用双模型对话系统?(y/n默认y): ").strip().lower()
if use_dual_model != 'n':
print("使用双模型对话系统...")
dialogue_engine.run_dual_model_conversation(session_id, topic, turns, history_count, context_info_count)
else:
print("使用传统对话系统...")
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
elif choice == '2':
# 继续已有对话
sessions = conv_mgr.list_sessions()
if not sessions:
print("❌ 没有已有对话")
continue
print("已有会话:")
for i, session in enumerate(sessions[:5]):
chars = ", ".join(session['characters'])
print(f"{i+1}. {session['session_id'][:8]}... ({chars}) - {session['last_update'][:16]}")
try:
idx = int(input("请选择会话编号: ").strip()) - 1
if 0 <= idx < len(sessions):
session = sessions[idx]
session_id = session['session_id']
characters = session['characters']
# 显示最近的对话
history = conv_mgr.get_conversation_history(session_id, 4)
if history:
print("\n最近的对话:")
for turn in history:
print(f"{turn.speaker}: {turn.content}")
topic = input("请输入对话主题(可选): ").strip()
turns = int(input("请输入对话轮次数量默认1: ").strip() or "1")
# 历史上下文控制选项
print("\n历史上下文设置:")
history_count = input("使用历史对话轮数默认30表示不使用: ").strip()
history_count = int(history_count) if history_count.isdigit() else 3
context_info_count = input("使用上下文信息数量默认2: ").strip()
context_info_count = int(context_info_count) if context_info_count.isdigit() else 2
print(f"\n继续对话 - 会话ID: {session_id}")
print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}")
# 询问是否使用双模型对话
use_dual_model = input("是否使用双模型对话系统?(y/n默认y): ").strip().lower()
if use_dual_model != 'n':
print("使用双模型对话系统...")
dialogue_engine.run_dual_model_conversation(session_id, topic, turns, history_count, context_info_count)
else:
print("使用传统对话系统...")
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
else:
print("❌ 无效的会话编号")
except ValueError:
print("❌ 请输入有效的数字")
elif choice == '3':
# 查看对话历史
session_id = input("请输入会话ID前8位即可: ").strip()
# 查找匹配的会话
sessions = conv_mgr.list_sessions()
matching_session = None
for session in sessions:
if session['session_id'].startswith(session_id):
matching_session = session
break
if matching_session:
full_session_id = matching_session['session_id']
history = conv_mgr.get_conversation_history(full_session_id, 20)
if history:
print(f"\n对话历史 - {full_session_id}")
print(f"角色: {', '.join(matching_session['characters'])}")
print(f"世界观: {matching_session['worldview']}")
print("-" * 50)
for turn in history:
print(f"[{turn.timestamp[:16]}] {turn.speaker}:")
print(f" {turn.content}")
if turn.context_used:
print(f" 使用上下文: {', '.join(turn.context_used)}")
print()
else:
print("该会话暂无对话历史")
else:
print("❌ 未找到匹配的会话")
elif choice == '4':
# 列出所有会话
sessions = conv_mgr.list_sessions()
if sessions:
print(f"\n共有 {len(sessions)} 个对话会话:")
for session in sessions:
chars = ", ".join(session['characters'])
print(f"ID: {session['session_id']}")
print(f" 角色: {chars}")
print(f" 世界观: {session['worldview']}")
print(f" 最后更新: {session['last_update']}")
print()
else:
print("暂无对话会话")
else:
print("❌ 无效选择")
except Exception as e:
print(f"✗ 系统运行出错: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()