diff --git a/AITrain/dual_ai_dialogue_system.py b/AITrain/dual_ai_dialogue_system.py new file mode 100644 index 0000000..e56e5e1 --- /dev/null +++ b/AITrain/dual_ai_dialogue_system.py @@ -0,0 +1,590 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +''' +RAG增强的角色对话系统 +集成世界观知识库,支持角色设定加载和对话生成 +''' + +import json +import os +import sqlite3 +from datetime import datetime +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass, asdict +import hashlib + +# 尝试导入向量化相关库 +try: + from sentence_transformers import SentenceTransformer + import faiss + import numpy as np + EMBEDDING_AVAILABLE = True +except ImportError: + EMBEDDING_AVAILABLE = False + +@dataclass +class DialogueTurn: + """对话轮次数据结构""" + speaker: str + content: str + timestamp: str + context_used: List[str] # 使用的上下文信息 + relevance_score: float = 0.0 + +@dataclass +class ConversationSession: + """对话会话数据结构""" + session_id: str + characters: List[str] + worldview: str + start_time: str + last_update: str + dialogue_history: List[DialogueTurn] + +class RAGKnowledgeBase: + """RAG知识库管理器""" + + def __init__(self, knowledge_dir: str): + self.knowledge_dir = knowledge_dir + self.worldview_data = None + self.character_data = {} + self.chunks = [] + self.embedding_model = None + self.index = None + + # 初始化向量模型 + if EMBEDDING_AVAILABLE: + try: + self.embedding_model = SentenceTransformer('./sentence-transformers/all-MiniLM-L6-v2') + print("✓ 向量模型加载成功") + except Exception as e: + print(f"✗ 向量模型加载失败: {e}") + + self._load_knowledge_base() + + def _load_knowledge_base(self): + """加载知识库""" + # 加载世界观 + worldview_files = [f for f in os.listdir(self.knowledge_dir) + if f.startswith('worldview') and f.endswith('.json')] + if worldview_files: + worldview_path = os.path.join(self.knowledge_dir, worldview_files[0]) + with open(worldview_path, 'r', encoding='utf-8') as f: + self.worldview_data = json.load(f) + print(f"✓ 世界观加载成功: {self.worldview_data.get('worldview_name', '未知')}") + + # 加载角色数据 + character_files = [f for f in os.listdir(self.knowledge_dir) + if f.startswith('character') and f.endswith('.json')] + for char_file in character_files: + char_path = os.path.join(self.knowledge_dir, char_file) + with open(char_path, 'r', encoding='utf-8') as f: + char_data = json.load(f) + char_name = char_data.get('character_name', char_file) + self.character_data[char_name] = char_data + + print(f"✓ 角色加载成功: {list(self.character_data.keys())}") + + # 构建检索用的文本块 + self._build_searchable_chunks() + + # 构建向量索引 + if EMBEDDING_AVAILABLE and self.embedding_model: + self._build_vector_index() + + def _build_searchable_chunks(self): + """构建可检索的文本块""" + self.chunks = [] + + # 世界观相关文本块 + if self.worldview_data: + for section_key, section_data in self.worldview_data.items(): + if isinstance(section_data, dict): + for sub_key, sub_data in section_data.items(): + if isinstance(sub_data, (str, list)): + content = str(sub_data) + if len(content) > 50: # 只保留有意义的文本 + self.chunks.append({ + "type": "worldview", + "section": section_key, + "subsection": sub_key, + "content": content, + "metadata": {"source": "worldview"} + }) + + # 角色相关文本块 + for char_name, char_data in self.character_data.items(): + for section_key, section_data in char_data.items(): + if isinstance(section_data, dict): + for sub_key, sub_data in section_data.items(): + if isinstance(sub_data, (str, list)): + content = str(sub_data) + if len(content) > 30: + self.chunks.append({ + "type": "character", + "character": char_name, + "section": section_key, + "subsection": sub_key, + "content": content, + "metadata": {"source": char_name} + }) + + print(f"✓ 构建文本块: {len(self.chunks)} 个") + + def _build_vector_index(self): + """构建向量索引""" + try: + texts = [chunk["content"] for chunk in self.chunks] + embeddings = self.embedding_model.encode(texts) + + dimension = embeddings.shape[1] + self.index = faiss.IndexFlatL2(dimension) + self.index.add(embeddings.astype(np.float32)) + + print(f"✓ 向量索引构建成功: {dimension}维, {len(texts)}个向量") + except Exception as e: + print(f"✗ 向量索引构建失败: {e}") + + def search_relevant_context(self, query: str, character_name: str = None, top_k: int = 3) -> List[Dict]: + """搜索相关上下文""" + relevant_chunks = [] + + # 向量搜索 + if EMBEDDING_AVAILABLE and self.embedding_model and self.index: + try: + query_vector = self.embedding_model.encode([query]) + distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2) + + for distance, idx in zip(distances[0], indices[0]): + if idx < len(self.chunks): + chunk = self.chunks[idx].copy() + chunk["relevance_score"] = float(1 / (1 + distance)) + relevant_chunks.append(chunk) + except Exception as e: + print(f"向量搜索失败: {e}") + + # 文本搜索作为备选 + if not relevant_chunks: + query_lower = query.lower() + for chunk in self.chunks: + content_lower = chunk["content"].lower() + score = 0 + for word in query_lower.split(): + if word in content_lower: + score += content_lower.count(word) + + if score > 0: + chunk_copy = chunk.copy() + chunk_copy["relevance_score"] = score + relevant_chunks.append(chunk_copy) + + # 按相关性排序 + relevant_chunks.sort(key=lambda x: x["relevance_score"], reverse=True) + + # 优先返回特定角色的相关信息 + if character_name: + char_chunks = [c for c in relevant_chunks if c.get("character") == character_name] + other_chunks = [c for c in relevant_chunks if c.get("character") != character_name] + relevant_chunks = char_chunks + other_chunks + + return relevant_chunks[:top_k] + +class ConversationManager: + """对话管理器""" + + def __init__(self, db_path: str = "conversation_history.db"): + self.db_path = db_path + self._init_database() + + def _init_database(self): + """初始化对话历史数据库""" + with sqlite3.connect(self.db_path) as conn: + conn.execute(''' + CREATE TABLE IF NOT EXISTS conversations ( + session_id TEXT PRIMARY KEY, + characters TEXT, + worldview TEXT, + start_time TEXT, + last_update TEXT, + metadata TEXT + ) + ''') + + conn.execute(''' + CREATE TABLE IF NOT EXISTS dialogue_turns ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT, + turn_number INTEGER, + speaker TEXT, + content TEXT, + timestamp TEXT, + context_used TEXT, + relevance_score REAL, + FOREIGN KEY (session_id) REFERENCES conversations (session_id) + ) + ''') + conn.commit() + + def create_session(self, characters: List[str], worldview: str) -> str: + """创建新的对话会话""" + session_id = hashlib.md5(f"{'-'.join(characters)}-{datetime.now().isoformat()}".encode()).hexdigest()[:12] + + with sqlite3.connect(self.db_path) as conn: + conn.execute( + "INSERT INTO conversations (session_id, characters, worldview, start_time, last_update) VALUES (?, ?, ?, ?, ?)", + (session_id, json.dumps(characters), worldview, datetime.now().isoformat(), datetime.now().isoformat()) + ) + conn.commit() + + print(f"✓ 创建对话会话: {session_id}") + return session_id + + def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None, relevance_score: float = 0.0): + """添加对话轮次""" + if context_used is None: + context_used = [] + + with sqlite3.connect(self.db_path) as conn: + # 获取当前轮次数 + cursor = conn.execute("SELECT COUNT(*) FROM dialogue_turns WHERE session_id = ?", (session_id,)) + turn_number = cursor.fetchone()[0] + 1 + + # 插入对话轮次 + conn.execute( + """INSERT INTO dialogue_turns + (session_id, turn_number, speaker, content, timestamp, context_used, relevance_score) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + (session_id, turn_number, speaker, content, datetime.now().isoformat(), + json.dumps(context_used), relevance_score) + ) + + # 更新会话最后更新时间 + conn.execute( + "UPDATE conversations SET last_update = ? WHERE session_id = ?", + (datetime.now().isoformat(), session_id) + ) + conn.commit() + + def get_conversation_history(self, session_id: str, last_n: int = 10) -> List[DialogueTurn]: + """获取对话历史""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute( + """SELECT speaker, content, timestamp, context_used, relevance_score + FROM dialogue_turns + WHERE session_id = ? + ORDER BY turn_number DESC LIMIT ?""", + (session_id, last_n) + ) + + turns = [] + for row in cursor.fetchall(): + speaker, content, timestamp, context_used, relevance_score = row + turns.append(DialogueTurn( + speaker=speaker, + content=content, + timestamp=timestamp, + context_used=json.loads(context_used or "[]"), + relevance_score=relevance_score + )) + + return list(reversed(turns)) # 按时间正序返回 + + def list_sessions(self) -> List[Dict]: + """列出所有对话会话""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute( + "SELECT session_id, characters, worldview, start_time, last_update FROM conversations ORDER BY last_update DESC" + ) + + sessions = [] + for row in cursor.fetchall(): + session_id, characters, worldview, start_time, last_update = row + sessions.append({ + "session_id": session_id, + "characters": json.loads(characters), + "worldview": worldview, + "start_time": start_time, + "last_update": last_update + }) + + return sessions + +class DualAIDialogueEngine: + """双AI对话引擎""" + + def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator): + self.kb = knowledge_base + self.conv_mgr = conversation_manager + self.llm_generator = llm_generator + + def generate_character_prompt(self, character_name: str, context_info: List[Dict], dialogue_history: List[DialogueTurn]) -> str: + """为角色生成对话提示""" + char_data = self.kb.character_data.get(character_name, {}) + + # 基础角色设定 + prompt_parts = [] + prompt_parts.append(f"你是{character_name},具有以下设定:") + + if char_data.get('personality', {}).get('core_traits'): + traits = ", ".join(char_data['personality']['core_traits']) + prompt_parts.append(f"性格特点:{traits}") + + if char_data.get('speech_patterns', {}).get('sample_phrases'): + phrases = char_data['speech_patterns']['sample_phrases'][:3] + prompt_parts.append(f"说话风格示例:{'; '.join(phrases)}") + + # 当前情境 + if char_data.get('current_situation'): + situation = char_data['current_situation'] + prompt_parts.append(f"当前状态:{situation.get('current_mood', '')}") + + # 相关世界观信息 + if context_info: + prompt_parts.append("相关背景信息:") + for info in context_info[:2]: # 只使用最相关的2个信息 + content = info['content'][:200] + "..." if len(info['content']) > 200 else info['content'] + prompt_parts.append(f"- {content}") + + # 对话历史 + if dialogue_history: + prompt_parts.append("最近的对话:") + for turn in dialogue_history[-3:]: # 只使用最近的3轮对话 + prompt_parts.append(f"{turn.speaker}: {turn.content}") + + prompt_parts.append("\n请根据角色设定和上下文,生成符合角色特点的自然对话。回复应该在50-150字之间。") + + return "\n".join(prompt_parts) + + def generate_dialogue(self, session_id: str, current_speaker: str, topic_hint: str = "") -> Tuple[str, List[str]]: + """生成角色对话""" + # 获取对话历史 + dialogue_history = self.conv_mgr.get_conversation_history(session_id) + + # 构建搜索查询 + if dialogue_history: + # 基于最近的对话内容 + recent_content = " ".join([turn.content for turn in dialogue_history[-2:]]) + search_query = recent_content + " " + topic_hint + else: + # 首次对话 + search_query = f"{current_speaker} {topic_hint} introduction greeting" + + # 搜索相关上下文 + context_info = self.kb.search_relevant_context(search_query, current_speaker, 10) + + # 生成提示 + prompt = self.generate_character_prompt(current_speaker, context_info, dialogue_history) + + # 生成对话 + try: + response = self.llm_generator.generate_character_dialogue( + current_speaker, + prompt, + topic_hint or "请继续对话", + temperature=0.8, + max_new_tokens=150 + ) + + # 记录使用的上下文 + context_used = [f"{info['section']}.{info['subsection']}" for info in context_info] + avg_relevance = sum(info['relevance_score'] for info in context_info) / len(context_info) if context_info else 0.0 + + # 保存对话轮次 + self.conv_mgr.add_dialogue_turn( + session_id, current_speaker, response, context_used, avg_relevance + ) + + return response, context_used + + except Exception as e: + print(f"✗ 对话生成失败: {e}") + return f"[{current_speaker}暂时无法回应]", [] + + def run_conversation_turn(self, session_id: str, characters: List[str], turns_count: int = 1, topic: str = ""): + """运行对话轮次""" + results = [] + + for i in range(turns_count): + for char in characters: + response, context_used = self.generate_dialogue(session_id, char, topic) + results.append({ + "speaker": char, + "content": response, + "context_used": context_used, + "turn": i + 1 + }) + + print(f"{char}: {response}") + if context_used: + print(f" [使用上下文: {', '.join(context_used)}]") + print() + + return results + +def main(): + """主函数 - 演示系统使用""" + print("=== RAG增强双AI角色对话系统 ===") + + # 设置路径 + knowledge_dir = "./knowledge_base" # 包含世界观和角色文档的目录 + + # 检查必要文件 + required_dirs = [knowledge_dir] + for dir_path in required_dirs: + if not os.path.exists(dir_path): + print(f"✗ 目录不存在: {dir_path}") + print("请确保以下文件存在:") + print("- ./knowledge_base/worldview_template_coc.json") + print("- ./knowledge_base/character_template_detective.json") + print("- ./knowledge_base/character_template_professor.json") + return + + try: + # 初始化系统组件 + print("\n初始化系统...") + kb = RAGKnowledgeBase(knowledge_dir) + conv_mgr = ConversationManager() + + # 这里需要你的LLM生成器,使用现有的NPCDialogueGenerator + from npc_dialogue_generator import NPCDialogueGenerator + base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' # 根据你的路径调整 + lora_model_path = './output/NPC_Dialogue_LoRA/final_model' + + if not os.path.exists(lora_model_path): + lora_model_path = None + + llm_generator = NPCDialogueGenerator(base_model_path, lora_model_path) + + # 创建对话引擎 + dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator) + + print("✓ 系统初始化完成") + + # 交互式菜单 + while True: + print("\n" + "="*50) + print("双AI角色对话系统") + print("1. 创建新对话") + print("2. 继续已有对话") + print("3. 查看对话历史") + print("4. 列出所有会话") + print("0. 退出") + print("="*50) + + choice = input("请选择操作: ").strip() + + if choice == '0': + break + + elif choice == '1': + # 创建新对话 + print(f"可用角色: {list(kb.character_data.keys())}") + characters = input("请输入两个角色名称(用空格分隔): ").strip().split() + + if len(characters) != 2: + print("❌ 请输入正好两个角色名称") + continue + + worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观' + session_id = conv_mgr.create_session(characters, worldview) + + topic = input("请输入对话主题(可选): ").strip() + turns = int(input("请输入对话轮次数量(默认2): ").strip() or "2") + + print(f"\n开始对话 - 会话ID: {session_id}") + dialogue_engine.run_conversation_turn(session_id, characters, turns, topic) + + elif choice == '2': + # 继续已有对话 + sessions = conv_mgr.list_sessions() + if not sessions: + print("❌ 没有已有对话") + continue + + print("已有会话:") + for i, session in enumerate(sessions[:5]): + chars = ", ".join(session['characters']) + print(f"{i+1}. {session['session_id'][:8]}... ({chars}) - {session['last_update'][:16]}") + + try: + idx = int(input("请选择会话编号: ").strip()) - 1 + if 0 <= idx < len(sessions): + session = sessions[idx] + session_id = session['session_id'] + characters = session['characters'] + + # 显示最近的对话 + history = conv_mgr.get_conversation_history(session_id, 4) + if history: + print("\n最近的对话:") + for turn in history: + print(f"{turn.speaker}: {turn.content}") + + topic = input("请输入对话主题(可选): ").strip() + turns = int(input("请输入对话轮次数量(默认1): ").strip() or "1") + + print(f"\n继续对话 - 会话ID: {session_id}") + dialogue_engine.run_conversation_turn(session_id, characters, turns, topic) + else: + print("❌ 无效的会话编号") + except ValueError: + print("❌ 请输入有效的数字") + + elif choice == '3': + # 查看对话历史 + session_id = input("请输入会话ID(前8位即可): ").strip() + + # 查找匹配的会话 + sessions = conv_mgr.list_sessions() + matching_session = None + for session in sessions: + if session['session_id'].startswith(session_id): + matching_session = session + break + + if matching_session: + full_session_id = matching_session['session_id'] + history = conv_mgr.get_conversation_history(full_session_id, 20) + + if history: + print(f"\n对话历史 - {full_session_id}") + print(f"角色: {', '.join(matching_session['characters'])}") + print(f"世界观: {matching_session['worldview']}") + print("-" * 50) + + for turn in history: + print(f"[{turn.timestamp[:16]}] {turn.speaker}:") + print(f" {turn.content}") + if turn.context_used: + print(f" 使用上下文: {', '.join(turn.context_used)}") + print() + else: + print("该会话暂无对话历史") + else: + print("❌ 未找到匹配的会话") + + elif choice == '4': + # 列出所有会话 + sessions = conv_mgr.list_sessions() + if sessions: + print(f"\n共有 {len(sessions)} 个对话会话:") + for session in sessions: + chars = ", ".join(session['characters']) + print(f"ID: {session['session_id']}") + print(f" 角色: {chars}") + print(f" 世界观: {session['worldview']}") + print(f" 最后更新: {session['last_update']}") + print() + else: + print("暂无对话会话") + + else: + print("❌ 无效选择") + + except Exception as e: + print(f"✗ 系统运行出错: {e}") + import traceback + traceback.print_exc() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/AITrain/npc_dialogue_generator.py b/AITrain/npc_dialogue_generator.py new file mode 100644 index 0000000..d5a2b2b --- /dev/null +++ b/AITrain/npc_dialogue_generator.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +''' +游戏NPC角色对话生成器 +基于微调后的LoRA模型生成角色对话 +''' + +import torch +import json +import random +from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer +from typing import Dict, List, Optional +import platform + +# Windows multiprocessing兼容性修复 +if platform.system() == "Windows": + import multiprocessing + multiprocessing.set_start_method('spawn', force=True) + +class NPCDialogueGenerator: + def __init__(self, base_model_path: str, lora_model_path: Optional[str] = None): + """ + 初始化NPC对话生成器 + + Args: + base_model_path: 基础模型路径 + lora_model_path: LoRA模型路径(可选) + """ + self.base_model_path = base_model_path + self.lora_model_path = lora_model_path + self.model = None + self.tokenizer = None + self.character_profiles = self._load_character_profiles() + + self._load_model() + + def _load_character_profiles(self) -> Dict: + """加载角色画像数据""" + return { + "维多利亚·布莱克伍德": { + "name": "维多利亚·布莱克伍德", + "title": "神秘学专家", + "personality": ["理性分析", "谨慎小心", "实用主义", "思维缜密"], + "background": "拥有丰富神秘学知识和战斗经验的侦探,既是非凡者也是夏洛克·莫里亚蒂", + "speech_patterns": ["会使用专业术语", "经常进行逻辑分析", "对危险保持警告", "内心独白较多"], + "sample_dialogues": [ + "好奇往往是导致死亡的主要因素。", + "总之,我的任务到此为止。", + "这需要仔细分析才能得出结论。" + ] + }, + "阿奇博尔德·韦恩博士": { + "name": "阿奇博尔德·韦恩博士", + "title": "神秘学导师", + "personality": ["沉稳睿智", "言简意赅", "关怀学生", "经验丰富"], + "background": "神秘学领域的资深专家,经验极其丰富的导师,知识渊博", + "speech_patterns": ["话语简练但信息量大", "给予实用指导", "语调平和但权威", "关心但保持距离"], + "sample_dialogues": [ + "耐心是修炼的基础。", + "不要急于求成,稳扎稳打比什么都重要。", + "这种情况需要格外小心。" + ] + }, + "塔利姆": { + "name": "塔利姆", + "title": "文雅绅士", + "personality": ["礼貌尊敬", "有文化素养", "寻求帮助", "温和友善"], + "background": "受过良好教育的普通人,有一定的文学修养,遇到困难时会寻求专家帮助", + "speech_patterns": ["使用礼貌称谓", "表达困惑时措辞文雅", "会引用文学作品", "语气温和"], + "sample_dialogues": [ + "噢,尊敬的大侦探,你最近在忙碌什么?", + "这不是《罗密欧与朱丽叶》的故事!", + "我有个朋友遇到了困难..." + ] + }, + "艾伦": { + "name": "艾伦", + "title": "困扰的求助者", + "personality": ["焦虑不安", "详细描述", "半信半疑", "急需帮助"], + "background": "普通人,但最近遭遇了一系列神秘的厄运事件,怀疑受到诅咒", + "speech_patterns": ["情绪紧张", "会详细描述遭遇", "语气急切", "表现出恐惧"], + "sample_dialogues": [ + "最近我总是遭遇各种厄运...", + "我怀疑是不是受到了什么诅咒。", + "请帮帮我,我不知道该怎么办!" + ] + }, + "戴莉.西蒙妮": { + "name": "戴莉·西蒙妮", + "title": "专业调查员", + "personality": ["专业简洁", "直接明确", "严谨认真", "目标导向"], + "background": "负责调查神秘事件的专业人员,办事效率高,问题直接", + "speech_patterns": ["问题直接明确", "语气专业", "注重事实", "简洁有力"], + "sample_dialogues": [ + "请详细描述事件经过。", + "有什么证据可以证明?", + "这件事需要立即调查。" + ] + } + } + + def _load_model(self): + """加载模型和分词器""" + print(f"Loading tokenizer from: {self.base_model_path}") + self.tokenizer = AutoTokenizer.from_pretrained( + self.base_model_path, + use_fast=False, + trust_remote_code=True + ) + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + print(f"Loading base model from: {self.base_model_path}") + self.model = AutoModelForCausalLM.from_pretrained( + self.base_model_path, + device_map="auto", + torch_dtype=torch.bfloat16, + trust_remote_code=True + ) + + # 如果有LoRA模型,则加载 + if self.lora_model_path: + print(f"Loading LoRA weights from: {self.lora_model_path}") + self.model = PeftModel.from_pretrained(self.model, self.lora_model_path) + + def generate_character_dialogue( + self, + character_name: str, + context: str = "", + user_input: str = "", + temperature: float = 0.8, + max_new_tokens: int = 150, + top_p: float = 0.9 + ) -> str: + """ + 生成指定角色的对话 + + Args: + character_name: 角色名称 + context: 对话上下文 + user_input: 用户输入/触发内容 + temperature: 采样温度 + max_new_tokens: 最大生成token数 + top_p: 核采样参数 + + Returns: + 生成的对话内容 + """ + if character_name not in self.character_profiles: + raise ValueError(f"Unknown character: {character_name}") + + profile = self.character_profiles[character_name] + + # 构建系统提示 + system_prompt = self._build_system_prompt(profile, context) + + # 构建用户输入 + if not user_input: + user_input = "请说一段符合你角色设定的话。" + + # 准备消息 + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_input} + ] + + # 应用对话模板 + inputs = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + return_dict=True, + enable_thinking=False + ) + + # 移动到设备 + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + + # 生成对话 + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=temperature, + top_p=top_p, + pad_token_id=self.tokenizer.eos_token_id, + repetition_penalty=1.1 + ) + + # 解码输出 + response = outputs[0][inputs['input_ids'].shape[1]:] + dialogue = self.tokenizer.decode(response, skip_special_tokens=True).strip() + + return dialogue + + def _build_system_prompt(self, profile: Dict, context: str = "") -> str: + """构建系统提示""" + personality_str = "、".join(profile["personality"]) + speech_pattern_str = ";".join(profile["speech_patterns"]) + + system_prompt = f"""你是游戏中的NPC角色{profile["name"]}({profile["title"]})。 + 角色背景:{profile["background"]} + 性格特点:{personality_str} + 说话风格:{speech_pattern_str} + 请严格按照这个角色的设定来回应,保持角色的一致性和独特性。""" + if context: + system_prompt += f"\n\n当前情境:{context}" + return system_prompt + + def generate_dialogue_conversation(self, character1: str, character2: str, topic: str, turns: int = 4) -> List[Dict]: + """生成两个角色之间的对话 + + Args: + character1: 第一个角色 + character2: 第二个角色 + topic: 对话主题 + turns: 对话轮数 + + Returns: + 对话列表,每个元素包含speaker和dialogue + """ + conversation = [] + context = f"现在{character1}和{character2}在讨论关于{topic}的话题。" + + for turn in range(turns): + if turn % 2 == 0: + # character1 说话 + speaker = character1 + if turn == 0: + user_input = f"开始和{character2}讨论{topic}这个话题。" + else: + # 基于上一轮对话内容 + last_dialogue = conversation[-1]["dialogue"] + user_input = f"{character2}刚才说:\"{last_dialogue}\"。请回应。" + else: + # character2 说话 + speaker = character2 + last_dialogue = conversation[-1]["dialogue"] + user_input = f"{character1}刚才说:\"{last_dialogue}\"。请回应。" + + dialogue = self.generate_character_dialogue( + speaker, context, user_input, temperature=0.8 + ) + + conversation.append({ + "speaker": speaker, + "dialogue": dialogue + }) + + return conversation + + def get_character_info(self, character_name: str) -> Dict: + """获取角色信息""" + return self.character_profiles.get(character_name, {}) + + def list_available_characters(self) -> List[str]: + """列出所有可用角色""" + return list(self.character_profiles.keys()) + + def main(): + """测试对话生成器""" + # 配置路径 + base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ' + lora_model_path = './output/NPC_Dialogue_LoRA/final_model' # 如果没有训练LoRA,设为None + + # 检查LoRA模型是否存在 + import os + if not os.path.exists(lora_model_path): + print("LoRA模型不存在,使用基础模型") + lora_model_path = None + + # 创建对话生成器 + generator = NPCDialogueGenerator(base_model_path, lora_model_path) + + print("=== 游戏NPC角色对话生成器 ===") + print(f"可用角色:{', '.join(generator.list_available_characters())}") + + # 测试单个角色对话生成 + print("\n=== 单角色对话测试 ===") + test_scenarios = [ + { + "character": "克莱恩", + "context": "玩家向你咨询神秘学知识", + "input": "请告诉我一些关于灵界的注意事项。" + }, + { + "character": "阿兹克", + "context": "学生遇到了修炼瓶颈", + "input": "导师,我在修炼中遇到了困难。" + }, + { + "character": "塔利姆", + "context": "在俱乐部偶遇老朋友", + "input": "好久不见,最近怎么样?" + } + ] + + for scenario in test_scenarios: + print(f"\n--- {scenario['character']} ---") + print(f"情境:{scenario['context']}") + print(f"输入:{scenario['input']}") + + dialogue = generator.generate_character_dialogue( + scenario["character"], + scenario["context"], + scenario["input"] + ) + print(f"回复:{dialogue}") + + # 测试角色间对话 + print("\n=== 角色间对话测试 ===") + conversation = generator.generate_dialogue_conversation( + "克莱恩", "塔利姆", "最近遇到的神秘事件", turns=4 + ) + + for turn in conversation: + print(f"{turn['speaker']}:{turn['dialogue']}") + + # 交互式对话模式 + print("\n=== 交互式对话模式 ===") + print("输入格式:角色名 上下文 用户输入") + print("例如:克莱恩 在俱乐部 请给我一些建议") + print("输入'quit'退出") + + while True: + try: + user_command = input("\n请输入指令: ").strip() + if user_command.lower() == 'quit': + break + + parts = user_command.split(' ', 2) + if len(parts) < 2: + print("格式错误,请使用:角色名 上下文 [用户输入]") + continue + + character = parts[0] + context = parts[1] + user_input = parts[2] if len(parts) > 2 else "" + + if character not in generator.list_available_characters(): + print(f"未知角色:{character}") + print(f"可用角色:{', '.join(generator.list_available_characters())}") + continue + + dialogue = generator.generate_character_dialogue( + character, context, user_input + ) + print(f"\n{character}:{dialogue}") + + except KeyboardInterrupt: + break + except Exception as e: + print(f"生成对话时出错:{e}") + + print("\n对话生成器已退出") + + if __name__ == '__main__': + main() \ No newline at end of file