#!/usr/bin/env python # -*- coding: utf-8 -*- ''' RAG增强的角色对话系统 集成世界观知识库,支持角色设定加载和对话生成 ''' import json import os import sqlite3 from datetime import datetime from typing import Dict, List, Optional, Tuple from dataclasses import dataclass, asdict import hashlib # 尝试导入向量化相关库 try: from sentence_transformers import SentenceTransformer import faiss import numpy as np EMBEDDING_AVAILABLE = True except ImportError: EMBEDDING_AVAILABLE = False @dataclass class DialogueTurn: """对话轮次数据结构""" speaker: str content: str timestamp: str context_used: List[str] # 使用的上下文信息 relevance_score: float = 0.0 @dataclass class ConversationSession: """对话会话数据结构""" session_id: str characters: List[str] worldview: str start_time: str last_update: str dialogue_history: List[DialogueTurn] class RAGKnowledgeBase: """RAG知识库管理器""" def __init__(self, knowledge_dir: str): self.knowledge_dir = knowledge_dir self.worldview_data = None self.character_data = {} self.chunks = [] self.embedding_model = None self.index = None # 初始化向量模型 if EMBEDDING_AVAILABLE: try: self.embedding_model = SentenceTransformer('./sentence-transformers/all-MiniLM-L6-v2') print("✓ 向量模型加载成功") except Exception as e: print(f"✗ 向量模型加载失败: {e}") self._load_knowledge_base() def _load_knowledge_base(self): """加载知识库""" # 优先加载RAG知识库作为世界观 rag_worldview_path = "./rag_knowledge/knowledge_base.json" if os.path.exists(rag_worldview_path): try: with open(rag_worldview_path, 'r', encoding='utf-8') as f: rag_data = json.load(f) # 从RAG数据中提取世界观信息 self.worldview_data = { "worldview_name": "克苏鲁神话世界观 (RAG)", "source": rag_data.get("metadata", {}).get("source_file", "未知"), "description": f"基于{rag_data.get('metadata', {}).get('source_file', 'PDF文档')}的RAG知识库", "total_chunks": rag_data.get("metadata", {}).get("total_chunks", 0), "total_concepts": rag_data.get("metadata", {}).get("total_concepts", 0), "rag_enabled": True } # 保存RAG数据用于检索 self.rag_chunks = rag_data.get("chunks", []) print(f"✓ RAG世界观加载成功: {self.worldview_data['worldview_name']}") print(f" - 文档块数: {self.worldview_data['total_chunks']}") print(f" - 概念数: {self.worldview_data['total_concepts']}") except Exception as e: print(f"✗ RAG世界观加载失败: {e}") self.rag_chunks = [] # 如果没有RAG知识库,则加载传统世界观文件 if not hasattr(self, 'rag_chunks') or not self.rag_chunks: worldview_files = [f for f in os.listdir(self.knowledge_dir) if f.startswith('worldview') and f.endswith('.json')] if worldview_files: worldview_path = os.path.join(self.knowledge_dir, worldview_files[0]) with open(worldview_path, 'r', encoding='utf-8') as f: self.worldview_data = json.load(f) print(f"✓ 传统世界观加载成功: {self.worldview_data.get('worldview_name', '未知')}") # 加载角色数据 character_files = [f for f in os.listdir(self.knowledge_dir) if f.startswith('character') and f.endswith('.json')] for char_file in character_files: char_path = os.path.join(self.knowledge_dir, char_file) with open(char_path, 'r', encoding='utf-8') as f: char_data = json.load(f) char_name = char_data.get('character_name', char_file) self.character_data[char_name] = char_data print(f"✓ 角色加载成功: {list(self.character_data.keys())}") # 构建检索用的文本块 self._build_searchable_chunks() # 构建向量索引 if EMBEDDING_AVAILABLE and self.embedding_model: self._build_vector_index() def _build_searchable_chunks(self): """构建可检索的文本块""" self.chunks = [] # 优先使用RAG知识库的文本块 if hasattr(self, 'rag_chunks') and self.rag_chunks: for rag_chunk in self.rag_chunks: self.chunks.append({ "type": "worldview_rag", "section": "rag_knowledge", "subsection": rag_chunk.get("type", "unknown"), "content": rag_chunk.get("content", ""), "metadata": { "source": "rag_worldview", "chunk_id": rag_chunk.get("id", ""), "size": rag_chunk.get("size", 0), "hash": rag_chunk.get("hash", "") } }) print(f"✓ 使用RAG知识库文本块: {len(self.rag_chunks)} 个") else: # 传统世界观相关文本块 if self.worldview_data: for section_key, section_data in self.worldview_data.items(): if isinstance(section_data, dict): for sub_key, sub_data in section_data.items(): if isinstance(sub_data, (str, list)): content = str(sub_data) if len(content) > 50: # 只保留有意义的文本 self.chunks.append({ "type": "worldview", "section": section_key, "subsection": sub_key, "content": content, "metadata": {"source": "worldview"} }) # 角色相关文本块 for char_name, char_data in self.character_data.items(): for section_key, section_data in char_data.items(): if isinstance(section_data, dict): for sub_key, sub_data in section_data.items(): if isinstance(sub_data, (str, list)): content = str(sub_data) if len(content) > 30: self.chunks.append({ "type": "character", "character": char_name, "section": section_key, "subsection": sub_key, "content": content, "metadata": {"source": char_name} }) print(f"✓ 构建文本块: {len(self.chunks)} 个") def _build_vector_index(self): """构建向量索引""" try: # 优先使用RAG知识库的预构建向量索引 rag_vector_path = "./rag_knowledge/vector_index.faiss" rag_embeddings_path = "./rag_knowledge/embeddings.npy" if os.path.exists(rag_vector_path) and os.path.exists(rag_embeddings_path): # 加载预构建的向量索引 self.index = faiss.read_index(rag_vector_path) self.rag_embeddings = np.load(rag_embeddings_path) print(f"✓ 使用RAG预构建向量索引: {self.index.ntotal}个向量") return # 如果没有预构建的向量索引,则重新构建 texts = [chunk["content"] for chunk in self.chunks] embeddings = self.embedding_model.encode(texts) dimension = embeddings.shape[1] self.index = faiss.IndexFlatL2(dimension) self.index.add(embeddings.astype(np.float32)) print(f"✓ 向量索引构建成功: {dimension}维, {len(texts)}个向量") except Exception as e: print(f"✗ 向量索引构建失败: {e}") def search_relevant_context(self, query: str, character_name: str = None, top_k: int = 3) -> List[Dict]: """搜索相关上下文""" relevant_chunks = [] # 向量搜索 if EMBEDDING_AVAILABLE and self.embedding_model and self.index: try: # 如果使用RAG预构建向量索引,直接搜索 if hasattr(self, 'rag_embeddings'): query_vector = self.embedding_model.encode([query]) distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2) for distance, idx in zip(distances[0], indices[0]): if idx < len(self.chunks): chunk = self.chunks[idx].copy() chunk["relevance_score"] = float(1 / (1 + distance)) relevant_chunks.append(chunk) else: # 传统向量搜索 query_vector = self.embedding_model.encode([query]) distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2) for distance, idx in zip(distances[0], indices[0]): if idx < len(self.chunks): chunk = self.chunks[idx].copy() chunk["relevance_score"] = float(1 / (1 + distance)) relevant_chunks.append(chunk) except Exception as e: print(f"向量搜索失败: {e}") # 文本搜索作为备选 if not relevant_chunks: query_lower = query.lower() for chunk in self.chunks: content_lower = chunk["content"].lower() score = 0 for word in query_lower.split(): if word in content_lower: score += content_lower.count(word) if score > 0: chunk_copy = chunk.copy() chunk_copy["relevance_score"] = score relevant_chunks.append(chunk_copy) # 按相关性排序 relevant_chunks.sort(key=lambda x: x["relevance_score"], reverse=True) # 优先返回特定角色的相关信息 if character_name: char_chunks = [c for c in relevant_chunks if c.get("character") == character_name] other_chunks = [c for c in relevant_chunks if c.get("character") != character_name] relevant_chunks = char_chunks + other_chunks return relevant_chunks[:top_k] class ConversationManager: """对话管理器""" def __init__(self, db_path: str = "conversation_history.db"): self.db_path = db_path self._init_database() def _init_database(self): """初始化对话历史数据库""" with sqlite3.connect(self.db_path) as conn: conn.execute(''' CREATE TABLE IF NOT EXISTS conversations ( session_id TEXT PRIMARY KEY, characters TEXT, worldview TEXT, start_time TEXT, last_update TEXT, metadata TEXT ) ''') conn.execute(''' CREATE TABLE IF NOT EXISTS dialogue_turns ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT, turn_number INTEGER, speaker TEXT, content TEXT, timestamp TEXT, context_used TEXT, relevance_score REAL, dialogue_score REAL DEFAULT 0.0, score_details TEXT, score_feedback TEXT, FOREIGN KEY (session_id) REFERENCES conversations (session_id) ) ''') conn.commit() def create_session(self, characters: List[str], worldview: str) -> str: """创建新的对话会话""" session_id = hashlib.md5(f"{'-'.join(characters)}-{datetime.now().isoformat()}".encode()).hexdigest()[:12] with sqlite3.connect(self.db_path) as conn: conn.execute( "INSERT INTO conversations (session_id, characters, worldview, start_time, last_update) VALUES (?, ?, ?, ?, ?)", (session_id, json.dumps(characters), worldview, datetime.now().isoformat(), datetime.now().isoformat()) ) conn.commit() print(f"✓ 创建对话会话: {session_id}") return session_id def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None, relevance_score: float = 0.0, dialogue_score: float = 0.0, score_details: str = None, score_feedback: str = None): """添加对话轮次""" if context_used is None: context_used = [] with sqlite3.connect(self.db_path) as conn: # 获取当前轮次数 cursor = conn.execute("SELECT COUNT(*) FROM dialogue_turns WHERE session_id = ?", (session_id,)) turn_number = cursor.fetchone()[0] + 1 # 插入对话轮次 conn.execute( """INSERT INTO dialogue_turns (session_id, turn_number, speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_details, score_feedback) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", (session_id, turn_number, speaker, content, datetime.now().isoformat(), json.dumps(context_used), relevance_score, dialogue_score, score_details, score_feedback) ) # 更新会话最后更新时间 conn.execute( "UPDATE conversations SET last_update = ? WHERE session_id = ?", (datetime.now().isoformat(), session_id) ) conn.commit() def get_conversation_history(self, session_id: str, last_n: int = 10) -> List[DialogueTurn]: """获取对话历史""" with sqlite3.connect(self.db_path) as conn: cursor = conn.execute( """SELECT speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback FROM dialogue_turns WHERE session_id = ? ORDER BY turn_number DESC LIMIT ?""", (session_id, last_n) ) turns = [] for row in cursor.fetchall(): speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback = row turn = DialogueTurn( speaker=speaker, content=content, timestamp=timestamp, context_used=json.loads(context_used or "[]"), relevance_score=relevance_score ) # 添加评分信息到turn对象 if dialogue_score: turn.dialogue_score = dialogue_score if score_feedback: turn.score_feedback = score_feedback turns.append(turn) return list(reversed(turns)) # 按时间正序返回 def list_sessions(self) -> List[Dict]: """列出所有对话会话""" with sqlite3.connect(self.db_path) as conn: cursor = conn.execute( "SELECT session_id, characters, worldview, start_time, last_update FROM conversations ORDER BY last_update DESC" ) sessions = [] for row in cursor.fetchall(): session_id, characters, worldview, start_time, last_update = row sessions.append({ "session_id": session_id, "characters": json.loads(characters), "worldview": worldview, "start_time": start_time, "last_update": last_update }) return sessions class DualAIDialogueEngine: """双AI对话引擎""" def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator, enable_scoring: bool = True, base_model_path: str = None, use_manual_scoring: bool = False): self.kb = knowledge_base self.conv_mgr = conversation_manager self.llm_generator = llm_generator self.enable_scoring = enable_scoring self.use_manual_scoring = use_manual_scoring self.scorer = None # 初始化评分器 if enable_scoring and base_model_path and not use_manual_scoring: try: from dialogue_scorer import DialogueAIScorer print("正在初始化对话评分系统...") self.scorer = DialogueAIScorer( base_model_path=base_model_path, tokenizer=getattr(llm_generator, 'tokenizer', None), model=getattr(llm_generator, 'model', None) ) print("✓ 对话评分系统初始化成功") except Exception as e: print(f"⚠ 对话评分系统初始化失败: {e}") self.enable_scoring = False def _manual_score_dialogue_turn(self, dialogue_content: str, speaker: str, dialogue_history: List[DialogueTurn]) -> Tuple[float, str, str]: """人工打分对话轮次 Args: dialogue_content: 对话内容 speaker: 说话者 dialogue_history: 对话历史 Returns: tuple: (总分, 详细分数JSON, 反馈意见) """ print("\n" + "="*60) print("人工对话评分") print("="*60) # print(f"说话者: {speaker}") # print(f"对话内容: {dialogue_content}") print("-" * 40) # # 显示最近的对话历史作为参考 # if dialogue_history: # print("最近对话历史:") # for i, turn in enumerate(dialogue_history[-3:], 1): # print(f" {i}. {turn.speaker}: {turn.content[:100]}...") # print("-" * 40) # 五个评分维度 dimensions = { 'coherence': '逻辑连贯性 (1-10)', 'character_consistency': '角色一致性 (1-10)', 'naturalness': '自然流畅度 (1-10)', 'information_density': '信息密度 (1-10)', 'creativity': '创意新颖度 (1-10)' } scores = {} print("\n请为以下维度打分 (输入1-10的分数,直接回车跳过该维度):") for key, desc in dimensions.items(): while True: try: score_input = input(f"{desc}: ").strip() if score_input == "": scores[key] = 7.0 # 默认分数 break score = float(score_input) if 1 <= score <= 10: scores[key] = score break else: print("请输入1-10之间的分数") except ValueError: print("请输入有效的数字") # 计算总分 overall_score = sum(scores.values()) / len(scores) # 获取反馈意见 print("\n请输入对该对话的评价和建议 (可选,直接回车跳过):") feedback = input("反馈意见: ").strip() if not feedback: feedback = f"人工评分完成,总分: {overall_score:.1f}" print(f"\n✓ 评分完成 - 总分: {overall_score:.1f}") print("="*60) return overall_score, json.dumps(scores), feedback def score_dialogue_turn(self, dialogue_content: str, speaker: str, dialogue_history: List[DialogueTurn]) -> Tuple[float, str, str]: """对单条对话进行评分 Args: dialogue_content: 对话内容 speaker: 说话者 dialogue_history: 对话历史 Returns: tuple: (总分, 详细分数JSON, 反馈意见) """ if not self.enable_scoring: return 0.0, "{}", "评分系统未启用" # 人工打分模式 if self.use_manual_scoring: return self._manual_score_dialogue_turn(dialogue_content, speaker, dialogue_history) # AI自动打分模式 if not self.scorer: return 0.0, "{}", "AI评分器未初始化" try: # 获取角色数据 character_data = self.kb.character_data.get(speaker, {}) # 转换对话历史格式 history_for_scoring = [] for turn in dialogue_history[-5:]: # 最近5轮对话 history_for_scoring.append({ 'speaker': turn.speaker, 'content': turn.content }) # 进行AI评分 score_result = self.scorer.score_dialogue( dialogue_content=dialogue_content, speaker=speaker, character_data=character_data, dialogue_history=history_for_scoring, context_info=[] ) # 返回评分结果 return score_result.overall_score, json.dumps(score_result.scores), score_result.feedback except Exception as e: print(f"⚠ 对话评分失败: {e}") return 0.0, "{}", f"评分失败: {str(e)}" def run_dual_model_conversation(self, session_id: str, topic: str = "", turns: int = 4, history_context_count: int = 3, context_info_count: int = 2): """使用双模型系统运行对话 Args: session_id: 会话ID topic: 对话主题 turns: 对话轮数 history_context_count: 使用的历史对话轮数 context_info_count: 使用的上下文信息数量 """ # 检查是否为双模型对话系统 if not hasattr(self.llm_generator, 'run_dual_character_conversation'): print("⚠ 当前系统不支持双模型对话") return self.run_conversation_turn(session_id, self.llm_generator.list_characters(), turns, topic, history_context_count, context_info_count) # 获取对话历史 dialogue_history = self.conv_mgr.get_conversation_history(session_id) # 构建上下文信息 if dialogue_history: recent_turns = dialogue_history[-history_context_count:] if history_context_count > 0 else [] recent_content = " ".join([turn.content for turn in recent_turns]) search_query = recent_content + " " + topic else: search_query = f"{topic} introduction greeting" # 搜索相关上下文 context_info = self.kb.search_relevant_context(search_query, top_k=context_info_count) # 构建上下文字符串 context_str = "" if context_info: context_str = "相关背景信息:" for info in context_info[:context_info_count]: content = info['content'][:150] + "..." if len(info['content']) > 150 else info['content'] context_str += f"\n- {content}" print(f"\n=== 双模型对话系统 ===") print(f"主题: {topic}") print(f"角色: {', '.join(self.llm_generator.list_characters())}") print(f"轮数: {turns}") print(f"上下文设置: 历史{history_context_count}轮, 信息{context_info_count}个") # 使用双模型系统生成对话 for turn in range(turns): # 获取对话历史 dialogue_history = self.conv_mgr.get_conversation_history(session_id) conversation_results = self.llm_generator.run_dual_character_conversation( topic=topic, turn_index = turn, context=context_str, dialogue_history = dialogue_history, history_context_count = history_context_count, max_new_tokens=150 ) # 保存对话到数据库并进行评分 for result in conversation_results: # 获取当前对话历史进行评分 current_dialogue_history = self.conv_mgr.get_conversation_history(session_id) # 对对话进行评分 if self.enable_scoring: dialogue_score, score_details, score_feedback = self.score_dialogue_turn( result['dialogue'], result['speaker'], current_dialogue_history ) print(f" [评分: {dialogue_score:.2f}] {score_feedback[:100]}...") else: dialogue_score, score_details, score_feedback = 0.0, "{}", "" self.conv_mgr.add_dialogue_turn( session_id, result['speaker'], result['dialogue'], [result.get('context_used', '')], 0.8, # 默认相关性分数 dialogue_score, score_details, score_feedback ) return conversation_results