Project02/AITrain/dual_ai_dialogue_system.py

546 lines
23 KiB
Python
Raw Normal View History

2025-08-14 07:17:50 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
RAG增强的角色对话系统
集成世界观知识库支持角色设定加载和对话生成
'''
import json
import os
import sqlite3
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, asdict
import hashlib
# 尝试导入向量化相关库
try:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
EMBEDDING_AVAILABLE = True
except ImportError:
EMBEDDING_AVAILABLE = False
@dataclass
class DialogueTurn:
"""对话轮次数据结构"""
speaker: str
content: str
timestamp: str
context_used: List[str] # 使用的上下文信息
relevance_score: float = 0.0
@dataclass
class ConversationSession:
"""对话会话数据结构"""
session_id: str
characters: List[str]
worldview: str
start_time: str
last_update: str
dialogue_history: List[DialogueTurn]
class RAGKnowledgeBase:
"""RAG知识库管理器"""
def __init__(self, knowledge_dir: str):
self.knowledge_dir = knowledge_dir
self.worldview_data = None
self.character_data = {}
self.chunks = []
self.embedding_model = None
self.index = None
# 初始化向量模型
if EMBEDDING_AVAILABLE:
try:
self.embedding_model = SentenceTransformer('./sentence-transformers/all-MiniLM-L6-v2')
print("✓ 向量模型加载成功")
except Exception as e:
print(f"✗ 向量模型加载失败: {e}")
self._load_knowledge_base()
def _load_knowledge_base(self):
"""加载知识库"""
2025-08-15 14:42:13 +08:00
# 优先加载RAG知识库作为世界观
rag_worldview_path = "./rag_knowledge/knowledge_base.json"
if os.path.exists(rag_worldview_path):
try:
with open(rag_worldview_path, 'r', encoding='utf-8') as f:
rag_data = json.load(f)
# 从RAG数据中提取世界观信息
self.worldview_data = {
"worldview_name": "克苏鲁神话世界观 (RAG)",
"source": rag_data.get("metadata", {}).get("source_file", "未知"),
"description": f"基于{rag_data.get('metadata', {}).get('source_file', 'PDF文档')}的RAG知识库",
"total_chunks": rag_data.get("metadata", {}).get("total_chunks", 0),
"total_concepts": rag_data.get("metadata", {}).get("total_concepts", 0),
"rag_enabled": True
}
# 保存RAG数据用于检索
self.rag_chunks = rag_data.get("chunks", [])
print(f"✓ RAG世界观加载成功: {self.worldview_data['worldview_name']}")
print(f" - 文档块数: {self.worldview_data['total_chunks']}")
print(f" - 概念数: {self.worldview_data['total_concepts']}")
except Exception as e:
print(f"✗ RAG世界观加载失败: {e}")
self.rag_chunks = []
# 如果没有RAG知识库则加载传统世界观文件
if not hasattr(self, 'rag_chunks') or not self.rag_chunks:
worldview_files = [f for f in os.listdir(self.knowledge_dir)
if f.startswith('worldview') and f.endswith('.json')]
if worldview_files:
worldview_path = os.path.join(self.knowledge_dir, worldview_files[0])
with open(worldview_path, 'r', encoding='utf-8') as f:
self.worldview_data = json.load(f)
print(f"✓ 传统世界观加载成功: {self.worldview_data.get('worldview_name', '未知')}")
2025-08-14 07:17:50 +08:00
# 加载角色数据
character_files = [f for f in os.listdir(self.knowledge_dir)
if f.startswith('character') and f.endswith('.json')]
for char_file in character_files:
char_path = os.path.join(self.knowledge_dir, char_file)
with open(char_path, 'r', encoding='utf-8') as f:
char_data = json.load(f)
char_name = char_data.get('character_name', char_file)
self.character_data[char_name] = char_data
print(f"✓ 角色加载成功: {list(self.character_data.keys())}")
# 构建检索用的文本块
self._build_searchable_chunks()
# 构建向量索引
if EMBEDDING_AVAILABLE and self.embedding_model:
self._build_vector_index()
def _build_searchable_chunks(self):
"""构建可检索的文本块"""
self.chunks = []
2025-08-15 14:42:13 +08:00
# 优先使用RAG知识库的文本块
if hasattr(self, 'rag_chunks') and self.rag_chunks:
for rag_chunk in self.rag_chunks:
self.chunks.append({
"type": "worldview_rag",
"section": "rag_knowledge",
"subsection": rag_chunk.get("type", "unknown"),
"content": rag_chunk.get("content", ""),
"metadata": {
"source": "rag_worldview",
"chunk_id": rag_chunk.get("id", ""),
"size": rag_chunk.get("size", 0),
"hash": rag_chunk.get("hash", "")
}
})
print(f"✓ 使用RAG知识库文本块: {len(self.rag_chunks)}")
else:
# 传统世界观相关文本块
if self.worldview_data:
for section_key, section_data in self.worldview_data.items():
if isinstance(section_data, dict):
for sub_key, sub_data in section_data.items():
if isinstance(sub_data, (str, list)):
content = str(sub_data)
if len(content) > 50: # 只保留有意义的文本
self.chunks.append({
"type": "worldview",
"section": section_key,
"subsection": sub_key,
"content": content,
"metadata": {"source": "worldview"}
})
2025-08-14 07:17:50 +08:00
# 角色相关文本块
for char_name, char_data in self.character_data.items():
for section_key, section_data in char_data.items():
if isinstance(section_data, dict):
for sub_key, sub_data in section_data.items():
if isinstance(sub_data, (str, list)):
content = str(sub_data)
if len(content) > 30:
self.chunks.append({
"type": "character",
"character": char_name,
"section": section_key,
"subsection": sub_key,
"content": content,
"metadata": {"source": char_name}
})
print(f"✓ 构建文本块: {len(self.chunks)}")
def _build_vector_index(self):
"""构建向量索引"""
try:
2025-08-15 14:42:13 +08:00
# 优先使用RAG知识库的预构建向量索引
rag_vector_path = "./rag_knowledge/vector_index.faiss"
rag_embeddings_path = "./rag_knowledge/embeddings.npy"
if os.path.exists(rag_vector_path) and os.path.exists(rag_embeddings_path):
# 加载预构建的向量索引
self.index = faiss.read_index(rag_vector_path)
self.rag_embeddings = np.load(rag_embeddings_path)
print(f"✓ 使用RAG预构建向量索引: {self.index.ntotal}个向量")
return
# 如果没有预构建的向量索引,则重新构建
2025-08-14 07:17:50 +08:00
texts = [chunk["content"] for chunk in self.chunks]
embeddings = self.embedding_model.encode(texts)
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(embeddings.astype(np.float32))
print(f"✓ 向量索引构建成功: {dimension}维, {len(texts)}个向量")
except Exception as e:
print(f"✗ 向量索引构建失败: {e}")
def search_relevant_context(self, query: str, character_name: str = None, top_k: int = 3) -> List[Dict]:
"""搜索相关上下文"""
relevant_chunks = []
# 向量搜索
if EMBEDDING_AVAILABLE and self.embedding_model and self.index:
try:
2025-08-15 14:42:13 +08:00
# 如果使用RAG预构建向量索引直接搜索
if hasattr(self, 'rag_embeddings'):
query_vector = self.embedding_model.encode([query])
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
for distance, idx in zip(distances[0], indices[0]):
if idx < len(self.chunks):
chunk = self.chunks[idx].copy()
chunk["relevance_score"] = float(1 / (1 + distance))
relevant_chunks.append(chunk)
else:
# 传统向量搜索
query_vector = self.embedding_model.encode([query])
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
for distance, idx in zip(distances[0], indices[0]):
if idx < len(self.chunks):
chunk = self.chunks[idx].copy()
chunk["relevance_score"] = float(1 / (1 + distance))
relevant_chunks.append(chunk)
2025-08-14 07:17:50 +08:00
except Exception as e:
print(f"向量搜索失败: {e}")
# 文本搜索作为备选
if not relevant_chunks:
query_lower = query.lower()
for chunk in self.chunks:
content_lower = chunk["content"].lower()
score = 0
for word in query_lower.split():
if word in content_lower:
score += content_lower.count(word)
if score > 0:
chunk_copy = chunk.copy()
chunk_copy["relevance_score"] = score
relevant_chunks.append(chunk_copy)
# 按相关性排序
relevant_chunks.sort(key=lambda x: x["relevance_score"], reverse=True)
# 优先返回特定角色的相关信息
if character_name:
char_chunks = [c for c in relevant_chunks if c.get("character") == character_name]
other_chunks = [c for c in relevant_chunks if c.get("character") != character_name]
relevant_chunks = char_chunks + other_chunks
return relevant_chunks[:top_k]
class ConversationManager:
"""对话管理器"""
def __init__(self, db_path: str = "conversation_history.db"):
self.db_path = db_path
self._init_database()
def _init_database(self):
"""初始化对话历史数据库"""
with sqlite3.connect(self.db_path) as conn:
conn.execute('''
CREATE TABLE IF NOT EXISTS conversations (
session_id TEXT PRIMARY KEY,
characters TEXT,
worldview TEXT,
start_time TEXT,
last_update TEXT,
metadata TEXT
)
''')
conn.execute('''
CREATE TABLE IF NOT EXISTS dialogue_turns (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT,
turn_number INTEGER,
speaker TEXT,
content TEXT,
timestamp TEXT,
context_used TEXT,
relevance_score REAL,
2025-08-23 17:10:23 +08:00
dialogue_score REAL DEFAULT 0.0,
score_details TEXT,
score_feedback TEXT,
2025-08-14 07:17:50 +08:00
FOREIGN KEY (session_id) REFERENCES conversations (session_id)
)
''')
conn.commit()
def create_session(self, characters: List[str], worldview: str) -> str:
"""创建新的对话会话"""
session_id = hashlib.md5(f"{'-'.join(characters)}-{datetime.now().isoformat()}".encode()).hexdigest()[:12]
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"INSERT INTO conversations (session_id, characters, worldview, start_time, last_update) VALUES (?, ?, ?, ?, ?)",
(session_id, json.dumps(characters), worldview, datetime.now().isoformat(), datetime.now().isoformat())
)
conn.commit()
print(f"✓ 创建对话会话: {session_id}")
return session_id
2025-08-23 17:10:23 +08:00
def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None,
relevance_score: float = 0.0, dialogue_score: float = 0.0,
score_details: str = None, score_feedback: str = None):
2025-08-14 07:17:50 +08:00
"""添加对话轮次"""
if context_used is None:
context_used = []
with sqlite3.connect(self.db_path) as conn:
# 获取当前轮次数
cursor = conn.execute("SELECT COUNT(*) FROM dialogue_turns WHERE session_id = ?", (session_id,))
turn_number = cursor.fetchone()[0] + 1
# 插入对话轮次
conn.execute(
"""INSERT INTO dialogue_turns
2025-08-23 17:10:23 +08:00
(session_id, turn_number, speaker, content, timestamp, context_used, relevance_score,
dialogue_score, score_details, score_feedback)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
2025-08-14 07:17:50 +08:00
(session_id, turn_number, speaker, content, datetime.now().isoformat(),
2025-08-23 17:10:23 +08:00
json.dumps(context_used), relevance_score, dialogue_score, score_details, score_feedback)
2025-08-14 07:17:50 +08:00
)
# 更新会话最后更新时间
conn.execute(
"UPDATE conversations SET last_update = ? WHERE session_id = ?",
(datetime.now().isoformat(), session_id)
)
conn.commit()
def get_conversation_history(self, session_id: str, last_n: int = 10) -> List[DialogueTurn]:
"""获取对话历史"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
2025-08-23 17:10:23 +08:00
"""SELECT speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback
2025-08-14 07:17:50 +08:00
FROM dialogue_turns
WHERE session_id = ?
ORDER BY turn_number DESC LIMIT ?""",
(session_id, last_n)
)
turns = []
for row in cursor.fetchall():
2025-08-23 17:10:23 +08:00
speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback = row
turn = DialogueTurn(
2025-08-14 07:17:50 +08:00
speaker=speaker,
content=content,
timestamp=timestamp,
context_used=json.loads(context_used or "[]"),
relevance_score=relevance_score
2025-08-23 17:10:23 +08:00
)
# 添加评分信息到turn对象
if dialogue_score:
turn.dialogue_score = dialogue_score
if score_feedback:
turn.score_feedback = score_feedback
turns.append(turn)
2025-08-14 07:17:50 +08:00
return list(reversed(turns)) # 按时间正序返回
def list_sessions(self) -> List[Dict]:
"""列出所有对话会话"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"SELECT session_id, characters, worldview, start_time, last_update FROM conversations ORDER BY last_update DESC"
)
sessions = []
for row in cursor.fetchall():
session_id, characters, worldview, start_time, last_update = row
sessions.append({
"session_id": session_id,
"characters": json.loads(characters),
"worldview": worldview,
"start_time": start_time,
"last_update": last_update
})
return sessions
class DualAIDialogueEngine:
"""双AI对话引擎"""
2025-08-23 17:10:23 +08:00
def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator,
enable_scoring: bool = True, base_model_path: str = None):
2025-08-14 07:17:50 +08:00
self.kb = knowledge_base
self.conv_mgr = conversation_manager
self.llm_generator = llm_generator
2025-08-23 17:10:23 +08:00
self.enable_scoring = enable_scoring
self.scorer = None
# 初始化评分器
if enable_scoring and base_model_path:
try:
from dialogue_scorer import DialogueAIScorer
print("正在初始化对话评分系统...")
self.scorer = DialogueAIScorer(
base_model_path=base_model_path,
tokenizer=getattr(llm_generator, 'tokenizer', None),
model=getattr(llm_generator, 'model', None)
)
print("✓ 对话评分系统初始化成功")
except Exception as e:
print(f"⚠ 对话评分系统初始化失败: {e}")
self.enable_scoring = False
def score_dialogue_turn(self, dialogue_content: str, speaker: str, dialogue_history: List[DialogueTurn]) -> Tuple[float, str, str]:
"""对单条对话进行评分
Args:
dialogue_content: 对话内容
speaker: 说话者
dialogue_history: 对话历史
Returns:
tuple: (总分, 详细分数JSON, 反馈意见)
"""
if not self.enable_scoring or not self.scorer:
return 0.0, "{}", "评分系统未启用"
try:
# 获取角色数据
character_data = self.kb.character_data.get(speaker, {})
# 转换对话历史格式
history_for_scoring = []
for turn in dialogue_history[-5:]: # 最近5轮对话
history_for_scoring.append({
'speaker': turn.speaker,
'content': turn.content
})
# 进行AI评分
score_result = self.scorer.score_dialogue(
dialogue_content=dialogue_content,
speaker=speaker,
character_data=character_data,
dialogue_history=history_for_scoring,
context_info=[]
)
# 返回评分结果
return score_result.overall_score, json.dumps(score_result.scores), score_result.feedback
except Exception as e:
print(f"⚠ 对话评分失败: {e}")
return 0.0, "{}", f"评分失败: {str(e)}"
2025-08-14 07:17:50 +08:00
2025-08-23 17:27:01 +08:00
2025-08-15 17:58:11 +08:00
def run_dual_model_conversation(self, session_id: str, topic: str = "", turns: int = 4,
history_context_count: int = 3, context_info_count: int = 2):
"""使用双模型系统运行对话
Args:
session_id: 会话ID
topic: 对话主题
turns: 对话轮数
history_context_count: 使用的历史对话轮数
context_info_count: 使用的上下文信息数量
"""
# 检查是否为双模型对话系统
if not hasattr(self.llm_generator, 'run_dual_character_conversation'):
print("⚠ 当前系统不支持双模型对话")
return self.run_conversation_turn(session_id, self.llm_generator.list_characters(), turns, topic,
history_context_count, context_info_count)
# 获取对话历史
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
# 构建上下文信息
if dialogue_history:
recent_turns = dialogue_history[-history_context_count:] if history_context_count > 0 else []
recent_content = " ".join([turn.content for turn in recent_turns])
search_query = recent_content + " " + topic
else:
search_query = f"{topic} introduction greeting"
# 搜索相关上下文
context_info = self.kb.search_relevant_context(search_query, top_k=context_info_count)
# 构建上下文字符串
context_str = ""
if context_info:
context_str = "相关背景信息:"
for info in context_info[:context_info_count]:
content = info['content'][:150] + "..." if len(info['content']) > 150 else info['content']
context_str += f"\n- {content}"
print(f"\n=== 双模型对话系统 ===")
print(f"主题: {topic}")
print(f"角色: {', '.join(self.llm_generator.list_characters())}")
print(f"轮数: {turns}")
print(f"上下文设置: 历史{history_context_count}轮, 信息{context_info_count}")
# 使用双模型系统生成对话
2025-08-18 14:32:55 +08:00
for turn in range(turns):
# 获取对话历史
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
conversation_results = self.llm_generator.run_dual_character_conversation(
topic=topic,
turn_index = turn,
context=context_str,
dialogue_history = dialogue_history,
history_context_count = history_context_count,
max_new_tokens=150
2025-08-15 17:58:11 +08:00
)
2025-08-18 14:32:55 +08:00
2025-08-23 17:10:23 +08:00
# 保存对话到数据库并进行评分
2025-08-18 14:32:55 +08:00
for result in conversation_results:
2025-08-23 17:10:23 +08:00
# 获取当前对话历史进行评分
current_dialogue_history = self.conv_mgr.get_conversation_history(session_id)
# 对对话进行评分
if self.enable_scoring:
dialogue_score, score_details, score_feedback = self.score_dialogue_turn(
result['dialogue'], result['speaker'], current_dialogue_history
)
print(f" [评分: {dialogue_score:.2f}] {score_feedback[:100]}...")
else:
dialogue_score, score_details, score_feedback = 0.0, "{}", ""
2025-08-18 14:32:55 +08:00
self.conv_mgr.add_dialogue_turn(
session_id,
result['speaker'],
result['dialogue'],
[result.get('context_used', '')],
2025-08-23 17:10:23 +08:00
0.8, # 默认相关性分数
dialogue_score,
score_details,
score_feedback
2025-08-18 14:32:55 +08:00
)
2025-08-15 17:58:11 +08:00
2025-08-18 14:32:55 +08:00
2025-08-15 17:58:11 +08:00
return conversation_results
2025-08-14 07:17:50 +08:00