2025-08-14 07:17:50 +08:00
|
|
|
|
#!/usr/bin/env python
|
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
'''
|
|
|
|
|
|
RAG增强的角色对话系统
|
|
|
|
|
|
集成世界观知识库,支持角色设定加载和对话生成
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
|
import os
|
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
|
|
from dataclasses import dataclass, asdict
|
|
|
|
|
|
import hashlib
|
|
|
|
|
|
|
|
|
|
|
|
# 尝试导入向量化相关库
|
|
|
|
|
|
try:
|
|
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
import faiss
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
EMBEDDING_AVAILABLE = True
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
EMBEDDING_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class DialogueTurn:
|
|
|
|
|
|
"""对话轮次数据结构"""
|
|
|
|
|
|
speaker: str
|
|
|
|
|
|
content: str
|
|
|
|
|
|
timestamp: str
|
|
|
|
|
|
context_used: List[str] # 使用的上下文信息
|
|
|
|
|
|
relevance_score: float = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class ConversationSession:
|
|
|
|
|
|
"""对话会话数据结构"""
|
|
|
|
|
|
session_id: str
|
|
|
|
|
|
characters: List[str]
|
|
|
|
|
|
worldview: str
|
|
|
|
|
|
start_time: str
|
|
|
|
|
|
last_update: str
|
|
|
|
|
|
dialogue_history: List[DialogueTurn]
|
|
|
|
|
|
|
|
|
|
|
|
class RAGKnowledgeBase:
|
|
|
|
|
|
"""RAG知识库管理器"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, knowledge_dir: str):
|
|
|
|
|
|
self.knowledge_dir = knowledge_dir
|
|
|
|
|
|
self.worldview_data = None
|
|
|
|
|
|
self.character_data = {}
|
|
|
|
|
|
self.chunks = []
|
|
|
|
|
|
self.embedding_model = None
|
|
|
|
|
|
self.index = None
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化向量模型
|
|
|
|
|
|
if EMBEDDING_AVAILABLE:
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.embedding_model = SentenceTransformer('./sentence-transformers/all-MiniLM-L6-v2')
|
|
|
|
|
|
print("✓ 向量模型加载成功")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 向量模型加载失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
self._load_knowledge_base()
|
|
|
|
|
|
|
|
|
|
|
|
def _load_knowledge_base(self):
|
|
|
|
|
|
"""加载知识库"""
|
2025-08-15 14:42:13 +08:00
|
|
|
|
# 优先加载RAG知识库作为世界观
|
|
|
|
|
|
rag_worldview_path = "./rag_knowledge/knowledge_base.json"
|
|
|
|
|
|
if os.path.exists(rag_worldview_path):
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(rag_worldview_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
rag_data = json.load(f)
|
|
|
|
|
|
# 从RAG数据中提取世界观信息
|
|
|
|
|
|
self.worldview_data = {
|
|
|
|
|
|
"worldview_name": "克苏鲁神话世界观 (RAG)",
|
|
|
|
|
|
"source": rag_data.get("metadata", {}).get("source_file", "未知"),
|
|
|
|
|
|
"description": f"基于{rag_data.get('metadata', {}).get('source_file', 'PDF文档')}的RAG知识库",
|
|
|
|
|
|
"total_chunks": rag_data.get("metadata", {}).get("total_chunks", 0),
|
|
|
|
|
|
"total_concepts": rag_data.get("metadata", {}).get("total_concepts", 0),
|
|
|
|
|
|
"rag_enabled": True
|
|
|
|
|
|
}
|
|
|
|
|
|
# 保存RAG数据用于检索
|
|
|
|
|
|
self.rag_chunks = rag_data.get("chunks", [])
|
|
|
|
|
|
print(f"✓ RAG世界观加载成功: {self.worldview_data['worldview_name']}")
|
|
|
|
|
|
print(f" - 文档块数: {self.worldview_data['total_chunks']}")
|
|
|
|
|
|
print(f" - 概念数: {self.worldview_data['total_concepts']}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ RAG世界观加载失败: {e}")
|
|
|
|
|
|
self.rag_chunks = []
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有RAG知识库,则加载传统世界观文件
|
|
|
|
|
|
if not hasattr(self, 'rag_chunks') or not self.rag_chunks:
|
|
|
|
|
|
worldview_files = [f for f in os.listdir(self.knowledge_dir)
|
|
|
|
|
|
if f.startswith('worldview') and f.endswith('.json')]
|
|
|
|
|
|
if worldview_files:
|
|
|
|
|
|
worldview_path = os.path.join(self.knowledge_dir, worldview_files[0])
|
|
|
|
|
|
with open(worldview_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
self.worldview_data = json.load(f)
|
|
|
|
|
|
print(f"✓ 传统世界观加载成功: {self.worldview_data.get('worldview_name', '未知')}")
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
|
|
|
|
|
# 加载角色数据
|
|
|
|
|
|
character_files = [f for f in os.listdir(self.knowledge_dir)
|
|
|
|
|
|
if f.startswith('character') and f.endswith('.json')]
|
|
|
|
|
|
for char_file in character_files:
|
|
|
|
|
|
char_path = os.path.join(self.knowledge_dir, char_file)
|
|
|
|
|
|
with open(char_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
char_data = json.load(f)
|
|
|
|
|
|
char_name = char_data.get('character_name', char_file)
|
|
|
|
|
|
self.character_data[char_name] = char_data
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✓ 角色加载成功: {list(self.character_data.keys())}")
|
|
|
|
|
|
|
|
|
|
|
|
# 构建检索用的文本块
|
|
|
|
|
|
self._build_searchable_chunks()
|
|
|
|
|
|
|
|
|
|
|
|
# 构建向量索引
|
|
|
|
|
|
if EMBEDDING_AVAILABLE and self.embedding_model:
|
|
|
|
|
|
self._build_vector_index()
|
|
|
|
|
|
|
|
|
|
|
|
def _build_searchable_chunks(self):
|
|
|
|
|
|
"""构建可检索的文本块"""
|
|
|
|
|
|
self.chunks = []
|
|
|
|
|
|
|
2025-08-15 14:42:13 +08:00
|
|
|
|
# 优先使用RAG知识库的文本块
|
|
|
|
|
|
if hasattr(self, 'rag_chunks') and self.rag_chunks:
|
|
|
|
|
|
for rag_chunk in self.rag_chunks:
|
|
|
|
|
|
self.chunks.append({
|
|
|
|
|
|
"type": "worldview_rag",
|
|
|
|
|
|
"section": "rag_knowledge",
|
|
|
|
|
|
"subsection": rag_chunk.get("type", "unknown"),
|
|
|
|
|
|
"content": rag_chunk.get("content", ""),
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"source": "rag_worldview",
|
|
|
|
|
|
"chunk_id": rag_chunk.get("id", ""),
|
|
|
|
|
|
"size": rag_chunk.get("size", 0),
|
|
|
|
|
|
"hash": rag_chunk.get("hash", "")
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
|
|
|
|
|
print(f"✓ 使用RAG知识库文本块: {len(self.rag_chunks)} 个")
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 传统世界观相关文本块
|
|
|
|
|
|
if self.worldview_data:
|
|
|
|
|
|
for section_key, section_data in self.worldview_data.items():
|
|
|
|
|
|
if isinstance(section_data, dict):
|
|
|
|
|
|
for sub_key, sub_data in section_data.items():
|
|
|
|
|
|
if isinstance(sub_data, (str, list)):
|
|
|
|
|
|
content = str(sub_data)
|
|
|
|
|
|
if len(content) > 50: # 只保留有意义的文本
|
|
|
|
|
|
self.chunks.append({
|
|
|
|
|
|
"type": "worldview",
|
|
|
|
|
|
"section": section_key,
|
|
|
|
|
|
"subsection": sub_key,
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
"metadata": {"source": "worldview"}
|
|
|
|
|
|
})
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
|
|
|
|
|
# 角色相关文本块
|
|
|
|
|
|
for char_name, char_data in self.character_data.items():
|
|
|
|
|
|
for section_key, section_data in char_data.items():
|
|
|
|
|
|
if isinstance(section_data, dict):
|
|
|
|
|
|
for sub_key, sub_data in section_data.items():
|
|
|
|
|
|
if isinstance(sub_data, (str, list)):
|
|
|
|
|
|
content = str(sub_data)
|
|
|
|
|
|
if len(content) > 30:
|
|
|
|
|
|
self.chunks.append({
|
|
|
|
|
|
"type": "character",
|
|
|
|
|
|
"character": char_name,
|
|
|
|
|
|
"section": section_key,
|
|
|
|
|
|
"subsection": sub_key,
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
"metadata": {"source": char_name}
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✓ 构建文本块: {len(self.chunks)} 个")
|
|
|
|
|
|
|
|
|
|
|
|
def _build_vector_index(self):
|
|
|
|
|
|
"""构建向量索引"""
|
|
|
|
|
|
try:
|
2025-08-15 14:42:13 +08:00
|
|
|
|
# 优先使用RAG知识库的预构建向量索引
|
|
|
|
|
|
rag_vector_path = "./rag_knowledge/vector_index.faiss"
|
|
|
|
|
|
rag_embeddings_path = "./rag_knowledge/embeddings.npy"
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(rag_vector_path) and os.path.exists(rag_embeddings_path):
|
|
|
|
|
|
# 加载预构建的向量索引
|
|
|
|
|
|
self.index = faiss.read_index(rag_vector_path)
|
|
|
|
|
|
self.rag_embeddings = np.load(rag_embeddings_path)
|
|
|
|
|
|
print(f"✓ 使用RAG预构建向量索引: {self.index.ntotal}个向量")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有预构建的向量索引,则重新构建
|
2025-08-14 07:17:50 +08:00
|
|
|
|
texts = [chunk["content"] for chunk in self.chunks]
|
|
|
|
|
|
embeddings = self.embedding_model.encode(texts)
|
|
|
|
|
|
|
|
|
|
|
|
dimension = embeddings.shape[1]
|
|
|
|
|
|
self.index = faiss.IndexFlatL2(dimension)
|
|
|
|
|
|
self.index.add(embeddings.astype(np.float32))
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✓ 向量索引构建成功: {dimension}维, {len(texts)}个向量")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 向量索引构建失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def search_relevant_context(self, query: str, character_name: str = None, top_k: int = 3) -> List[Dict]:
|
|
|
|
|
|
"""搜索相关上下文"""
|
|
|
|
|
|
relevant_chunks = []
|
|
|
|
|
|
|
|
|
|
|
|
# 向量搜索
|
|
|
|
|
|
if EMBEDDING_AVAILABLE and self.embedding_model and self.index:
|
|
|
|
|
|
try:
|
2025-08-15 14:42:13 +08:00
|
|
|
|
# 如果使用RAG预构建向量索引,直接搜索
|
|
|
|
|
|
if hasattr(self, 'rag_embeddings'):
|
|
|
|
|
|
query_vector = self.embedding_model.encode([query])
|
|
|
|
|
|
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
|
|
|
|
|
|
|
|
|
|
|
|
for distance, idx in zip(distances[0], indices[0]):
|
|
|
|
|
|
if idx < len(self.chunks):
|
|
|
|
|
|
chunk = self.chunks[idx].copy()
|
|
|
|
|
|
chunk["relevance_score"] = float(1 / (1 + distance))
|
|
|
|
|
|
relevant_chunks.append(chunk)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 传统向量搜索
|
|
|
|
|
|
query_vector = self.embedding_model.encode([query])
|
|
|
|
|
|
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
|
|
|
|
|
|
|
|
|
|
|
|
for distance, idx in zip(distances[0], indices[0]):
|
|
|
|
|
|
if idx < len(self.chunks):
|
|
|
|
|
|
chunk = self.chunks[idx].copy()
|
|
|
|
|
|
chunk["relevance_score"] = float(1 / (1 + distance))
|
|
|
|
|
|
relevant_chunks.append(chunk)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"向量搜索失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
# 文本搜索作为备选
|
|
|
|
|
|
if not relevant_chunks:
|
|
|
|
|
|
query_lower = query.lower()
|
|
|
|
|
|
for chunk in self.chunks:
|
|
|
|
|
|
content_lower = chunk["content"].lower()
|
|
|
|
|
|
score = 0
|
|
|
|
|
|
for word in query_lower.split():
|
|
|
|
|
|
if word in content_lower:
|
|
|
|
|
|
score += content_lower.count(word)
|
|
|
|
|
|
|
|
|
|
|
|
if score > 0:
|
|
|
|
|
|
chunk_copy = chunk.copy()
|
|
|
|
|
|
chunk_copy["relevance_score"] = score
|
|
|
|
|
|
relevant_chunks.append(chunk_copy)
|
|
|
|
|
|
|
|
|
|
|
|
# 按相关性排序
|
|
|
|
|
|
relevant_chunks.sort(key=lambda x: x["relevance_score"], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
# 优先返回特定角色的相关信息
|
|
|
|
|
|
if character_name:
|
|
|
|
|
|
char_chunks = [c for c in relevant_chunks if c.get("character") == character_name]
|
|
|
|
|
|
other_chunks = [c for c in relevant_chunks if c.get("character") != character_name]
|
|
|
|
|
|
relevant_chunks = char_chunks + other_chunks
|
|
|
|
|
|
|
|
|
|
|
|
return relevant_chunks[:top_k]
|
|
|
|
|
|
|
|
|
|
|
|
class ConversationManager:
|
|
|
|
|
|
"""对话管理器"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, db_path: str = "conversation_history.db"):
|
|
|
|
|
|
self.db_path = db_path
|
|
|
|
|
|
self._init_database()
|
|
|
|
|
|
|
|
|
|
|
|
def _init_database(self):
|
|
|
|
|
|
"""初始化对话历史数据库"""
|
|
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
|
|
|
|
conn.execute('''
|
|
|
|
|
|
CREATE TABLE IF NOT EXISTS conversations (
|
|
|
|
|
|
session_id TEXT PRIMARY KEY,
|
|
|
|
|
|
characters TEXT,
|
|
|
|
|
|
worldview TEXT,
|
|
|
|
|
|
start_time TEXT,
|
|
|
|
|
|
last_update TEXT,
|
|
|
|
|
|
metadata TEXT
|
|
|
|
|
|
)
|
|
|
|
|
|
''')
|
|
|
|
|
|
|
|
|
|
|
|
conn.execute('''
|
|
|
|
|
|
CREATE TABLE IF NOT EXISTS dialogue_turns (
|
|
|
|
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
|
|
|
session_id TEXT,
|
|
|
|
|
|
turn_number INTEGER,
|
|
|
|
|
|
speaker TEXT,
|
|
|
|
|
|
content TEXT,
|
|
|
|
|
|
timestamp TEXT,
|
|
|
|
|
|
context_used TEXT,
|
|
|
|
|
|
relevance_score REAL,
|
2025-08-23 17:10:23 +08:00
|
|
|
|
dialogue_score REAL DEFAULT 0.0,
|
|
|
|
|
|
score_details TEXT,
|
|
|
|
|
|
score_feedback TEXT,
|
2025-08-14 07:17:50 +08:00
|
|
|
|
FOREIGN KEY (session_id) REFERENCES conversations (session_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
''')
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
|
|
|
|
|
|
def create_session(self, characters: List[str], worldview: str) -> str:
|
|
|
|
|
|
"""创建新的对话会话"""
|
|
|
|
|
|
session_id = hashlib.md5(f"{'-'.join(characters)}-{datetime.now().isoformat()}".encode()).hexdigest()[:12]
|
|
|
|
|
|
|
|
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
|
|
|
|
conn.execute(
|
|
|
|
|
|
"INSERT INTO conversations (session_id, characters, worldview, start_time, last_update) VALUES (?, ?, ?, ?, ?)",
|
|
|
|
|
|
(session_id, json.dumps(characters), worldview, datetime.now().isoformat(), datetime.now().isoformat())
|
|
|
|
|
|
)
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
|
|
|
|
|
|
print(f"✓ 创建对话会话: {session_id}")
|
|
|
|
|
|
return session_id
|
|
|
|
|
|
|
2025-08-23 17:10:23 +08:00
|
|
|
|
def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None,
|
|
|
|
|
|
relevance_score: float = 0.0, dialogue_score: float = 0.0,
|
|
|
|
|
|
score_details: str = None, score_feedback: str = None):
|
2025-08-14 07:17:50 +08:00
|
|
|
|
"""添加对话轮次"""
|
|
|
|
|
|
if context_used is None:
|
|
|
|
|
|
context_used = []
|
|
|
|
|
|
|
|
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
|
|
|
|
# 获取当前轮次数
|
|
|
|
|
|
cursor = conn.execute("SELECT COUNT(*) FROM dialogue_turns WHERE session_id = ?", (session_id,))
|
|
|
|
|
|
turn_number = cursor.fetchone()[0] + 1
|
|
|
|
|
|
|
|
|
|
|
|
# 插入对话轮次
|
|
|
|
|
|
conn.execute(
|
|
|
|
|
|
"""INSERT INTO dialogue_turns
|
2025-08-23 17:10:23 +08:00
|
|
|
|
(session_id, turn_number, speaker, content, timestamp, context_used, relevance_score,
|
|
|
|
|
|
dialogue_score, score_details, score_feedback)
|
|
|
|
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
2025-08-14 07:17:50 +08:00
|
|
|
|
(session_id, turn_number, speaker, content, datetime.now().isoformat(),
|
2025-08-23 17:10:23 +08:00
|
|
|
|
json.dumps(context_used), relevance_score, dialogue_score, score_details, score_feedback)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 更新会话最后更新时间
|
|
|
|
|
|
conn.execute(
|
|
|
|
|
|
"UPDATE conversations SET last_update = ? WHERE session_id = ?",
|
|
|
|
|
|
(datetime.now().isoformat(), session_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
|
|
|
|
|
|
def get_conversation_history(self, session_id: str, last_n: int = 10) -> List[DialogueTurn]:
|
|
|
|
|
|
"""获取对话历史"""
|
|
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
|
|
|
|
cursor = conn.execute(
|
2025-08-23 17:10:23 +08:00
|
|
|
|
"""SELECT speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback
|
2025-08-14 07:17:50 +08:00
|
|
|
|
FROM dialogue_turns
|
|
|
|
|
|
WHERE session_id = ?
|
|
|
|
|
|
ORDER BY turn_number DESC LIMIT ?""",
|
|
|
|
|
|
(session_id, last_n)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
turns = []
|
|
|
|
|
|
for row in cursor.fetchall():
|
2025-08-23 17:10:23 +08:00
|
|
|
|
speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback = row
|
|
|
|
|
|
turn = DialogueTurn(
|
2025-08-14 07:17:50 +08:00
|
|
|
|
speaker=speaker,
|
|
|
|
|
|
content=content,
|
|
|
|
|
|
timestamp=timestamp,
|
|
|
|
|
|
context_used=json.loads(context_used or "[]"),
|
|
|
|
|
|
relevance_score=relevance_score
|
2025-08-23 17:10:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
# 添加评分信息到turn对象
|
|
|
|
|
|
if dialogue_score:
|
|
|
|
|
|
turn.dialogue_score = dialogue_score
|
|
|
|
|
|
if score_feedback:
|
|
|
|
|
|
turn.score_feedback = score_feedback
|
|
|
|
|
|
turns.append(turn)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
|
|
|
|
|
return list(reversed(turns)) # 按时间正序返回
|
|
|
|
|
|
|
|
|
|
|
|
def list_sessions(self) -> List[Dict]:
|
|
|
|
|
|
"""列出所有对话会话"""
|
|
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
|
|
|
|
cursor = conn.execute(
|
|
|
|
|
|
"SELECT session_id, characters, worldview, start_time, last_update FROM conversations ORDER BY last_update DESC"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
sessions = []
|
|
|
|
|
|
for row in cursor.fetchall():
|
|
|
|
|
|
session_id, characters, worldview, start_time, last_update = row
|
|
|
|
|
|
sessions.append({
|
|
|
|
|
|
"session_id": session_id,
|
|
|
|
|
|
"characters": json.loads(characters),
|
|
|
|
|
|
"worldview": worldview,
|
|
|
|
|
|
"start_time": start_time,
|
|
|
|
|
|
"last_update": last_update
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return sessions
|
|
|
|
|
|
|
|
|
|
|
|
class DualAIDialogueEngine:
|
|
|
|
|
|
"""双AI对话引擎"""
|
|
|
|
|
|
|
2025-08-23 17:10:23 +08:00
|
|
|
|
def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator,
|
|
|
|
|
|
enable_scoring: bool = True, base_model_path: str = None):
|
2025-08-14 07:17:50 +08:00
|
|
|
|
self.kb = knowledge_base
|
|
|
|
|
|
self.conv_mgr = conversation_manager
|
|
|
|
|
|
self.llm_generator = llm_generator
|
2025-08-23 17:10:23 +08:00
|
|
|
|
self.enable_scoring = enable_scoring
|
|
|
|
|
|
self.scorer = None
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化评分器
|
|
|
|
|
|
if enable_scoring and base_model_path:
|
|
|
|
|
|
try:
|
|
|
|
|
|
from dialogue_scorer import DialogueAIScorer
|
|
|
|
|
|
print("正在初始化对话评分系统...")
|
|
|
|
|
|
self.scorer = DialogueAIScorer(
|
|
|
|
|
|
base_model_path=base_model_path,
|
|
|
|
|
|
tokenizer=getattr(llm_generator, 'tokenizer', None),
|
|
|
|
|
|
model=getattr(llm_generator, 'model', None)
|
|
|
|
|
|
)
|
|
|
|
|
|
print("✓ 对话评分系统初始化成功")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"⚠ 对话评分系统初始化失败: {e}")
|
|
|
|
|
|
self.enable_scoring = False
|
|
|
|
|
|
|
|
|
|
|
|
def score_dialogue_turn(self, dialogue_content: str, speaker: str, dialogue_history: List[DialogueTurn]) -> Tuple[float, str, str]:
|
|
|
|
|
|
"""对单条对话进行评分
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
dialogue_content: 对话内容
|
|
|
|
|
|
speaker: 说话者
|
|
|
|
|
|
dialogue_history: 对话历史
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
tuple: (总分, 详细分数JSON, 反馈意见)
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not self.enable_scoring or not self.scorer:
|
|
|
|
|
|
return 0.0, "{}", "评分系统未启用"
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 获取角色数据
|
|
|
|
|
|
character_data = self.kb.character_data.get(speaker, {})
|
|
|
|
|
|
|
|
|
|
|
|
# 转换对话历史格式
|
|
|
|
|
|
history_for_scoring = []
|
|
|
|
|
|
for turn in dialogue_history[-5:]: # 最近5轮对话
|
|
|
|
|
|
history_for_scoring.append({
|
|
|
|
|
|
'speaker': turn.speaker,
|
|
|
|
|
|
'content': turn.content
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 进行AI评分
|
|
|
|
|
|
score_result = self.scorer.score_dialogue(
|
|
|
|
|
|
dialogue_content=dialogue_content,
|
|
|
|
|
|
speaker=speaker,
|
|
|
|
|
|
character_data=character_data,
|
|
|
|
|
|
dialogue_history=history_for_scoring,
|
|
|
|
|
|
context_info=[]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 返回评分结果
|
|
|
|
|
|
return score_result.overall_score, json.dumps(score_result.scores), score_result.feedback
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"⚠ 对话评分失败: {e}")
|
|
|
|
|
|
return 0.0, "{}", f"评分失败: {str(e)}"
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-15 14:42:13 +08:00
|
|
|
|
def generate_character_prompt(self, character_name: str, context_info: List[Dict], dialogue_history: List[DialogueTurn],
|
|
|
|
|
|
history_context_count: int = 3, context_info_count: int = 2) -> str:
|
|
|
|
|
|
"""为角色生成对话提示
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
character_name: 角色名称
|
|
|
|
|
|
context_info: 相关上下文信息
|
|
|
|
|
|
dialogue_history: 对话历史
|
|
|
|
|
|
history_context_count: 使用的历史对话轮数(默认3轮)
|
|
|
|
|
|
context_info_count: 使用的上下文信息数量(默认2个)
|
|
|
|
|
|
"""
|
2025-08-14 07:17:50 +08:00
|
|
|
|
char_data = self.kb.character_data.get(character_name, {})
|
|
|
|
|
|
|
|
|
|
|
|
# 基础角色设定
|
|
|
|
|
|
prompt_parts = []
|
|
|
|
|
|
prompt_parts.append(f"你是{character_name},具有以下设定:")
|
|
|
|
|
|
|
|
|
|
|
|
if char_data.get('personality', {}).get('core_traits'):
|
|
|
|
|
|
traits = ", ".join(char_data['personality']['core_traits'])
|
|
|
|
|
|
prompt_parts.append(f"性格特点:{traits}")
|
|
|
|
|
|
|
|
|
|
|
|
if char_data.get('speech_patterns', {}).get('sample_phrases'):
|
|
|
|
|
|
phrases = char_data['speech_patterns']['sample_phrases'][:3]
|
|
|
|
|
|
prompt_parts.append(f"说话风格示例:{'; '.join(phrases)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 当前情境
|
|
|
|
|
|
if char_data.get('current_situation'):
|
|
|
|
|
|
situation = char_data['current_situation']
|
|
|
|
|
|
prompt_parts.append(f"当前状态:{situation.get('current_mood', '')}")
|
|
|
|
|
|
|
2025-08-15 14:42:13 +08:00
|
|
|
|
# 相关世界观信息(可控制数量)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
if context_info:
|
|
|
|
|
|
prompt_parts.append("相关背景信息:")
|
2025-08-15 14:42:13 +08:00
|
|
|
|
for info in context_info[:context_info_count]:
|
2025-08-14 07:17:50 +08:00
|
|
|
|
content = info['content'][:200] + "..." if len(info['content']) > 200 else info['content']
|
|
|
|
|
|
prompt_parts.append(f"- {content}")
|
|
|
|
|
|
|
2025-08-15 14:42:13 +08:00
|
|
|
|
# 对话历史(可控制数量)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
if dialogue_history:
|
|
|
|
|
|
prompt_parts.append("最近的对话:")
|
2025-08-15 14:42:13 +08:00
|
|
|
|
# 使用参数控制历史对话轮数
|
|
|
|
|
|
history_to_use = dialogue_history[-history_context_count:] if history_context_count > 0 else []
|
|
|
|
|
|
for turn in history_to_use:
|
2025-08-14 07:17:50 +08:00
|
|
|
|
prompt_parts.append(f"{turn.speaker}: {turn.content}")
|
|
|
|
|
|
|
|
|
|
|
|
prompt_parts.append("\n请根据角色设定和上下文,生成符合角色特点的自然对话。回复应该在50-150字之间。")
|
|
|
|
|
|
|
|
|
|
|
|
return "\n".join(prompt_parts)
|
|
|
|
|
|
|
2025-08-15 14:42:13 +08:00
|
|
|
|
def generate_dialogue(self, session_id: str, current_speaker: str, topic_hint: str = "",
|
|
|
|
|
|
history_context_count: int = 3, context_info_count: int = 2) -> Tuple[str, List[str]]:
|
|
|
|
|
|
"""生成角色对话
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
session_id: 会话ID
|
|
|
|
|
|
current_speaker: 当前说话者
|
|
|
|
|
|
topic_hint: 话题提示
|
|
|
|
|
|
history_context_count: 使用的历史对话轮数(默认3轮)
|
|
|
|
|
|
context_info_count: 使用的上下文信息数量(默认2个)
|
|
|
|
|
|
"""
|
2025-08-14 07:17:50 +08:00
|
|
|
|
# 获取对话历史
|
|
|
|
|
|
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建搜索查询
|
|
|
|
|
|
if dialogue_history:
|
2025-08-15 14:42:13 +08:00
|
|
|
|
# 基于最近的对话内容(可控制数量)
|
|
|
|
|
|
recent_turns = dialogue_history[-history_context_count:] if history_context_count > 0 else []
|
|
|
|
|
|
recent_content = " ".join([turn.content for turn in recent_turns])
|
2025-08-14 07:17:50 +08:00
|
|
|
|
search_query = recent_content + " " + topic_hint
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 首次对话
|
|
|
|
|
|
search_query = f"{current_speaker} {topic_hint} introduction greeting"
|
|
|
|
|
|
|
|
|
|
|
|
# 搜索相关上下文
|
2025-08-15 16:56:34 +08:00
|
|
|
|
context_info = self.kb.search_relevant_context(search_query, current_speaker, context_info_count)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-15 14:42:13 +08:00
|
|
|
|
# 生成提示(使用参数控制上下文数量)
|
|
|
|
|
|
prompt = self.generate_character_prompt(
|
|
|
|
|
|
current_speaker,
|
|
|
|
|
|
context_info,
|
|
|
|
|
|
dialogue_history,
|
|
|
|
|
|
history_context_count,
|
|
|
|
|
|
context_info_count
|
|
|
|
|
|
)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-15 17:58:11 +08:00
|
|
|
|
# 生成对话 - 使用双模型系统
|
2025-08-14 07:17:50 +08:00
|
|
|
|
try:
|
2025-08-15 17:58:11 +08:00
|
|
|
|
# 检查是否为双模型对话系统
|
|
|
|
|
|
if hasattr(self.llm_generator, 'generate_dual_character_dialogue'):
|
|
|
|
|
|
# 使用双模型系统
|
|
|
|
|
|
response = self.llm_generator.generate_dual_character_dialogue(
|
|
|
|
|
|
current_speaker,
|
|
|
|
|
|
prompt,
|
|
|
|
|
|
topic_hint or "请继续对话",
|
|
|
|
|
|
temperature=0.8,
|
|
|
|
|
|
max_new_tokens=150
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 兼容旧的单模型系统
|
|
|
|
|
|
response = self.llm_generator.generate_character_dialogue(
|
|
|
|
|
|
current_speaker,
|
|
|
|
|
|
prompt,
|
|
|
|
|
|
topic_hint or "请继续对话",
|
|
|
|
|
|
temperature=0.8,
|
|
|
|
|
|
max_new_tokens=150
|
|
|
|
|
|
)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
|
|
|
|
|
# 记录使用的上下文
|
2025-08-15 14:42:13 +08:00
|
|
|
|
context_used = [f"{info['section']}.{info['subsection']}" for info in context_info[:context_info_count]]
|
|
|
|
|
|
avg_relevance = sum(info['relevance_score'] for info in context_info[:context_info_count]) / len(context_info[:context_info_count]) if context_info else 0.0
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-23 17:10:23 +08:00
|
|
|
|
# 对对话进行评分
|
|
|
|
|
|
if self.enable_scoring:
|
|
|
|
|
|
dialogue_score, score_details, score_feedback = self.score_dialogue_turn(response, current_speaker, dialogue_history)
|
|
|
|
|
|
print(f" [评分: {dialogue_score:.2f}] {score_feedback}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
dialogue_score, score_details, score_feedback = 0.0, "{}", ""
|
|
|
|
|
|
|
|
|
|
|
|
# 保存对话轮次(包含评分信息)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
self.conv_mgr.add_dialogue_turn(
|
2025-08-23 17:10:23 +08:00
|
|
|
|
session_id, current_speaker, response, context_used, avg_relevance,
|
|
|
|
|
|
dialogue_score, score_details, score_feedback
|
2025-08-14 07:17:50 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return response, context_used
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"✗ 对话生成失败: {e}")
|
|
|
|
|
|
return f"[{current_speaker}暂时无法回应]", []
|
|
|
|
|
|
|
2025-08-15 14:42:13 +08:00
|
|
|
|
def run_conversation_turn(self, session_id: str, characters: List[str], turns_count: int = 1, topic: str = "",
|
|
|
|
|
|
history_context_count: int = 3, context_info_count: int = 2):
|
|
|
|
|
|
"""运行对话轮次
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
session_id: 会话ID
|
|
|
|
|
|
characters: 角色列表
|
|
|
|
|
|
turns_count: 对话轮数
|
|
|
|
|
|
topic: 对话主题
|
|
|
|
|
|
history_context_count: 使用的历史对话轮数(默认3轮)
|
|
|
|
|
|
context_info_count: 使用的上下文信息数量(默认2个)
|
|
|
|
|
|
"""
|
2025-08-14 07:17:50 +08:00
|
|
|
|
results = []
|
2025-08-15 14:42:13 +08:00
|
|
|
|
print(f" [上下文设置: 历史{history_context_count}轮, 信息{context_info_count}个]")
|
2025-08-14 07:17:50 +08:00
|
|
|
|
for i in range(turns_count):
|
|
|
|
|
|
for char in characters:
|
2025-08-15 14:42:13 +08:00
|
|
|
|
response, context_used = self.generate_dialogue(
|
|
|
|
|
|
session_id,
|
|
|
|
|
|
char,
|
|
|
|
|
|
topic,
|
|
|
|
|
|
history_context_count,
|
|
|
|
|
|
context_info_count
|
|
|
|
|
|
)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
results.append({
|
|
|
|
|
|
"speaker": char,
|
|
|
|
|
|
"content": response,
|
|
|
|
|
|
"context_used": context_used,
|
2025-08-15 14:42:13 +08:00
|
|
|
|
"turn": i + 1,
|
|
|
|
|
|
"context_settings": {
|
|
|
|
|
|
"history_count": history_context_count,
|
|
|
|
|
|
"context_info_count": context_info_count
|
|
|
|
|
|
}
|
2025-08-14 07:17:50 +08:00
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
print(f"{char}: {response}")
|
2025-08-15 14:42:13 +08:00
|
|
|
|
# if context_used:
|
|
|
|
|
|
# print(f" [使用上下文: {', '.join(context_used)}]")
|
|
|
|
|
|
|
2025-08-14 07:17:50 +08:00
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
|
return results
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
|
|
|
|
|
def run_dual_model_conversation(self, session_id: str, topic: str = "", turns: int = 4,
|
|
|
|
|
|
history_context_count: int = 3, context_info_count: int = 2):
|
|
|
|
|
|
"""使用双模型系统运行对话
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
session_id: 会话ID
|
|
|
|
|
|
topic: 对话主题
|
|
|
|
|
|
turns: 对话轮数
|
|
|
|
|
|
history_context_count: 使用的历史对话轮数
|
|
|
|
|
|
context_info_count: 使用的上下文信息数量
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 检查是否为双模型对话系统
|
|
|
|
|
|
if not hasattr(self.llm_generator, 'run_dual_character_conversation'):
|
|
|
|
|
|
print("⚠ 当前系统不支持双模型对话")
|
|
|
|
|
|
return self.run_conversation_turn(session_id, self.llm_generator.list_characters(), turns, topic,
|
|
|
|
|
|
history_context_count, context_info_count)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取对话历史
|
|
|
|
|
|
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建上下文信息
|
|
|
|
|
|
if dialogue_history:
|
|
|
|
|
|
recent_turns = dialogue_history[-history_context_count:] if history_context_count > 0 else []
|
|
|
|
|
|
recent_content = " ".join([turn.content for turn in recent_turns])
|
|
|
|
|
|
search_query = recent_content + " " + topic
|
|
|
|
|
|
else:
|
|
|
|
|
|
search_query = f"{topic} introduction greeting"
|
|
|
|
|
|
|
|
|
|
|
|
# 搜索相关上下文
|
|
|
|
|
|
context_info = self.kb.search_relevant_context(search_query, top_k=context_info_count)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建上下文字符串
|
|
|
|
|
|
context_str = ""
|
|
|
|
|
|
if context_info:
|
|
|
|
|
|
context_str = "相关背景信息:"
|
|
|
|
|
|
for info in context_info[:context_info_count]:
|
|
|
|
|
|
content = info['content'][:150] + "..." if len(info['content']) > 150 else info['content']
|
|
|
|
|
|
context_str += f"\n- {content}"
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n=== 双模型对话系统 ===")
|
|
|
|
|
|
print(f"主题: {topic}")
|
|
|
|
|
|
print(f"角色: {', '.join(self.llm_generator.list_characters())}")
|
|
|
|
|
|
print(f"轮数: {turns}")
|
|
|
|
|
|
print(f"上下文设置: 历史{history_context_count}轮, 信息{context_info_count}个")
|
|
|
|
|
|
|
|
|
|
|
|
# 使用双模型系统生成对话
|
2025-08-18 14:32:55 +08:00
|
|
|
|
for turn in range(turns):
|
|
|
|
|
|
# 获取对话历史
|
|
|
|
|
|
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
|
|
|
|
|
|
conversation_results = self.llm_generator.run_dual_character_conversation(
|
|
|
|
|
|
topic=topic,
|
|
|
|
|
|
turn_index = turn,
|
|
|
|
|
|
context=context_str,
|
|
|
|
|
|
dialogue_history = dialogue_history,
|
|
|
|
|
|
history_context_count = history_context_count,
|
|
|
|
|
|
max_new_tokens=150
|
2025-08-15 17:58:11 +08:00
|
|
|
|
)
|
2025-08-18 14:32:55 +08:00
|
|
|
|
|
2025-08-23 17:10:23 +08:00
|
|
|
|
# 保存对话到数据库并进行评分
|
2025-08-18 14:32:55 +08:00
|
|
|
|
for result in conversation_results:
|
2025-08-23 17:10:23 +08:00
|
|
|
|
# 获取当前对话历史进行评分
|
|
|
|
|
|
current_dialogue_history = self.conv_mgr.get_conversation_history(session_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 对对话进行评分
|
|
|
|
|
|
if self.enable_scoring:
|
|
|
|
|
|
dialogue_score, score_details, score_feedback = self.score_dialogue_turn(
|
|
|
|
|
|
result['dialogue'], result['speaker'], current_dialogue_history
|
|
|
|
|
|
)
|
|
|
|
|
|
print(f" [评分: {dialogue_score:.2f}] {score_feedback[:100]}...")
|
|
|
|
|
|
else:
|
|
|
|
|
|
dialogue_score, score_details, score_feedback = 0.0, "{}", ""
|
|
|
|
|
|
|
2025-08-18 14:32:55 +08:00
|
|
|
|
self.conv_mgr.add_dialogue_turn(
|
|
|
|
|
|
session_id,
|
|
|
|
|
|
result['speaker'],
|
|
|
|
|
|
result['dialogue'],
|
|
|
|
|
|
[result.get('context_used', '')],
|
2025-08-23 17:10:23 +08:00
|
|
|
|
0.8, # 默认相关性分数
|
|
|
|
|
|
dialogue_score,
|
|
|
|
|
|
score_details,
|
|
|
|
|
|
score_feedback
|
2025-08-18 14:32:55 +08:00
|
|
|
|
)
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
2025-08-18 14:32:55 +08:00
|
|
|
|
|
2025-08-15 17:58:11 +08:00
|
|
|
|
return conversation_results
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# def main():
|
|
|
|
|
|
# """主函数 - 演示系统使用"""
|
|
|
|
|
|
# print("=== RAG增强双AI角色对话系统 ===")
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# # 设置路径
|
|
|
|
|
|
# knowledge_dir = "./knowledge_base" # 包含世界观和角色文档的目录
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# # 检查必要文件
|
|
|
|
|
|
# required_dirs = [knowledge_dir]
|
|
|
|
|
|
# for dir_path in required_dirs:
|
|
|
|
|
|
# if not os.path.exists(dir_path):
|
|
|
|
|
|
# print(f"✗ 目录不存在: {dir_path}")
|
|
|
|
|
|
# print("请确保以下文件存在:")
|
|
|
|
|
|
# print("- ./knowledge_base/worldview_template_coc.json")
|
|
|
|
|
|
# print("- ./knowledge_base/character_template_detective.json")
|
|
|
|
|
|
# print("- ./knowledge_base/character_template_professor.json")
|
|
|
|
|
|
# return
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# try:
|
|
|
|
|
|
# # 初始化系统组件
|
|
|
|
|
|
# print("\n初始化系统...")
|
|
|
|
|
|
# kb = RAGKnowledgeBase(knowledge_dir)
|
|
|
|
|
|
# conv_mgr = ConversationManager()
|
|
|
|
|
|
|
|
|
|
|
|
# # 这里需要你的LLM生成器,使用新的双模型对话系统
|
|
|
|
|
|
# from npc_dialogue_generator import DualModelDialogueGenerator
|
|
|
|
|
|
# base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' # 根据你的路径调整
|
|
|
|
|
|
# lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
|
|
|
|
|
|
|
|
|
|
|
# if not os.path.exists(lora_model_path):
|
|
|
|
|
|
# lora_model_path = None
|
|
|
|
|
|
|
|
|
|
|
|
# # 创建双模型对话生成器
|
|
|
|
|
|
# if hasattr(kb, 'character_data') and len(kb.character_data) >= 2:
|
|
|
|
|
|
# print("✓ 使用knowledge_base角色数据创建双模型对话系统")
|
|
|
|
|
|
# # 获取前两个角色
|
|
|
|
|
|
# character_names = list(kb.character_data.keys())[:2]
|
|
|
|
|
|
# char1_name = character_names[0]
|
|
|
|
|
|
# char2_name = character_names[1]
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# # 配置两个角色的模型
|
|
|
|
|
|
# character1_config = {
|
|
|
|
|
|
# "name": char1_name,
|
|
|
|
|
|
# "lora_path": lora_model_path, # 可以为每个角色设置不同的LoRA
|
|
|
|
|
|
# "character_data": kb.character_data[char1_name]
|
|
|
|
|
|
# }
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# character2_config = {
|
|
|
|
|
|
# "name": char2_name,
|
|
|
|
|
|
# "lora_path": lora_model_path, # 可以为每个角色设置不同的LoRA
|
|
|
|
|
|
# "character_data": kb.character_data[char2_name]
|
|
|
|
|
|
# }
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# llm_generator = DualModelDialogueGenerator(
|
|
|
|
|
|
# base_model_path,
|
|
|
|
|
|
# character1_config,
|
|
|
|
|
|
# character2_config
|
|
|
|
|
|
# )
|
|
|
|
|
|
# else:
|
|
|
|
|
|
# print("⚠ 角色数据不足,无法创建双模型对话系统")
|
|
|
|
|
|
# return
|
|
|
|
|
|
|
|
|
|
|
|
# # 创建对话引擎
|
|
|
|
|
|
# dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator)
|
|
|
|
|
|
|
|
|
|
|
|
# print("✓ 系统初始化完成")
|
|
|
|
|
|
|
|
|
|
|
|
# # 交互式菜单
|
|
|
|
|
|
# while True:
|
|
|
|
|
|
# print("\n" + "="*50)
|
|
|
|
|
|
# print("双AI角色对话系统")
|
|
|
|
|
|
# print("1. 创建新对话")
|
|
|
|
|
|
# print("2. 继续已有对话")
|
|
|
|
|
|
# print("3. 查看对话历史")
|
|
|
|
|
|
# print("4. 列出所有会话")
|
|
|
|
|
|
# print("0. 退出")
|
|
|
|
|
|
# print("="*50)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# choice = input("请选择操作: ").strip()
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# if choice == '0':
|
|
|
|
|
|
# break
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# elif choice == '1':
|
|
|
|
|
|
# # 创建新对话
|
|
|
|
|
|
# print(f"可用角色: {list(kb.character_data.keys())}")
|
|
|
|
|
|
# characters = input("请输入两个角色名称(用空格分隔): ").strip().split()
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# if len(characters) != 2:
|
|
|
|
|
|
# print("❌ 请输入正好两个角色名称")
|
|
|
|
|
|
# continue
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观'
|
|
|
|
|
|
# session_id = conv_mgr.create_session(characters, worldview)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# topic = input("请输入对话主题(可选): ").strip()
|
|
|
|
|
|
# turns = int(input("请输入对话轮次数量(默认2): ").strip() or "2")
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# # 历史上下文控制选项
|
|
|
|
|
|
# print("\n历史上下文设置:")
|
|
|
|
|
|
# history_count = input("使用历史对话轮数(默认3,0表示不使用): ").strip()
|
|
|
|
|
|
# history_count = int(history_count) if history_count.isdigit() else 3
|
2025-08-15 14:42:13 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# context_info_count = input("使用上下文信息数量(默认2): ").strip()
|
|
|
|
|
|
# context_info_count = int(context_info_count) if context_info_count.isdigit() else 2
|
2025-08-15 14:42:13 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# print(f"\n开始对话 - 会话ID: {session_id}")
|
|
|
|
|
|
# print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}个")
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# # 询问是否使用双模型对话
|
|
|
|
|
|
# use_dual_model = input("是否使用双模型对话系统?(y/n,默认y): ").strip().lower()
|
|
|
|
|
|
# if use_dual_model != 'n':
|
|
|
|
|
|
# print("使用双模型对话系统...")
|
|
|
|
|
|
# dialogue_engine.run_dual_model_conversation(session_id, topic, turns, history_count, context_info_count)
|
|
|
|
|
|
# else:
|
|
|
|
|
|
# print("使用传统对话系统...")
|
|
|
|
|
|
# dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# elif choice == '2':
|
|
|
|
|
|
# # 继续已有对话
|
|
|
|
|
|
# sessions = conv_mgr.list_sessions()
|
|
|
|
|
|
# if not sessions:
|
|
|
|
|
|
# print("❌ 没有已有对话")
|
|
|
|
|
|
# continue
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# print("已有会话:")
|
|
|
|
|
|
# for i, session in enumerate(sessions[:5]):
|
|
|
|
|
|
# chars = ", ".join(session['characters'])
|
|
|
|
|
|
# print(f"{i+1}. {session['session_id'][:8]}... ({chars}) - {session['last_update'][:16]}")
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# try:
|
|
|
|
|
|
# idx = int(input("请选择会话编号: ").strip()) - 1
|
|
|
|
|
|
# if 0 <= idx < len(sessions):
|
|
|
|
|
|
# session = sessions[idx]
|
|
|
|
|
|
# session_id = session['session_id']
|
|
|
|
|
|
# characters = session['characters']
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# # 显示最近的对话
|
|
|
|
|
|
# history = conv_mgr.get_conversation_history(session_id, 4)
|
|
|
|
|
|
# if history:
|
|
|
|
|
|
# print("\n最近的对话:")
|
|
|
|
|
|
# for turn in history:
|
|
|
|
|
|
# print(f"{turn.speaker}: {turn.content}")
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# topic = input("请输入对话主题(可选): ").strip()
|
|
|
|
|
|
# turns = int(input("请输入对话轮次数量(默认1): ").strip() or "1")
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# # 历史上下文控制选项
|
|
|
|
|
|
# print("\n历史上下文设置:")
|
|
|
|
|
|
# history_count = input("使用历史对话轮数(默认3,0表示不使用): ").strip()
|
|
|
|
|
|
# history_count = int(history_count) if history_count.isdigit() else 3
|
2025-08-15 14:42:13 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# context_info_count = input("使用上下文信息数量(默认2): ").strip()
|
|
|
|
|
|
# context_info_count = int(context_info_count) if context_info_count.isdigit() else 2
|
2025-08-15 14:42:13 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# print(f"\n继续对话 - 会话ID: {session_id}")
|
|
|
|
|
|
# print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}个")
|
2025-08-15 17:58:11 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# # 询问是否使用双模型对话
|
|
|
|
|
|
# use_dual_model = input("是否使用双模型对话系统?(y/n,默认y): ").strip().lower()
|
|
|
|
|
|
# if use_dual_model != 'n':
|
|
|
|
|
|
# print("使用双模型对话系统...")
|
|
|
|
|
|
# dialogue_engine.run_dual_model_conversation(session_id, topic, turns, history_count, context_info_count)
|
|
|
|
|
|
# else:
|
|
|
|
|
|
# print("使用传统对话系统...")
|
|
|
|
|
|
# dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
|
|
|
|
|
|
# else:
|
|
|
|
|
|
# print("❌ 无效的会话编号")
|
|
|
|
|
|
# except ValueError:
|
|
|
|
|
|
# print("❌ 请输入有效的数字")
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# elif choice == '3':
|
|
|
|
|
|
# # 查看对话历史
|
|
|
|
|
|
# session_id = input("请输入会话ID(前8位即可): ").strip()
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# # 查找匹配的会话
|
|
|
|
|
|
# sessions = conv_mgr.list_sessions()
|
|
|
|
|
|
# matching_session = None
|
|
|
|
|
|
# for session in sessions:
|
|
|
|
|
|
# if session['session_id'].startswith(session_id):
|
|
|
|
|
|
# matching_session = session
|
|
|
|
|
|
# break
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# if matching_session:
|
|
|
|
|
|
# full_session_id = matching_session['session_id']
|
|
|
|
|
|
# history = conv_mgr.get_conversation_history(full_session_id, 20)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# if history:
|
|
|
|
|
|
# print(f"\n对话历史 - {full_session_id}")
|
|
|
|
|
|
# print(f"角色: {', '.join(matching_session['characters'])}")
|
|
|
|
|
|
# print(f"世界观: {matching_session['worldview']}")
|
|
|
|
|
|
# print("-" * 50)
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# for turn in history:
|
|
|
|
|
|
# print(f"[{turn.timestamp[:16]}] {turn.speaker}:")
|
|
|
|
|
|
# print(f" {turn.content}")
|
|
|
|
|
|
# if turn.context_used:
|
|
|
|
|
|
# print(f" 使用上下文: {', '.join(turn.context_used)}")
|
|
|
|
|
|
# print()
|
|
|
|
|
|
# else:
|
|
|
|
|
|
# print("该会话暂无对话历史")
|
|
|
|
|
|
# else:
|
|
|
|
|
|
# print("❌ 未找到匹配的会话")
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# elif choice == '4':
|
|
|
|
|
|
# # 列出所有会话
|
|
|
|
|
|
# sessions = conv_mgr.list_sessions()
|
|
|
|
|
|
# if sessions:
|
|
|
|
|
|
# print(f"\n共有 {len(sessions)} 个对话会话:")
|
|
|
|
|
|
# for session in sessions:
|
|
|
|
|
|
# chars = ", ".join(session['characters'])
|
|
|
|
|
|
# print(f"ID: {session['session_id']}")
|
|
|
|
|
|
# print(f" 角色: {chars}")
|
|
|
|
|
|
# print(f" 世界观: {session['worldview']}")
|
|
|
|
|
|
# print(f" 最后更新: {session['last_update']}")
|
|
|
|
|
|
# print()
|
|
|
|
|
|
# else:
|
|
|
|
|
|
# print("暂无对话会话")
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# else:
|
|
|
|
|
|
# print("❌ 无效选择")
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# except Exception as e:
|
|
|
|
|
|
# print(f"✗ 系统运行出错: {e}")
|
|
|
|
|
|
# import traceback
|
|
|
|
|
|
# traceback.print_exc()
|
2025-08-14 07:17:50 +08:00
|
|
|
|
|
2025-08-18 09:55:18 +08:00
|
|
|
|
# if __name__ == '__main__':
|
|
|
|
|
|
# main()
|