添加对话生成

This commit is contained in:
997146918 2025-08-14 07:17:50 +08:00
parent 9206ef9352
commit bf6b7451cb
2 changed files with 952 additions and 0 deletions

View File

@ -0,0 +1,590 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
RAG增强的角色对话系统
集成世界观知识库支持角色设定加载和对话生成
'''
import json
import os
import sqlite3
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, asdict
import hashlib
# 尝试导入向量化相关库
try:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
EMBEDDING_AVAILABLE = True
except ImportError:
EMBEDDING_AVAILABLE = False
@dataclass
class DialogueTurn:
"""对话轮次数据结构"""
speaker: str
content: str
timestamp: str
context_used: List[str] # 使用的上下文信息
relevance_score: float = 0.0
@dataclass
class ConversationSession:
"""对话会话数据结构"""
session_id: str
characters: List[str]
worldview: str
start_time: str
last_update: str
dialogue_history: List[DialogueTurn]
class RAGKnowledgeBase:
"""RAG知识库管理器"""
def __init__(self, knowledge_dir: str):
self.knowledge_dir = knowledge_dir
self.worldview_data = None
self.character_data = {}
self.chunks = []
self.embedding_model = None
self.index = None
# 初始化向量模型
if EMBEDDING_AVAILABLE:
try:
self.embedding_model = SentenceTransformer('./sentence-transformers/all-MiniLM-L6-v2')
print("✓ 向量模型加载成功")
except Exception as e:
print(f"✗ 向量模型加载失败: {e}")
self._load_knowledge_base()
def _load_knowledge_base(self):
"""加载知识库"""
# 加载世界观
worldview_files = [f for f in os.listdir(self.knowledge_dir)
if f.startswith('worldview') and f.endswith('.json')]
if worldview_files:
worldview_path = os.path.join(self.knowledge_dir, worldview_files[0])
with open(worldview_path, 'r', encoding='utf-8') as f:
self.worldview_data = json.load(f)
print(f"✓ 世界观加载成功: {self.worldview_data.get('worldview_name', '未知')}")
# 加载角色数据
character_files = [f for f in os.listdir(self.knowledge_dir)
if f.startswith('character') and f.endswith('.json')]
for char_file in character_files:
char_path = os.path.join(self.knowledge_dir, char_file)
with open(char_path, 'r', encoding='utf-8') as f:
char_data = json.load(f)
char_name = char_data.get('character_name', char_file)
self.character_data[char_name] = char_data
print(f"✓ 角色加载成功: {list(self.character_data.keys())}")
# 构建检索用的文本块
self._build_searchable_chunks()
# 构建向量索引
if EMBEDDING_AVAILABLE and self.embedding_model:
self._build_vector_index()
def _build_searchable_chunks(self):
"""构建可检索的文本块"""
self.chunks = []
# 世界观相关文本块
if self.worldview_data:
for section_key, section_data in self.worldview_data.items():
if isinstance(section_data, dict):
for sub_key, sub_data in section_data.items():
if isinstance(sub_data, (str, list)):
content = str(sub_data)
if len(content) > 50: # 只保留有意义的文本
self.chunks.append({
"type": "worldview",
"section": section_key,
"subsection": sub_key,
"content": content,
"metadata": {"source": "worldview"}
})
# 角色相关文本块
for char_name, char_data in self.character_data.items():
for section_key, section_data in char_data.items():
if isinstance(section_data, dict):
for sub_key, sub_data in section_data.items():
if isinstance(sub_data, (str, list)):
content = str(sub_data)
if len(content) > 30:
self.chunks.append({
"type": "character",
"character": char_name,
"section": section_key,
"subsection": sub_key,
"content": content,
"metadata": {"source": char_name}
})
print(f"✓ 构建文本块: {len(self.chunks)}")
def _build_vector_index(self):
"""构建向量索引"""
try:
texts = [chunk["content"] for chunk in self.chunks]
embeddings = self.embedding_model.encode(texts)
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(embeddings.astype(np.float32))
print(f"✓ 向量索引构建成功: {dimension}维, {len(texts)}个向量")
except Exception as e:
print(f"✗ 向量索引构建失败: {e}")
def search_relevant_context(self, query: str, character_name: str = None, top_k: int = 3) -> List[Dict]:
"""搜索相关上下文"""
relevant_chunks = []
# 向量搜索
if EMBEDDING_AVAILABLE and self.embedding_model and self.index:
try:
query_vector = self.embedding_model.encode([query])
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
for distance, idx in zip(distances[0], indices[0]):
if idx < len(self.chunks):
chunk = self.chunks[idx].copy()
chunk["relevance_score"] = float(1 / (1 + distance))
relevant_chunks.append(chunk)
except Exception as e:
print(f"向量搜索失败: {e}")
# 文本搜索作为备选
if not relevant_chunks:
query_lower = query.lower()
for chunk in self.chunks:
content_lower = chunk["content"].lower()
score = 0
for word in query_lower.split():
if word in content_lower:
score += content_lower.count(word)
if score > 0:
chunk_copy = chunk.copy()
chunk_copy["relevance_score"] = score
relevant_chunks.append(chunk_copy)
# 按相关性排序
relevant_chunks.sort(key=lambda x: x["relevance_score"], reverse=True)
# 优先返回特定角色的相关信息
if character_name:
char_chunks = [c for c in relevant_chunks if c.get("character") == character_name]
other_chunks = [c for c in relevant_chunks if c.get("character") != character_name]
relevant_chunks = char_chunks + other_chunks
return relevant_chunks[:top_k]
class ConversationManager:
"""对话管理器"""
def __init__(self, db_path: str = "conversation_history.db"):
self.db_path = db_path
self._init_database()
def _init_database(self):
"""初始化对话历史数据库"""
with sqlite3.connect(self.db_path) as conn:
conn.execute('''
CREATE TABLE IF NOT EXISTS conversations (
session_id TEXT PRIMARY KEY,
characters TEXT,
worldview TEXT,
start_time TEXT,
last_update TEXT,
metadata TEXT
)
''')
conn.execute('''
CREATE TABLE IF NOT EXISTS dialogue_turns (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT,
turn_number INTEGER,
speaker TEXT,
content TEXT,
timestamp TEXT,
context_used TEXT,
relevance_score REAL,
FOREIGN KEY (session_id) REFERENCES conversations (session_id)
)
''')
conn.commit()
def create_session(self, characters: List[str], worldview: str) -> str:
"""创建新的对话会话"""
session_id = hashlib.md5(f"{'-'.join(characters)}-{datetime.now().isoformat()}".encode()).hexdigest()[:12]
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"INSERT INTO conversations (session_id, characters, worldview, start_time, last_update) VALUES (?, ?, ?, ?, ?)",
(session_id, json.dumps(characters), worldview, datetime.now().isoformat(), datetime.now().isoformat())
)
conn.commit()
print(f"✓ 创建对话会话: {session_id}")
return session_id
def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None, relevance_score: float = 0.0):
"""添加对话轮次"""
if context_used is None:
context_used = []
with sqlite3.connect(self.db_path) as conn:
# 获取当前轮次数
cursor = conn.execute("SELECT COUNT(*) FROM dialogue_turns WHERE session_id = ?", (session_id,))
turn_number = cursor.fetchone()[0] + 1
# 插入对话轮次
conn.execute(
"""INSERT INTO dialogue_turns
(session_id, turn_number, speaker, content, timestamp, context_used, relevance_score)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(session_id, turn_number, speaker, content, datetime.now().isoformat(),
json.dumps(context_used), relevance_score)
)
# 更新会话最后更新时间
conn.execute(
"UPDATE conversations SET last_update = ? WHERE session_id = ?",
(datetime.now().isoformat(), session_id)
)
conn.commit()
def get_conversation_history(self, session_id: str, last_n: int = 10) -> List[DialogueTurn]:
"""获取对话历史"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"""SELECT speaker, content, timestamp, context_used, relevance_score
FROM dialogue_turns
WHERE session_id = ?
ORDER BY turn_number DESC LIMIT ?""",
(session_id, last_n)
)
turns = []
for row in cursor.fetchall():
speaker, content, timestamp, context_used, relevance_score = row
turns.append(DialogueTurn(
speaker=speaker,
content=content,
timestamp=timestamp,
context_used=json.loads(context_used or "[]"),
relevance_score=relevance_score
))
return list(reversed(turns)) # 按时间正序返回
def list_sessions(self) -> List[Dict]:
"""列出所有对话会话"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"SELECT session_id, characters, worldview, start_time, last_update FROM conversations ORDER BY last_update DESC"
)
sessions = []
for row in cursor.fetchall():
session_id, characters, worldview, start_time, last_update = row
sessions.append({
"session_id": session_id,
"characters": json.loads(characters),
"worldview": worldview,
"start_time": start_time,
"last_update": last_update
})
return sessions
class DualAIDialogueEngine:
"""双AI对话引擎"""
def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator):
self.kb = knowledge_base
self.conv_mgr = conversation_manager
self.llm_generator = llm_generator
def generate_character_prompt(self, character_name: str, context_info: List[Dict], dialogue_history: List[DialogueTurn]) -> str:
"""为角色生成对话提示"""
char_data = self.kb.character_data.get(character_name, {})
# 基础角色设定
prompt_parts = []
prompt_parts.append(f"你是{character_name},具有以下设定:")
if char_data.get('personality', {}).get('core_traits'):
traits = ", ".join(char_data['personality']['core_traits'])
prompt_parts.append(f"性格特点:{traits}")
if char_data.get('speech_patterns', {}).get('sample_phrases'):
phrases = char_data['speech_patterns']['sample_phrases'][:3]
prompt_parts.append(f"说话风格示例:{'; '.join(phrases)}")
# 当前情境
if char_data.get('current_situation'):
situation = char_data['current_situation']
prompt_parts.append(f"当前状态:{situation.get('current_mood', '')}")
# 相关世界观信息
if context_info:
prompt_parts.append("相关背景信息:")
for info in context_info[:2]: # 只使用最相关的2个信息
content = info['content'][:200] + "..." if len(info['content']) > 200 else info['content']
prompt_parts.append(f"- {content}")
# 对话历史
if dialogue_history:
prompt_parts.append("最近的对话:")
for turn in dialogue_history[-3:]: # 只使用最近的3轮对话
prompt_parts.append(f"{turn.speaker}: {turn.content}")
prompt_parts.append("\n请根据角色设定和上下文生成符合角色特点的自然对话。回复应该在50-150字之间。")
return "\n".join(prompt_parts)
def generate_dialogue(self, session_id: str, current_speaker: str, topic_hint: str = "") -> Tuple[str, List[str]]:
"""生成角色对话"""
# 获取对话历史
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
# 构建搜索查询
if dialogue_history:
# 基于最近的对话内容
recent_content = " ".join([turn.content for turn in dialogue_history[-2:]])
search_query = recent_content + " " + topic_hint
else:
# 首次对话
search_query = f"{current_speaker} {topic_hint} introduction greeting"
# 搜索相关上下文
context_info = self.kb.search_relevant_context(search_query, current_speaker, 10)
# 生成提示
prompt = self.generate_character_prompt(current_speaker, context_info, dialogue_history)
# 生成对话
try:
response = self.llm_generator.generate_character_dialogue(
current_speaker,
prompt,
topic_hint or "请继续对话",
temperature=0.8,
max_new_tokens=150
)
# 记录使用的上下文
context_used = [f"{info['section']}.{info['subsection']}" for info in context_info]
avg_relevance = sum(info['relevance_score'] for info in context_info) / len(context_info) if context_info else 0.0
# 保存对话轮次
self.conv_mgr.add_dialogue_turn(
session_id, current_speaker, response, context_used, avg_relevance
)
return response, context_used
except Exception as e:
print(f"✗ 对话生成失败: {e}")
return f"[{current_speaker}暂时无法回应]", []
def run_conversation_turn(self, session_id: str, characters: List[str], turns_count: int = 1, topic: str = ""):
"""运行对话轮次"""
results = []
for i in range(turns_count):
for char in characters:
response, context_used = self.generate_dialogue(session_id, char, topic)
results.append({
"speaker": char,
"content": response,
"context_used": context_used,
"turn": i + 1
})
print(f"{char}: {response}")
if context_used:
print(f" [使用上下文: {', '.join(context_used)}]")
print()
return results
def main():
"""主函数 - 演示系统使用"""
print("=== RAG增强双AI角色对话系统 ===")
# 设置路径
knowledge_dir = "./knowledge_base" # 包含世界观和角色文档的目录
# 检查必要文件
required_dirs = [knowledge_dir]
for dir_path in required_dirs:
if not os.path.exists(dir_path):
print(f"✗ 目录不存在: {dir_path}")
print("请确保以下文件存在:")
print("- ./knowledge_base/worldview_template_coc.json")
print("- ./knowledge_base/character_template_detective.json")
print("- ./knowledge_base/character_template_professor.json")
return
try:
# 初始化系统组件
print("\n初始化系统...")
kb = RAGKnowledgeBase(knowledge_dir)
conv_mgr = ConversationManager()
# 这里需要你的LLM生成器使用现有的NPCDialogueGenerator
from npc_dialogue_generator import NPCDialogueGenerator
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' # 根据你的路径调整
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
if not os.path.exists(lora_model_path):
lora_model_path = None
llm_generator = NPCDialogueGenerator(base_model_path, lora_model_path)
# 创建对话引擎
dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator)
print("✓ 系统初始化完成")
# 交互式菜单
while True:
print("\n" + "="*50)
print("双AI角色对话系统")
print("1. 创建新对话")
print("2. 继续已有对话")
print("3. 查看对话历史")
print("4. 列出所有会话")
print("0. 退出")
print("="*50)
choice = input("请选择操作: ").strip()
if choice == '0':
break
elif choice == '1':
# 创建新对话
print(f"可用角色: {list(kb.character_data.keys())}")
characters = input("请输入两个角色名称(用空格分隔): ").strip().split()
if len(characters) != 2:
print("❌ 请输入正好两个角色名称")
continue
worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观'
session_id = conv_mgr.create_session(characters, worldview)
topic = input("请输入对话主题(可选): ").strip()
turns = int(input("请输入对话轮次数量默认2: ").strip() or "2")
print(f"\n开始对话 - 会话ID: {session_id}")
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic)
elif choice == '2':
# 继续已有对话
sessions = conv_mgr.list_sessions()
if not sessions:
print("❌ 没有已有对话")
continue
print("已有会话:")
for i, session in enumerate(sessions[:5]):
chars = ", ".join(session['characters'])
print(f"{i+1}. {session['session_id'][:8]}... ({chars}) - {session['last_update'][:16]}")
try:
idx = int(input("请选择会话编号: ").strip()) - 1
if 0 <= idx < len(sessions):
session = sessions[idx]
session_id = session['session_id']
characters = session['characters']
# 显示最近的对话
history = conv_mgr.get_conversation_history(session_id, 4)
if history:
print("\n最近的对话:")
for turn in history:
print(f"{turn.speaker}: {turn.content}")
topic = input("请输入对话主题(可选): ").strip()
turns = int(input("请输入对话轮次数量默认1: ").strip() or "1")
print(f"\n继续对话 - 会话ID: {session_id}")
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic)
else:
print("❌ 无效的会话编号")
except ValueError:
print("❌ 请输入有效的数字")
elif choice == '3':
# 查看对话历史
session_id = input("请输入会话ID前8位即可: ").strip()
# 查找匹配的会话
sessions = conv_mgr.list_sessions()
matching_session = None
for session in sessions:
if session['session_id'].startswith(session_id):
matching_session = session
break
if matching_session:
full_session_id = matching_session['session_id']
history = conv_mgr.get_conversation_history(full_session_id, 20)
if history:
print(f"\n对话历史 - {full_session_id}")
print(f"角色: {', '.join(matching_session['characters'])}")
print(f"世界观: {matching_session['worldview']}")
print("-" * 50)
for turn in history:
print(f"[{turn.timestamp[:16]}] {turn.speaker}:")
print(f" {turn.content}")
if turn.context_used:
print(f" 使用上下文: {', '.join(turn.context_used)}")
print()
else:
print("该会话暂无对话历史")
else:
print("❌ 未找到匹配的会话")
elif choice == '4':
# 列出所有会话
sessions = conv_mgr.list_sessions()
if sessions:
print(f"\n共有 {len(sessions)} 个对话会话:")
for session in sessions:
chars = ", ".join(session['characters'])
print(f"ID: {session['session_id']}")
print(f" 角色: {chars}")
print(f" 世界观: {session['worldview']}")
print(f" 最后更新: {session['last_update']}")
print()
else:
print("暂无对话会话")
else:
print("❌ 无效选择")
except Exception as e:
print(f"✗ 系统运行出错: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,362 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
游戏NPC角色对话生成器
基于微调后的LoRA模型生成角色对话
'''
import torch
import json
import random
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Optional
import platform
# Windows multiprocessing兼容性修复
if platform.system() == "Windows":
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)
class NPCDialogueGenerator:
def __init__(self, base_model_path: str, lora_model_path: Optional[str] = None):
"""
初始化NPC对话生成器
Args:
base_model_path: 基础模型路径
lora_model_path: LoRA模型路径可选
"""
self.base_model_path = base_model_path
self.lora_model_path = lora_model_path
self.model = None
self.tokenizer = None
self.character_profiles = self._load_character_profiles()
self._load_model()
def _load_character_profiles(self) -> Dict:
"""加载角色画像数据"""
return {
"维多利亚·布莱克伍德": {
"name": "维多利亚·布莱克伍德",
"title": "神秘学专家",
"personality": ["理性分析", "谨慎小心", "实用主义", "思维缜密"],
"background": "拥有丰富神秘学知识和战斗经验的侦探,既是非凡者也是夏洛克·莫里亚蒂",
"speech_patterns": ["会使用专业术语", "经常进行逻辑分析", "对危险保持警告", "内心独白较多"],
"sample_dialogues": [
"好奇往往是导致死亡的主要因素。",
"总之,我的任务到此为止。",
"这需要仔细分析才能得出结论。"
]
},
"阿奇博尔德·韦恩博士": {
"name": "阿奇博尔德·韦恩博士",
"title": "神秘学导师",
"personality": ["沉稳睿智", "言简意赅", "关怀学生", "经验丰富"],
"background": "神秘学领域的资深专家,经验极其丰富的导师,知识渊博",
"speech_patterns": ["话语简练但信息量大", "给予实用指导", "语调平和但权威", "关心但保持距离"],
"sample_dialogues": [
"耐心是修炼的基础。",
"不要急于求成,稳扎稳打比什么都重要。",
"这种情况需要格外小心。"
]
},
"塔利姆": {
"name": "塔利姆",
"title": "文雅绅士",
"personality": ["礼貌尊敬", "有文化素养", "寻求帮助", "温和友善"],
"background": "受过良好教育的普通人,有一定的文学修养,遇到困难时会寻求专家帮助",
"speech_patterns": ["使用礼貌称谓", "表达困惑时措辞文雅", "会引用文学作品", "语气温和"],
"sample_dialogues": [
"噢,尊敬的大侦探,你最近在忙碌什么?",
"这不是《罗密欧与朱丽叶》的故事!",
"我有个朋友遇到了困难..."
]
},
"艾伦": {
"name": "艾伦",
"title": "困扰的求助者",
"personality": ["焦虑不安", "详细描述", "半信半疑", "急需帮助"],
"background": "普通人,但最近遭遇了一系列神秘的厄运事件,怀疑受到诅咒",
"speech_patterns": ["情绪紧张", "会详细描述遭遇", "语气急切", "表现出恐惧"],
"sample_dialogues": [
"最近我总是遭遇各种厄运...",
"我怀疑是不是受到了什么诅咒。",
"请帮帮我,我不知道该怎么办!"
]
},
"戴莉.西蒙妮": {
"name": "戴莉·西蒙妮",
"title": "专业调查员",
"personality": ["专业简洁", "直接明确", "严谨认真", "目标导向"],
"background": "负责调查神秘事件的专业人员,办事效率高,问题直接",
"speech_patterns": ["问题直接明确", "语气专业", "注重事实", "简洁有力"],
"sample_dialogues": [
"请详细描述事件经过。",
"有什么证据可以证明?",
"这件事需要立即调查。"
]
}
}
def _load_model(self):
"""加载模型和分词器"""
print(f"Loading tokenizer from: {self.base_model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model_path,
use_fast=False,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"Loading base model from: {self.base_model_path}")
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# 如果有LoRA模型则加载
if self.lora_model_path:
print(f"Loading LoRA weights from: {self.lora_model_path}")
self.model = PeftModel.from_pretrained(self.model, self.lora_model_path)
def generate_character_dialogue(
self,
character_name: str,
context: str = "",
user_input: str = "",
temperature: float = 0.8,
max_new_tokens: int = 150,
top_p: float = 0.9
) -> str:
"""
生成指定角色的对话
Args:
character_name: 角色名称
context: 对话上下文
user_input: 用户输入/触发内容
temperature: 采样温度
max_new_tokens: 最大生成token数
top_p: 核采样参数
Returns:
生成的对话内容
"""
if character_name not in self.character_profiles:
raise ValueError(f"Unknown character: {character_name}")
profile = self.character_profiles[character_name]
# 构建系统提示
system_prompt = self._build_system_prompt(profile, context)
# 构建用户输入
if not user_input:
user_input = "请说一段符合你角色设定的话。"
# 准备消息
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_input}
]
# 应用对话模板
inputs = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
enable_thinking=False
)
# 移动到设备
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# 生成对话
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1
)
# 解码输出
response = outputs[0][inputs['input_ids'].shape[1]:]
dialogue = self.tokenizer.decode(response, skip_special_tokens=True).strip()
return dialogue
def _build_system_prompt(self, profile: Dict, context: str = "") -> str:
"""构建系统提示"""
personality_str = "".join(profile["personality"])
speech_pattern_str = "".join(profile["speech_patterns"])
system_prompt = f"""你是游戏中的NPC角色{profile["name"]}{profile["title"]})。
角色背景{profile["background"]}
性格特点{personality_str}
说话风格{speech_pattern_str}
请严格按照这个角色的设定来回应保持角色的一致性和独特性"""
if context:
system_prompt += f"\n\n当前情境:{context}"
return system_prompt
def generate_dialogue_conversation(self, character1: str, character2: str, topic: str, turns: int = 4) -> List[Dict]:
"""生成两个角色之间的对话
Args:
character1: 第一个角色
character2: 第二个角色
topic: 对话主题
turns: 对话轮数
Returns:
对话列表每个元素包含speaker和dialogue
"""
conversation = []
context = f"现在{character1}{character2}在讨论关于{topic}的话题。"
for turn in range(turns):
if turn % 2 == 0:
# character1 说话
speaker = character1
if turn == 0:
user_input = f"开始和{character2}讨论{topic}这个话题。"
else:
# 基于上一轮对话内容
last_dialogue = conversation[-1]["dialogue"]
user_input = f"{character2}刚才说:\"{last_dialogue}\"。请回应。"
else:
# character2 说话
speaker = character2
last_dialogue = conversation[-1]["dialogue"]
user_input = f"{character1}刚才说:\"{last_dialogue}\"。请回应。"
dialogue = self.generate_character_dialogue(
speaker, context, user_input, temperature=0.8
)
conversation.append({
"speaker": speaker,
"dialogue": dialogue
})
return conversation
def get_character_info(self, character_name: str) -> Dict:
"""获取角色信息"""
return self.character_profiles.get(character_name, {})
def list_available_characters(self) -> List[str]:
"""列出所有可用角色"""
return list(self.character_profiles.keys())
def main():
"""测试对话生成器"""
# 配置路径
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ'
lora_model_path = './output/NPC_Dialogue_LoRA/final_model' # 如果没有训练LoRA设为None
# 检查LoRA模型是否存在
import os
if not os.path.exists(lora_model_path):
print("LoRA模型不存在使用基础模型")
lora_model_path = None
# 创建对话生成器
generator = NPCDialogueGenerator(base_model_path, lora_model_path)
print("=== 游戏NPC角色对话生成器 ===")
print(f"可用角色:{', '.join(generator.list_available_characters())}")
# 测试单个角色对话生成
print("\n=== 单角色对话测试 ===")
test_scenarios = [
{
"character": "克莱恩",
"context": "玩家向你咨询神秘学知识",
"input": "请告诉我一些关于灵界的注意事项。"
},
{
"character": "阿兹克",
"context": "学生遇到了修炼瓶颈",
"input": "导师,我在修炼中遇到了困难。"
},
{
"character": "塔利姆",
"context": "在俱乐部偶遇老朋友",
"input": "好久不见,最近怎么样?"
}
]
for scenario in test_scenarios:
print(f"\n--- {scenario['character']} ---")
print(f"情境:{scenario['context']}")
print(f"输入:{scenario['input']}")
dialogue = generator.generate_character_dialogue(
scenario["character"],
scenario["context"],
scenario["input"]
)
print(f"回复:{dialogue}")
# 测试角色间对话
print("\n=== 角色间对话测试 ===")
conversation = generator.generate_dialogue_conversation(
"克莱恩", "塔利姆", "最近遇到的神秘事件", turns=4
)
for turn in conversation:
print(f"{turn['speaker']}{turn['dialogue']}")
# 交互式对话模式
print("\n=== 交互式对话模式 ===")
print("输入格式:角色名 上下文 用户输入")
print("例如:克莱恩 在俱乐部 请给我一些建议")
print("输入'quit'退出")
while True:
try:
user_command = input("\n请输入指令: ").strip()
if user_command.lower() == 'quit':
break
parts = user_command.split(' ', 2)
if len(parts) < 2:
print("格式错误,请使用:角色名 上下文 [用户输入]")
continue
character = parts[0]
context = parts[1]
user_input = parts[2] if len(parts) > 2 else ""
if character not in generator.list_available_characters():
print(f"未知角色:{character}")
print(f"可用角色:{', '.join(generator.list_available_characters())}")
continue
dialogue = generator.generate_character_dialogue(
character, context, user_input
)
print(f"\n{character}{dialogue}")
except KeyboardInterrupt:
break
except Exception as e:
print(f"生成对话时出错:{e}")
print("\n对话生成器已退出")
if __name__ == '__main__':
main()