#!/usr/bin/env python # -*- coding: utf-8 -*- ''' RAG增强的角色对话系统 集成世界观知识库,支持角色设定加载和对话生成 ''' import json import os import sqlite3 from datetime import datetime from typing import Dict, List, Optional, Tuple from dataclasses import dataclass, asdict import hashlib # 尝试导入向量化相关库 try: from sentence_transformers import SentenceTransformer import faiss import numpy as np EMBEDDING_AVAILABLE = True except ImportError: EMBEDDING_AVAILABLE = False @dataclass class DialogueTurn: """对话轮次数据结构""" speaker: str content: str timestamp: str context_used: List[str] # 使用的上下文信息 relevance_score: float = 0.0 @dataclass class ConversationSession: """对话会话数据结构""" session_id: str characters: List[str] worldview: str start_time: str last_update: str dialogue_history: List[DialogueTurn] class RAGKnowledgeBase: """RAG知识库管理器""" def __init__(self, knowledge_dir: str): self.knowledge_dir = knowledge_dir self.worldview_data = None self.character_data = {} self.chunks = [] self.embedding_model = None self.index = None # 初始化向量模型 if EMBEDDING_AVAILABLE: try: self.embedding_model = SentenceTransformer('./sentence-transformers/all-MiniLM-L6-v2') print("✓ 向量模型加载成功") except Exception as e: print(f"✗ 向量模型加载失败: {e}") self._load_knowledge_base() def _load_knowledge_base(self): """加载知识库""" # 加载世界观 worldview_files = [f for f in os.listdir(self.knowledge_dir) if f.startswith('worldview') and f.endswith('.json')] if worldview_files: worldview_path = os.path.join(self.knowledge_dir, worldview_files[0]) with open(worldview_path, 'r', encoding='utf-8') as f: self.worldview_data = json.load(f) print(f"✓ 世界观加载成功: {self.worldview_data.get('worldview_name', '未知')}") # 加载角色数据 character_files = [f for f in os.listdir(self.knowledge_dir) if f.startswith('character') and f.endswith('.json')] for char_file in character_files: char_path = os.path.join(self.knowledge_dir, char_file) with open(char_path, 'r', encoding='utf-8') as f: char_data = json.load(f) char_name = char_data.get('character_name', char_file) self.character_data[char_name] = char_data print(f"✓ 角色加载成功: {list(self.character_data.keys())}") # 构建检索用的文本块 self._build_searchable_chunks() # 构建向量索引 if EMBEDDING_AVAILABLE and self.embedding_model: self._build_vector_index() def _build_searchable_chunks(self): """构建可检索的文本块""" self.chunks = [] # 世界观相关文本块 if self.worldview_data: for section_key, section_data in self.worldview_data.items(): if isinstance(section_data, dict): for sub_key, sub_data in section_data.items(): if isinstance(sub_data, (str, list)): content = str(sub_data) if len(content) > 50: # 只保留有意义的文本 self.chunks.append({ "type": "worldview", "section": section_key, "subsection": sub_key, "content": content, "metadata": {"source": "worldview"} }) # 角色相关文本块 for char_name, char_data in self.character_data.items(): for section_key, section_data in char_data.items(): if isinstance(section_data, dict): for sub_key, sub_data in section_data.items(): if isinstance(sub_data, (str, list)): content = str(sub_data) if len(content) > 30: self.chunks.append({ "type": "character", "character": char_name, "section": section_key, "subsection": sub_key, "content": content, "metadata": {"source": char_name} }) print(f"✓ 构建文本块: {len(self.chunks)} 个") def _build_vector_index(self): """构建向量索引""" try: texts = [chunk["content"] for chunk in self.chunks] embeddings = self.embedding_model.encode(texts) dimension = embeddings.shape[1] self.index = faiss.IndexFlatL2(dimension) self.index.add(embeddings.astype(np.float32)) print(f"✓ 向量索引构建成功: {dimension}维, {len(texts)}个向量") except Exception as e: print(f"✗ 向量索引构建失败: {e}") def search_relevant_context(self, query: str, character_name: str = None, top_k: int = 3) -> List[Dict]: """搜索相关上下文""" relevant_chunks = [] # 向量搜索 if EMBEDDING_AVAILABLE and self.embedding_model and self.index: try: query_vector = self.embedding_model.encode([query]) distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2) for distance, idx in zip(distances[0], indices[0]): if idx < len(self.chunks): chunk = self.chunks[idx].copy() chunk["relevance_score"] = float(1 / (1 + distance)) relevant_chunks.append(chunk) except Exception as e: print(f"向量搜索失败: {e}") # 文本搜索作为备选 if not relevant_chunks: query_lower = query.lower() for chunk in self.chunks: content_lower = chunk["content"].lower() score = 0 for word in query_lower.split(): if word in content_lower: score += content_lower.count(word) if score > 0: chunk_copy = chunk.copy() chunk_copy["relevance_score"] = score relevant_chunks.append(chunk_copy) # 按相关性排序 relevant_chunks.sort(key=lambda x: x["relevance_score"], reverse=True) # 优先返回特定角色的相关信息 if character_name: char_chunks = [c for c in relevant_chunks if c.get("character") == character_name] other_chunks = [c for c in relevant_chunks if c.get("character") != character_name] relevant_chunks = char_chunks + other_chunks return relevant_chunks[:top_k] class ConversationManager: """对话管理器""" def __init__(self, db_path: str = "conversation_history.db"): self.db_path = db_path self._init_database() def _init_database(self): """初始化对话历史数据库""" with sqlite3.connect(self.db_path) as conn: conn.execute(''' CREATE TABLE IF NOT EXISTS conversations ( session_id TEXT PRIMARY KEY, characters TEXT, worldview TEXT, start_time TEXT, last_update TEXT, metadata TEXT ) ''') conn.execute(''' CREATE TABLE IF NOT EXISTS dialogue_turns ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT, turn_number INTEGER, speaker TEXT, content TEXT, timestamp TEXT, context_used TEXT, relevance_score REAL, FOREIGN KEY (session_id) REFERENCES conversations (session_id) ) ''') conn.commit() def create_session(self, characters: List[str], worldview: str) -> str: """创建新的对话会话""" session_id = hashlib.md5(f"{'-'.join(characters)}-{datetime.now().isoformat()}".encode()).hexdigest()[:12] with sqlite3.connect(self.db_path) as conn: conn.execute( "INSERT INTO conversations (session_id, characters, worldview, start_time, last_update) VALUES (?, ?, ?, ?, ?)", (session_id, json.dumps(characters), worldview, datetime.now().isoformat(), datetime.now().isoformat()) ) conn.commit() print(f"✓ 创建对话会话: {session_id}") return session_id def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None, relevance_score: float = 0.0): """添加对话轮次""" if context_used is None: context_used = [] with sqlite3.connect(self.db_path) as conn: # 获取当前轮次数 cursor = conn.execute("SELECT COUNT(*) FROM dialogue_turns WHERE session_id = ?", (session_id,)) turn_number = cursor.fetchone()[0] + 1 # 插入对话轮次 conn.execute( """INSERT INTO dialogue_turns (session_id, turn_number, speaker, content, timestamp, context_used, relevance_score) VALUES (?, ?, ?, ?, ?, ?, ?)""", (session_id, turn_number, speaker, content, datetime.now().isoformat(), json.dumps(context_used), relevance_score) ) # 更新会话最后更新时间 conn.execute( "UPDATE conversations SET last_update = ? WHERE session_id = ?", (datetime.now().isoformat(), session_id) ) conn.commit() def get_conversation_history(self, session_id: str, last_n: int = 10) -> List[DialogueTurn]: """获取对话历史""" with sqlite3.connect(self.db_path) as conn: cursor = conn.execute( """SELECT speaker, content, timestamp, context_used, relevance_score FROM dialogue_turns WHERE session_id = ? ORDER BY turn_number DESC LIMIT ?""", (session_id, last_n) ) turns = [] for row in cursor.fetchall(): speaker, content, timestamp, context_used, relevance_score = row turns.append(DialogueTurn( speaker=speaker, content=content, timestamp=timestamp, context_used=json.loads(context_used or "[]"), relevance_score=relevance_score )) return list(reversed(turns)) # 按时间正序返回 def list_sessions(self) -> List[Dict]: """列出所有对话会话""" with sqlite3.connect(self.db_path) as conn: cursor = conn.execute( "SELECT session_id, characters, worldview, start_time, last_update FROM conversations ORDER BY last_update DESC" ) sessions = [] for row in cursor.fetchall(): session_id, characters, worldview, start_time, last_update = row sessions.append({ "session_id": session_id, "characters": json.loads(characters), "worldview": worldview, "start_time": start_time, "last_update": last_update }) return sessions class DualAIDialogueEngine: """双AI对话引擎""" def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator): self.kb = knowledge_base self.conv_mgr = conversation_manager self.llm_generator = llm_generator def generate_character_prompt(self, character_name: str, context_info: List[Dict], dialogue_history: List[DialogueTurn]) -> str: """为角色生成对话提示""" char_data = self.kb.character_data.get(character_name, {}) # 基础角色设定 prompt_parts = [] prompt_parts.append(f"你是{character_name},具有以下设定:") if char_data.get('personality', {}).get('core_traits'): traits = ", ".join(char_data['personality']['core_traits']) prompt_parts.append(f"性格特点:{traits}") if char_data.get('speech_patterns', {}).get('sample_phrases'): phrases = char_data['speech_patterns']['sample_phrases'][:3] prompt_parts.append(f"说话风格示例:{'; '.join(phrases)}") # 当前情境 if char_data.get('current_situation'): situation = char_data['current_situation'] prompt_parts.append(f"当前状态:{situation.get('current_mood', '')}") # 相关世界观信息 if context_info: prompt_parts.append("相关背景信息:") for info in context_info[:2]: # 只使用最相关的2个信息 content = info['content'][:200] + "..." if len(info['content']) > 200 else info['content'] prompt_parts.append(f"- {content}") # 对话历史 if dialogue_history: prompt_parts.append("最近的对话:") for turn in dialogue_history[-3:]: # 只使用最近的3轮对话 prompt_parts.append(f"{turn.speaker}: {turn.content}") prompt_parts.append("\n请根据角色设定和上下文,生成符合角色特点的自然对话。回复应该在50-150字之间。") return "\n".join(prompt_parts) def generate_dialogue(self, session_id: str, current_speaker: str, topic_hint: str = "") -> Tuple[str, List[str]]: """生成角色对话""" # 获取对话历史 dialogue_history = self.conv_mgr.get_conversation_history(session_id) # 构建搜索查询 if dialogue_history: # 基于最近的对话内容 recent_content = " ".join([turn.content for turn in dialogue_history[-2:]]) search_query = recent_content + " " + topic_hint else: # 首次对话 search_query = f"{current_speaker} {topic_hint} introduction greeting" # 搜索相关上下文 context_info = self.kb.search_relevant_context(search_query, current_speaker, 10) # 生成提示 prompt = self.generate_character_prompt(current_speaker, context_info, dialogue_history) # 生成对话 try: response = self.llm_generator.generate_character_dialogue( current_speaker, prompt, topic_hint or "请继续对话", temperature=0.8, max_new_tokens=150 ) # 记录使用的上下文 context_used = [f"{info['section']}.{info['subsection']}" for info in context_info] avg_relevance = sum(info['relevance_score'] for info in context_info) / len(context_info) if context_info else 0.0 # 保存对话轮次 self.conv_mgr.add_dialogue_turn( session_id, current_speaker, response, context_used, avg_relevance ) return response, context_used except Exception as e: print(f"✗ 对话生成失败: {e}") return f"[{current_speaker}暂时无法回应]", [] def run_conversation_turn(self, session_id: str, characters: List[str], turns_count: int = 1, topic: str = ""): """运行对话轮次""" results = [] for i in range(turns_count): for char in characters: response, context_used = self.generate_dialogue(session_id, char, topic) results.append({ "speaker": char, "content": response, "context_used": context_used, "turn": i + 1 }) print(f"{char}: {response}") if context_used: print(f" [使用上下文: {', '.join(context_used)}]") print() return results def main(): """主函数 - 演示系统使用""" print("=== RAG增强双AI角色对话系统 ===") # 设置路径 knowledge_dir = "./knowledge_base" # 包含世界观和角色文档的目录 # 检查必要文件 required_dirs = [knowledge_dir] for dir_path in required_dirs: if not os.path.exists(dir_path): print(f"✗ 目录不存在: {dir_path}") print("请确保以下文件存在:") print("- ./knowledge_base/worldview_template_coc.json") print("- ./knowledge_base/character_template_detective.json") print("- ./knowledge_base/character_template_professor.json") return try: # 初始化系统组件 print("\n初始化系统...") kb = RAGKnowledgeBase(knowledge_dir) conv_mgr = ConversationManager() # 这里需要你的LLM生成器,使用现有的NPCDialogueGenerator from npc_dialogue_generator import NPCDialogueGenerator base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' # 根据你的路径调整 lora_model_path = './output/NPC_Dialogue_LoRA/final_model' if not os.path.exists(lora_model_path): lora_model_path = None llm_generator = NPCDialogueGenerator(base_model_path, lora_model_path) # 创建对话引擎 dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator) print("✓ 系统初始化完成") # 交互式菜单 while True: print("\n" + "="*50) print("双AI角色对话系统") print("1. 创建新对话") print("2. 继续已有对话") print("3. 查看对话历史") print("4. 列出所有会话") print("0. 退出") print("="*50) choice = input("请选择操作: ").strip() if choice == '0': break elif choice == '1': # 创建新对话 print(f"可用角色: {list(kb.character_data.keys())}") characters = input("请输入两个角色名称(用空格分隔): ").strip().split() if len(characters) != 2: print("❌ 请输入正好两个角色名称") continue worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观' session_id = conv_mgr.create_session(characters, worldview) topic = input("请输入对话主题(可选): ").strip() turns = int(input("请输入对话轮次数量(默认2): ").strip() or "2") print(f"\n开始对话 - 会话ID: {session_id}") dialogue_engine.run_conversation_turn(session_id, characters, turns, topic) elif choice == '2': # 继续已有对话 sessions = conv_mgr.list_sessions() if not sessions: print("❌ 没有已有对话") continue print("已有会话:") for i, session in enumerate(sessions[:5]): chars = ", ".join(session['characters']) print(f"{i+1}. {session['session_id'][:8]}... ({chars}) - {session['last_update'][:16]}") try: idx = int(input("请选择会话编号: ").strip()) - 1 if 0 <= idx < len(sessions): session = sessions[idx] session_id = session['session_id'] characters = session['characters'] # 显示最近的对话 history = conv_mgr.get_conversation_history(session_id, 4) if history: print("\n最近的对话:") for turn in history: print(f"{turn.speaker}: {turn.content}") topic = input("请输入对话主题(可选): ").strip() turns = int(input("请输入对话轮次数量(默认1): ").strip() or "1") print(f"\n继续对话 - 会话ID: {session_id}") dialogue_engine.run_conversation_turn(session_id, characters, turns, topic) else: print("❌ 无效的会话编号") except ValueError: print("❌ 请输入有效的数字") elif choice == '3': # 查看对话历史 session_id = input("请输入会话ID(前8位即可): ").strip() # 查找匹配的会话 sessions = conv_mgr.list_sessions() matching_session = None for session in sessions: if session['session_id'].startswith(session_id): matching_session = session break if matching_session: full_session_id = matching_session['session_id'] history = conv_mgr.get_conversation_history(full_session_id, 20) if history: print(f"\n对话历史 - {full_session_id}") print(f"角色: {', '.join(matching_session['characters'])}") print(f"世界观: {matching_session['worldview']}") print("-" * 50) for turn in history: print(f"[{turn.timestamp[:16]}] {turn.speaker}:") print(f" {turn.content}") if turn.context_used: print(f" 使用上下文: {', '.join(turn.context_used)}") print() else: print("该会话暂无对话历史") else: print("❌ 未找到匹配的会话") elif choice == '4': # 列出所有会话 sessions = conv_mgr.list_sessions() if sessions: print(f"\n共有 {len(sessions)} 个对话会话:") for session in sessions: chars = ", ".join(session['characters']) print(f"ID: {session['session_id']}") print(f" 角色: {chars}") print(f" 世界观: {session['worldview']}") print(f" 最后更新: {session['last_update']}") print() else: print("暂无对话会话") else: print("❌ 无效选择") except Exception as e: print(f"✗ 系统运行出错: {e}") import traceback traceback.print_exc() if __name__ == '__main__': main()