Project02/AITrain/dual_ai_dialogue_system.py

623 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
RAG增强的角色对话系统
集成世界观知识库,支持角色设定加载和对话生成
'''
import json
import os
import sqlite3
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, asdict
import hashlib
# 尝试导入向量化相关库
try:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
EMBEDDING_AVAILABLE = True
except ImportError:
EMBEDDING_AVAILABLE = False
@dataclass
class DialogueTurn:
"""对话轮次数据结构"""
speaker: str
content: str
timestamp: str
context_used: List[str] # 使用的上下文信息
relevance_score: float = 0.0
@dataclass
class ConversationSession:
"""对话会话数据结构"""
session_id: str
characters: List[str]
worldview: str
start_time: str
last_update: str
dialogue_history: List[DialogueTurn]
class RAGKnowledgeBase:
"""RAG知识库管理器"""
def __init__(self, knowledge_dir: str):
self.knowledge_dir = knowledge_dir
self.worldview_data = None
self.character_data = {}
self.chunks = []
self.embedding_model = None
self.index = None
# 初始化向量模型
if EMBEDDING_AVAILABLE:
try:
self.embedding_model = SentenceTransformer('./sentence-transformers/all-MiniLM-L6-v2')
print("✓ 向量模型加载成功")
except Exception as e:
print(f"✗ 向量模型加载失败: {e}")
self._load_knowledge_base()
def _load_knowledge_base(self):
"""加载知识库"""
# 优先加载RAG知识库作为世界观
rag_worldview_path = "./rag_knowledge/knowledge_base.json"
if os.path.exists(rag_worldview_path):
try:
with open(rag_worldview_path, 'r', encoding='utf-8') as f:
rag_data = json.load(f)
# 从RAG数据中提取世界观信息
self.worldview_data = {
"worldview_name": "克苏鲁神话世界观 (RAG)",
"source": rag_data.get("metadata", {}).get("source_file", "未知"),
"description": f"基于{rag_data.get('metadata', {}).get('source_file', 'PDF文档')}的RAG知识库",
"total_chunks": rag_data.get("metadata", {}).get("total_chunks", 0),
"total_concepts": rag_data.get("metadata", {}).get("total_concepts", 0),
"rag_enabled": True
}
# 保存RAG数据用于检索
self.rag_chunks = rag_data.get("chunks", [])
print(f"✓ RAG世界观加载成功: {self.worldview_data['worldview_name']}")
print(f" - 文档块数: {self.worldview_data['total_chunks']}")
print(f" - 概念数: {self.worldview_data['total_concepts']}")
except Exception as e:
print(f"✗ RAG世界观加载失败: {e}")
self.rag_chunks = []
# 如果没有RAG知识库则加载传统世界观文件
if not hasattr(self, 'rag_chunks') or not self.rag_chunks:
worldview_files = [f for f in os.listdir(self.knowledge_dir)
if f.startswith('worldview') and f.endswith('.json')]
if worldview_files:
worldview_path = os.path.join(self.knowledge_dir, worldview_files[0])
with open(worldview_path, 'r', encoding='utf-8') as f:
self.worldview_data = json.load(f)
print(f"✓ 传统世界观加载成功: {self.worldview_data.get('worldview_name', '未知')}")
# 加载角色数据
character_files = [f for f in os.listdir(self.knowledge_dir)
if f.startswith('character') and f.endswith('.json')]
for char_file in character_files:
char_path = os.path.join(self.knowledge_dir, char_file)
with open(char_path, 'r', encoding='utf-8') as f:
char_data = json.load(f)
char_name = char_data.get('character_name', char_file)
self.character_data[char_name] = char_data
print(f"✓ 角色加载成功: {list(self.character_data.keys())}")
# 构建检索用的文本块
self._build_searchable_chunks()
# 构建向量索引
if EMBEDDING_AVAILABLE and self.embedding_model:
self._build_vector_index()
def _build_searchable_chunks(self):
"""构建可检索的文本块"""
self.chunks = []
# 优先使用RAG知识库的文本块
if hasattr(self, 'rag_chunks') and self.rag_chunks:
for rag_chunk in self.rag_chunks:
self.chunks.append({
"type": "worldview_rag",
"section": "rag_knowledge",
"subsection": rag_chunk.get("type", "unknown"),
"content": rag_chunk.get("content", ""),
"metadata": {
"source": "rag_worldview",
"chunk_id": rag_chunk.get("id", ""),
"size": rag_chunk.get("size", 0),
"hash": rag_chunk.get("hash", "")
}
})
print(f"✓ 使用RAG知识库文本块: {len(self.rag_chunks)}")
else:
# 传统世界观相关文本块
if self.worldview_data:
for section_key, section_data in self.worldview_data.items():
if isinstance(section_data, dict):
for sub_key, sub_data in section_data.items():
if isinstance(sub_data, (str, list)):
content = str(sub_data)
if len(content) > 50: # 只保留有意义的文本
self.chunks.append({
"type": "worldview",
"section": section_key,
"subsection": sub_key,
"content": content,
"metadata": {"source": "worldview"}
})
# 角色相关文本块
for char_name, char_data in self.character_data.items():
for section_key, section_data in char_data.items():
if isinstance(section_data, dict):
for sub_key, sub_data in section_data.items():
if isinstance(sub_data, (str, list)):
content = str(sub_data)
if len(content) > 30:
self.chunks.append({
"type": "character",
"character": char_name,
"section": section_key,
"subsection": sub_key,
"content": content,
"metadata": {"source": char_name}
})
print(f"✓ 构建文本块: {len(self.chunks)}")
def _build_vector_index(self):
"""构建向量索引"""
try:
# 优先使用RAG知识库的预构建向量索引
rag_vector_path = "./rag_knowledge/vector_index.faiss"
rag_embeddings_path = "./rag_knowledge/embeddings.npy"
if os.path.exists(rag_vector_path) and os.path.exists(rag_embeddings_path):
# 加载预构建的向量索引
self.index = faiss.read_index(rag_vector_path)
self.rag_embeddings = np.load(rag_embeddings_path)
print(f"✓ 使用RAG预构建向量索引: {self.index.ntotal}个向量")
return
# 如果没有预构建的向量索引,则重新构建
texts = [chunk["content"] for chunk in self.chunks]
embeddings = self.embedding_model.encode(texts)
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(embeddings.astype(np.float32))
print(f"✓ 向量索引构建成功: {dimension}维, {len(texts)}个向量")
except Exception as e:
print(f"✗ 向量索引构建失败: {e}")
def search_relevant_context(self, query: str, character_name: str = None, top_k: int = 3) -> List[Dict]:
"""搜索相关上下文"""
relevant_chunks = []
# 向量搜索
if EMBEDDING_AVAILABLE and self.embedding_model and self.index:
try:
# 如果使用RAG预构建向量索引直接搜索
if hasattr(self, 'rag_embeddings'):
query_vector = self.embedding_model.encode([query])
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
for distance, idx in zip(distances[0], indices[0]):
if idx < len(self.chunks):
chunk = self.chunks[idx].copy()
chunk["relevance_score"] = float(1 / (1 + distance))
relevant_chunks.append(chunk)
else:
# 传统向量搜索
query_vector = self.embedding_model.encode([query])
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
for distance, idx in zip(distances[0], indices[0]):
if idx < len(self.chunks):
chunk = self.chunks[idx].copy()
chunk["relevance_score"] = float(1 / (1 + distance))
relevant_chunks.append(chunk)
except Exception as e:
print(f"向量搜索失败: {e}")
# 文本搜索作为备选
if not relevant_chunks:
query_lower = query.lower()
for chunk in self.chunks:
content_lower = chunk["content"].lower()
score = 0
for word in query_lower.split():
if word in content_lower:
score += content_lower.count(word)
if score > 0:
chunk_copy = chunk.copy()
chunk_copy["relevance_score"] = score
relevant_chunks.append(chunk_copy)
# 按相关性排序
relevant_chunks.sort(key=lambda x: x["relevance_score"], reverse=True)
# 优先返回特定角色的相关信息
if character_name:
char_chunks = [c for c in relevant_chunks if c.get("character") == character_name]
other_chunks = [c for c in relevant_chunks if c.get("character") != character_name]
relevant_chunks = char_chunks + other_chunks
return relevant_chunks[:top_k]
class ConversationManager:
"""对话管理器"""
def __init__(self, db_path: str = "conversation_history.db"):
self.db_path = db_path
self._init_database()
def _init_database(self):
"""初始化对话历史数据库"""
with sqlite3.connect(self.db_path) as conn:
conn.execute('''
CREATE TABLE IF NOT EXISTS conversations (
session_id TEXT PRIMARY KEY,
characters TEXT,
worldview TEXT,
start_time TEXT,
last_update TEXT,
metadata TEXT
)
''')
conn.execute('''
CREATE TABLE IF NOT EXISTS dialogue_turns (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT,
turn_number INTEGER,
speaker TEXT,
content TEXT,
timestamp TEXT,
context_used TEXT,
relevance_score REAL,
dialogue_score REAL DEFAULT 0.0,
score_details TEXT,
score_feedback TEXT,
FOREIGN KEY (session_id) REFERENCES conversations (session_id)
)
''')
conn.commit()
def create_session(self, characters: List[str], worldview: str) -> str:
"""创建新的对话会话"""
session_id = hashlib.md5(f"{'-'.join(characters)}-{datetime.now().isoformat()}".encode()).hexdigest()[:12]
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"INSERT INTO conversations (session_id, characters, worldview, start_time, last_update) VALUES (?, ?, ?, ?, ?)",
(session_id, json.dumps(characters), worldview, datetime.now().isoformat(), datetime.now().isoformat())
)
conn.commit()
print(f"✓ 创建对话会话: {session_id}")
return session_id
def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None,
relevance_score: float = 0.0, dialogue_score: float = 0.0,
score_details: str = None, score_feedback: str = None):
"""添加对话轮次"""
if context_used is None:
context_used = []
with sqlite3.connect(self.db_path) as conn:
# 获取当前轮次数
cursor = conn.execute("SELECT COUNT(*) FROM dialogue_turns WHERE session_id = ?", (session_id,))
turn_number = cursor.fetchone()[0] + 1
# 插入对话轮次
conn.execute(
"""INSERT INTO dialogue_turns
(session_id, turn_number, speaker, content, timestamp, context_used, relevance_score,
dialogue_score, score_details, score_feedback)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(session_id, turn_number, speaker, content, datetime.now().isoformat(),
json.dumps(context_used), relevance_score, dialogue_score, score_details, score_feedback)
)
# 更新会话最后更新时间
conn.execute(
"UPDATE conversations SET last_update = ? WHERE session_id = ?",
(datetime.now().isoformat(), session_id)
)
conn.commit()
def get_conversation_history(self, session_id: str, last_n: int = 10) -> List[DialogueTurn]:
"""获取对话历史"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"""SELECT speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback
FROM dialogue_turns
WHERE session_id = ?
ORDER BY turn_number DESC LIMIT ?""",
(session_id, last_n)
)
turns = []
for row in cursor.fetchall():
speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback = row
turn = DialogueTurn(
speaker=speaker,
content=content,
timestamp=timestamp,
context_used=json.loads(context_used or "[]"),
relevance_score=relevance_score
)
# 添加评分信息到turn对象
if dialogue_score:
turn.dialogue_score = dialogue_score
if score_feedback:
turn.score_feedback = score_feedback
turns.append(turn)
return list(reversed(turns)) # 按时间正序返回
def list_sessions(self) -> List[Dict]:
"""列出所有对话会话"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"SELECT session_id, characters, worldview, start_time, last_update FROM conversations ORDER BY last_update DESC"
)
sessions = []
for row in cursor.fetchall():
session_id, characters, worldview, start_time, last_update = row
sessions.append({
"session_id": session_id,
"characters": json.loads(characters),
"worldview": worldview,
"start_time": start_time,
"last_update": last_update
})
return sessions
class DualAIDialogueEngine:
"""双AI对话引擎"""
def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator,
enable_scoring: bool = True, base_model_path: str = None, use_manual_scoring: bool = False):
self.kb = knowledge_base
self.conv_mgr = conversation_manager
self.llm_generator = llm_generator
self.enable_scoring = enable_scoring
self.use_manual_scoring = use_manual_scoring
self.scorer = None
# 初始化评分器
if enable_scoring and base_model_path and not use_manual_scoring:
try:
from dialogue_scorer import DialogueAIScorer
print("正在初始化对话评分系统...")
self.scorer = DialogueAIScorer(
base_model_path=base_model_path,
tokenizer=getattr(llm_generator, 'tokenizer', None),
model=getattr(llm_generator, 'model', None)
)
print("✓ 对话评分系统初始化成功")
except Exception as e:
print(f"⚠ 对话评分系统初始化失败: {e}")
self.enable_scoring = False
def _manual_score_dialogue_turn(self, dialogue_content: str, speaker: str, dialogue_history: List[DialogueTurn]) -> Tuple[float, str, str]:
"""人工打分对话轮次
Args:
dialogue_content: 对话内容
speaker: 说话者
dialogue_history: 对话历史
Returns:
tuple: (总分, 详细分数JSON, 反馈意见)
"""
print("\n" + "="*60)
print("人工对话评分")
print("="*60)
# print(f"说话者: {speaker}")
# print(f"对话内容: {dialogue_content}")
print("-" * 40)
# # 显示最近的对话历史作为参考
# if dialogue_history:
# print("最近对话历史:")
# for i, turn in enumerate(dialogue_history[-3:], 1):
# print(f" {i}. {turn.speaker}: {turn.content[:100]}...")
# print("-" * 40)
# 五个评分维度
dimensions = {
'coherence': '逻辑连贯性 (1-10)',
'character_consistency': '角色一致性 (1-10)',
'naturalness': '自然流畅度 (1-10)',
'information_density': '信息密度 (1-10)',
'creativity': '创意新颖度 (1-10)'
}
scores = {}
print("\n请为以下维度打分 (输入1-10的分数直接回车跳过该维度):")
for key, desc in dimensions.items():
while True:
try:
score_input = input(f"{desc}: ").strip()
if score_input == "":
scores[key] = 7.0 # 默认分数
break
score = float(score_input)
if 1 <= score <= 10:
scores[key] = score
break
else:
print("请输入1-10之间的分数")
except ValueError:
print("请输入有效的数字")
# 计算总分
overall_score = sum(scores.values()) / len(scores)
# 获取反馈意见
print("\n请输入对该对话的评价和建议 (可选,直接回车跳过):")
feedback = input("反馈意见: ").strip()
if not feedback:
feedback = f"人工评分完成,总分: {overall_score:.1f}"
print(f"\n✓ 评分完成 - 总分: {overall_score:.1f}")
print("="*60)
return overall_score, json.dumps(scores), feedback
def score_dialogue_turn(self, dialogue_content: str, speaker: str, dialogue_history: List[DialogueTurn]) -> Tuple[float, str, str]:
"""对单条对话进行评分
Args:
dialogue_content: 对话内容
speaker: 说话者
dialogue_history: 对话历史
Returns:
tuple: (总分, 详细分数JSON, 反馈意见)
"""
if not self.enable_scoring:
return 0.0, "{}", "评分系统未启用"
# 人工打分模式
if self.use_manual_scoring:
return self._manual_score_dialogue_turn(dialogue_content, speaker, dialogue_history)
# AI自动打分模式
if not self.scorer:
return 0.0, "{}", "AI评分器未初始化"
try:
# 获取角色数据
character_data = self.kb.character_data.get(speaker, {})
# 转换对话历史格式
history_for_scoring = []
for turn in dialogue_history[-5:]: # 最近5轮对话
history_for_scoring.append({
'speaker': turn.speaker,
'content': turn.content
})
# 进行AI评分
score_result = self.scorer.score_dialogue(
dialogue_content=dialogue_content,
speaker=speaker,
character_data=character_data,
dialogue_history=history_for_scoring,
context_info=[]
)
# 返回评分结果
return score_result.overall_score, json.dumps(score_result.scores), score_result.feedback
except Exception as e:
print(f"⚠ 对话评分失败: {e}")
return 0.0, "{}", f"评分失败: {str(e)}"
def run_dual_model_conversation(self, session_id: str, topic: str = "", turns: int = 4,
history_context_count: int = 3, context_info_count: int = 2):
"""使用双模型系统运行对话
Args:
session_id: 会话ID
topic: 对话主题
turns: 对话轮数
history_context_count: 使用的历史对话轮数
context_info_count: 使用的上下文信息数量
"""
# 检查是否为双模型对话系统
if not hasattr(self.llm_generator, 'run_dual_character_conversation'):
print("⚠ 当前系统不支持双模型对话")
return self.run_conversation_turn(session_id, self.llm_generator.list_characters(), turns, topic,
history_context_count, context_info_count)
# 获取对话历史
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
# 构建上下文信息
if dialogue_history:
recent_turns = dialogue_history[-history_context_count:] if history_context_count > 0 else []
recent_content = " ".join([turn.content for turn in recent_turns])
search_query = recent_content + " " + topic
else:
search_query = f"{topic} introduction greeting"
# 搜索相关上下文
context_info = self.kb.search_relevant_context(search_query, top_k=context_info_count)
# 构建上下文字符串
context_str = ""
if context_info:
context_str = "相关背景信息:"
for info in context_info[:context_info_count]:
content = info['content'][:150] + "..." if len(info['content']) > 150 else info['content']
context_str += f"\n- {content}"
print(f"\n=== 双模型对话系统 ===")
print(f"主题: {topic}")
print(f"角色: {', '.join(self.llm_generator.list_characters())}")
print(f"轮数: {turns}")
print(f"上下文设置: 历史{history_context_count}轮, 信息{context_info_count}")
# 使用双模型系统生成对话
for turn in range(turns):
# 获取对话历史
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
conversation_results = self.llm_generator.run_dual_character_conversation(
topic=topic,
turn_index = turn,
context=context_str,
dialogue_history = dialogue_history,
history_context_count = history_context_count,
max_new_tokens=150
)
# 保存对话到数据库并进行评分
for result in conversation_results:
# 获取当前对话历史进行评分
current_dialogue_history = self.conv_mgr.get_conversation_history(session_id)
# 对对话进行评分
if self.enable_scoring:
dialogue_score, score_details, score_feedback = self.score_dialogue_turn(
result['dialogue'], result['speaker'], current_dialogue_history
)
print(f" [评分: {dialogue_score:.2f}] {score_feedback[:100]}...")
else:
dialogue_score, score_details, score_feedback = 0.0, "{}", ""
self.conv_mgr.add_dialogue_turn(
session_id,
result['speaker'],
result['dialogue'],
[result.get('context_used', '')],
0.8, # 默认相关性分数
dialogue_score,
score_details,
score_feedback
)
return conversation_results