调整双ai对话

This commit is contained in:
997146918 2025-08-18 09:55:18 +08:00
parent 82427b08ce
commit c5b81be7fd
3 changed files with 272 additions and 272 deletions

View File

@ -613,232 +613,232 @@ class DualAIDialogueEngine:
return conversation_results return conversation_results
def main(): # def main():
"""主函数 - 演示系统使用""" # """主函数 - 演示系统使用"""
print("=== RAG增强双AI角色对话系统 ===") # print("=== RAG增强双AI角色对话系统 ===")
# 设置路径 # # 设置路径
knowledge_dir = "./knowledge_base" # 包含世界观和角色文档的目录 # knowledge_dir = "./knowledge_base" # 包含世界观和角色文档的目录
# 检查必要文件 # # 检查必要文件
required_dirs = [knowledge_dir] # required_dirs = [knowledge_dir]
for dir_path in required_dirs: # for dir_path in required_dirs:
if not os.path.exists(dir_path): # if not os.path.exists(dir_path):
print(f"✗ 目录不存在: {dir_path}") # print(f"✗ 目录不存在: {dir_path}")
print("请确保以下文件存在:") # print("请确保以下文件存在:")
print("- ./knowledge_base/worldview_template_coc.json") # print("- ./knowledge_base/worldview_template_coc.json")
print("- ./knowledge_base/character_template_detective.json") # print("- ./knowledge_base/character_template_detective.json")
print("- ./knowledge_base/character_template_professor.json") # print("- ./knowledge_base/character_template_professor.json")
return # return
try: # try:
# 初始化系统组件 # # 初始化系统组件
print("\n初始化系统...") # print("\n初始化系统...")
kb = RAGKnowledgeBase(knowledge_dir) # kb = RAGKnowledgeBase(knowledge_dir)
conv_mgr = ConversationManager() # conv_mgr = ConversationManager()
# 这里需要你的LLM生成器使用新的双模型对话系统 # # 这里需要你的LLM生成器使用新的双模型对话系统
from npc_dialogue_generator import DualModelDialogueGenerator # from npc_dialogue_generator import DualModelDialogueGenerator
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' # 根据你的路径调整 # base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' # 根据你的路径调整
lora_model_path = './output/NPC_Dialogue_LoRA/final_model' # lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
if not os.path.exists(lora_model_path): # if not os.path.exists(lora_model_path):
lora_model_path = None # lora_model_path = None
# 创建双模型对话生成器 # # 创建双模型对话生成器
if hasattr(kb, 'character_data') and len(kb.character_data) >= 2: # if hasattr(kb, 'character_data') and len(kb.character_data) >= 2:
print("✓ 使用knowledge_base角色数据创建双模型对话系统") # print("✓ 使用knowledge_base角色数据创建双模型对话系统")
# 获取前两个角色 # # 获取前两个角色
character_names = list(kb.character_data.keys())[:2] # character_names = list(kb.character_data.keys())[:2]
char1_name = character_names[0] # char1_name = character_names[0]
char2_name = character_names[1] # char2_name = character_names[1]
# 配置两个角色的模型 # # 配置两个角色的模型
character1_config = { # character1_config = {
"name": char1_name, # "name": char1_name,
"lora_path": lora_model_path, # 可以为每个角色设置不同的LoRA # "lora_path": lora_model_path, # 可以为每个角色设置不同的LoRA
"character_data": kb.character_data[char1_name] # "character_data": kb.character_data[char1_name]
} # }
character2_config = { # character2_config = {
"name": char2_name, # "name": char2_name,
"lora_path": lora_model_path, # 可以为每个角色设置不同的LoRA # "lora_path": lora_model_path, # 可以为每个角色设置不同的LoRA
"character_data": kb.character_data[char2_name] # "character_data": kb.character_data[char2_name]
} # }
llm_generator = DualModelDialogueGenerator( # llm_generator = DualModelDialogueGenerator(
base_model_path, # base_model_path,
character1_config, # character1_config,
character2_config # character2_config
) # )
else: # else:
print("⚠ 角色数据不足,无法创建双模型对话系统") # print("⚠ 角色数据不足,无法创建双模型对话系统")
return # return
# 创建对话引擎 # # 创建对话引擎
dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator) # dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator)
print("✓ 系统初始化完成") # print("✓ 系统初始化完成")
# 交互式菜单 # # 交互式菜单
while True: # while True:
print("\n" + "="*50) # print("\n" + "="*50)
print("双AI角色对话系统") # print("双AI角色对话系统")
print("1. 创建新对话") # print("1. 创建新对话")
print("2. 继续已有对话") # print("2. 继续已有对话")
print("3. 查看对话历史") # print("3. 查看对话历史")
print("4. 列出所有会话") # print("4. 列出所有会话")
print("0. 退出") # print("0. 退出")
print("="*50) # print("="*50)
choice = input("请选择操作: ").strip() # choice = input("请选择操作: ").strip()
if choice == '0': # if choice == '0':
break # break
elif choice == '1': # elif choice == '1':
# 创建新对话 # # 创建新对话
print(f"可用角色: {list(kb.character_data.keys())}") # print(f"可用角色: {list(kb.character_data.keys())}")
characters = input("请输入两个角色名称(用空格分隔): ").strip().split() # characters = input("请输入两个角色名称(用空格分隔): ").strip().split()
if len(characters) != 2: # if len(characters) != 2:
print("❌ 请输入正好两个角色名称") # print("❌ 请输入正好两个角色名称")
continue # continue
worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观' # worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观'
session_id = conv_mgr.create_session(characters, worldview) # session_id = conv_mgr.create_session(characters, worldview)
topic = input("请输入对话主题(可选): ").strip() # topic = input("请输入对话主题(可选): ").strip()
turns = int(input("请输入对话轮次数量默认2: ").strip() or "2") # turns = int(input("请输入对话轮次数量默认2: ").strip() or "2")
# 历史上下文控制选项 # # 历史上下文控制选项
print("\n历史上下文设置:") # print("\n历史上下文设置:")
history_count = input("使用历史对话轮数默认30表示不使用: ").strip() # history_count = input("使用历史对话轮数默认30表示不使用: ").strip()
history_count = int(history_count) if history_count.isdigit() else 3 # history_count = int(history_count) if history_count.isdigit() else 3
context_info_count = input("使用上下文信息数量默认2: ").strip() # context_info_count = input("使用上下文信息数量默认2: ").strip()
context_info_count = int(context_info_count) if context_info_count.isdigit() else 2 # context_info_count = int(context_info_count) if context_info_count.isdigit() else 2
print(f"\n开始对话 - 会话ID: {session_id}") # print(f"\n开始对话 - 会话ID: {session_id}")
print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}") # print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}个")
# 询问是否使用双模型对话 # # 询问是否使用双模型对话
use_dual_model = input("是否使用双模型对话系统?(y/n默认y): ").strip().lower() # use_dual_model = input("是否使用双模型对话系统?(y/n默认y): ").strip().lower()
if use_dual_model != 'n': # if use_dual_model != 'n':
print("使用双模型对话系统...") # print("使用双模型对话系统...")
dialogue_engine.run_dual_model_conversation(session_id, topic, turns, history_count, context_info_count) # dialogue_engine.run_dual_model_conversation(session_id, topic, turns, history_count, context_info_count)
else: # else:
print("使用传统对话系统...") # print("使用传统对话系统...")
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count) # dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
elif choice == '2': # elif choice == '2':
# 继续已有对话 # # 继续已有对话
sessions = conv_mgr.list_sessions() # sessions = conv_mgr.list_sessions()
if not sessions: # if not sessions:
print("❌ 没有已有对话") # print("❌ 没有已有对话")
continue # continue
print("已有会话:") # print("已有会话:")
for i, session in enumerate(sessions[:5]): # for i, session in enumerate(sessions[:5]):
chars = ", ".join(session['characters']) # chars = ", ".join(session['characters'])
print(f"{i+1}. {session['session_id'][:8]}... ({chars}) - {session['last_update'][:16]}") # print(f"{i+1}. {session['session_id'][:8]}... ({chars}) - {session['last_update'][:16]}")
try: # try:
idx = int(input("请选择会话编号: ").strip()) - 1 # idx = int(input("请选择会话编号: ").strip()) - 1
if 0 <= idx < len(sessions): # if 0 <= idx < len(sessions):
session = sessions[idx] # session = sessions[idx]
session_id = session['session_id'] # session_id = session['session_id']
characters = session['characters'] # characters = session['characters']
# 显示最近的对话 # # 显示最近的对话
history = conv_mgr.get_conversation_history(session_id, 4) # history = conv_mgr.get_conversation_history(session_id, 4)
if history: # if history:
print("\n最近的对话:") # print("\n最近的对话:")
for turn in history: # for turn in history:
print(f"{turn.speaker}: {turn.content}") # print(f"{turn.speaker}: {turn.content}")
topic = input("请输入对话主题(可选): ").strip() # topic = input("请输入对话主题(可选): ").strip()
turns = int(input("请输入对话轮次数量默认1: ").strip() or "1") # turns = int(input("请输入对话轮次数量默认1: ").strip() or "1")
# 历史上下文控制选项 # # 历史上下文控制选项
print("\n历史上下文设置:") # print("\n历史上下文设置:")
history_count = input("使用历史对话轮数默认30表示不使用: ").strip() # history_count = input("使用历史对话轮数默认30表示不使用: ").strip()
history_count = int(history_count) if history_count.isdigit() else 3 # history_count = int(history_count) if history_count.isdigit() else 3
context_info_count = input("使用上下文信息数量默认2: ").strip() # context_info_count = input("使用上下文信息数量默认2: ").strip()
context_info_count = int(context_info_count) if context_info_count.isdigit() else 2 # context_info_count = int(context_info_count) if context_info_count.isdigit() else 2
print(f"\n继续对话 - 会话ID: {session_id}") # print(f"\n继续对话 - 会话ID: {session_id}")
print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}") # print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}个")
# 询问是否使用双模型对话 # # 询问是否使用双模型对话
use_dual_model = input("是否使用双模型对话系统?(y/n默认y): ").strip().lower() # use_dual_model = input("是否使用双模型对话系统?(y/n默认y): ").strip().lower()
if use_dual_model != 'n': # if use_dual_model != 'n':
print("使用双模型对话系统...") # print("使用双模型对话系统...")
dialogue_engine.run_dual_model_conversation(session_id, topic, turns, history_count, context_info_count) # dialogue_engine.run_dual_model_conversation(session_id, topic, turns, history_count, context_info_count)
else: # else:
print("使用传统对话系统...") # print("使用传统对话系统...")
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count) # dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
else: # else:
print("❌ 无效的会话编号") # print("❌ 无效的会话编号")
except ValueError: # except ValueError:
print("❌ 请输入有效的数字") # print("❌ 请输入有效的数字")
elif choice == '3': # elif choice == '3':
# 查看对话历史 # # 查看对话历史
session_id = input("请输入会话ID前8位即可: ").strip() # session_id = input("请输入会话ID前8位即可: ").strip()
# 查找匹配的会话 # # 查找匹配的会话
sessions = conv_mgr.list_sessions() # sessions = conv_mgr.list_sessions()
matching_session = None # matching_session = None
for session in sessions: # for session in sessions:
if session['session_id'].startswith(session_id): # if session['session_id'].startswith(session_id):
matching_session = session # matching_session = session
break # break
if matching_session: # if matching_session:
full_session_id = matching_session['session_id'] # full_session_id = matching_session['session_id']
history = conv_mgr.get_conversation_history(full_session_id, 20) # history = conv_mgr.get_conversation_history(full_session_id, 20)
if history: # if history:
print(f"\n对话历史 - {full_session_id}") # print(f"\n对话历史 - {full_session_id}")
print(f"角色: {', '.join(matching_session['characters'])}") # print(f"角色: {', '.join(matching_session['characters'])}")
print(f"世界观: {matching_session['worldview']}") # print(f"世界观: {matching_session['worldview']}")
print("-" * 50) # print("-" * 50)
for turn in history: # for turn in history:
print(f"[{turn.timestamp[:16]}] {turn.speaker}:") # print(f"[{turn.timestamp[:16]}] {turn.speaker}:")
print(f" {turn.content}") # print(f" {turn.content}")
if turn.context_used: # if turn.context_used:
print(f" 使用上下文: {', '.join(turn.context_used)}") # print(f" 使用上下文: {', '.join(turn.context_used)}")
print() # print()
else: # else:
print("该会话暂无对话历史") # print("该会话暂无对话历史")
else: # else:
print("❌ 未找到匹配的会话") # print("❌ 未找到匹配的会话")
elif choice == '4': # elif choice == '4':
# 列出所有会话 # # 列出所有会话
sessions = conv_mgr.list_sessions() # sessions = conv_mgr.list_sessions()
if sessions: # if sessions:
print(f"\n共有 {len(sessions)} 个对话会话:") # print(f"\n共有 {len(sessions)} 个对话会话:")
for session in sessions: # for session in sessions:
chars = ", ".join(session['characters']) # chars = ", ".join(session['characters'])
print(f"ID: {session['session_id']}") # print(f"ID: {session['session_id']}")
print(f" 角色: {chars}") # print(f" 角色: {chars}")
print(f" 世界观: {session['worldview']}") # print(f" 世界观: {session['worldview']}")
print(f" 最后更新: {session['last_update']}") # print(f" 最后更新: {session['last_update']}")
print() # print()
else: # else:
print("暂无对话会话") # print("暂无对话会话")
else: # else:
print("❌ 无效选择") # print("❌ 无效选择")
except Exception as e: # except Exception as e:
print(f"✗ 系统运行出错: {e}") # print(f"✗ 系统运行出错: {e}")
import traceback # import traceback
traceback.print_exc() # traceback.print_exc()
if __name__ == '__main__': # if __name__ == '__main__':
main() # main()

View File

@ -224,8 +224,8 @@ def run_dialogue_system():
history_input = input("使用历史对话轮数 (默认3): ").strip() history_input = input("使用历史对话轮数 (默认3): ").strip()
history_count = int(history_input) if history_input.isdigit() else 3 history_count = int(history_input) if history_input.isdigit() else 3
context_input = input("使用上下文信息数量 (默认2): ").strip() context_input = input("使用上下文信息数量 (默认10): ").strip()
context_info_count = int(context_input) if context_input.isdigit() else 2 context_info_count = int(context_input) if context_input.isdigit() else 10
print(f"\n开始对话 - 主题: {user_input}") print(f"\n开始对话 - 主题: {user_input}")
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}") print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")

View File

@ -435,101 +435,101 @@ class DualModelDialogueGenerator:
"""列出两个角色名称""" """列出两个角色名称"""
return [self.character1_config['name'], self.character2_config['name']] return [self.character1_config['name'], self.character2_config['name']]
def main(): # def main():
"""测试对话生成器""" # """测试对话生成器"""
# 配置路径 # # 配置路径
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ'
lora_model_path = './output/NPC_Dialogue_LoRA/final_model' # 如果没有训练LoRA设为None # lora_model_path = './output/NPC_Dialogue_LoRA/final_model' # 如果没有训练LoRA设为None
# 检查LoRA模型是否存在 # # 检查LoRA模型是否存在
if not os.path.exists(lora_model_path): # if not os.path.exists(lora_model_path):
print("LoRA模型不存在使用基础模型") # print("LoRA模型不存在使用基础模型")
lora_model_path = None # lora_model_path = None
# 创建对话生成器 # # 创建对话生成器
generator = NPCDialogueGenerator(base_model_path, lora_model_path) # generator = NPCDialogueGenerator(base_model_path, lora_model_path)
print("=== 游戏NPC角色对话生成器 ===") # print("=== 游戏NPC角色对话生成器 ===")
print(f"可用角色:{', '.join(generator.list_available_characters())}") # print(f"可用角色:{', '.join(generator.list_available_characters())}")
# 测试单个角色对话生成 # # 测试单个角色对话生成
print("\n=== 单角色对话测试 ===") # print("\n=== 单角色对话测试 ===")
test_scenarios = [ # test_scenarios = [
{ # {
"character": "克莱恩", # "character": "克莱恩",
"context": "玩家向你咨询神秘学知识", # "context": "玩家向你咨询神秘学知识",
"input": "请告诉我一些关于灵界的注意事项。" # "input": "请告诉我一些关于灵界的注意事项。"
}, # },
{ # {
"character": "阿兹克", # "character": "阿兹克",
"context": "学生遇到了修炼瓶颈", # "context": "学生遇到了修炼瓶颈",
"input": "导师,我在修炼中遇到了困难。" # "input": "导师,我在修炼中遇到了困难。"
}, # },
{ # {
"character": "塔利姆", # "character": "塔利姆",
"context": "在俱乐部偶遇老朋友", # "context": "在俱乐部偶遇老朋友",
"input": "好久不见,最近怎么样?" # "input": "好久不见,最近怎么样?"
} # }
] # ]
for scenario in test_scenarios: # for scenario in test_scenarios:
print(f"\n--- {scenario['character']} ---") # print(f"\n--- {scenario['character']} ---")
print(f"情境:{scenario['context']}") # print(f"情境:{scenario['context']}")
print(f"输入:{scenario['input']}") # print(f"输入:{scenario['input']}")
dialogue = generator.generate_character_dialogue( # dialogue = generator.generate_character_dialogue(
scenario["character"], # scenario["character"],
scenario["context"], # scenario["context"],
scenario["input"] # scenario["input"]
) # )
print(f"回复:{dialogue}") # print(f"回复:{dialogue}")
# 测试角色间对话 # # 测试角色间对话
print("\n=== 角色间对话测试 ===") # print("\n=== 角色间对话测试 ===")
conversation = generator.generate_dialogue_conversation( # conversation = generator.generate_dialogue_conversation(
"克莱恩", "塔利姆", "最近遇到的神秘事件", turns=4 # "克莱恩", "塔利姆", "最近遇到的神秘事件", turns=4
) # )
for turn in conversation: # for turn in conversation:
print(f"{turn['speaker']}{turn['dialogue']}") # print(f"{turn['speaker']}{turn['dialogue']}")
# 交互式对话模式 # # 交互式对话模式
print("\n=== 交互式对话模式 ===") # print("\n=== 交互式对话模式 ===")
print("输入格式:角色名 上下文 用户输入") # print("输入格式:角色名 上下文 用户输入")
print("例如:克莱恩 在俱乐部 请给我一些建议") # print("例如:克莱恩 在俱乐部 请给我一些建议")
print("输入'quit'退出") # print("输入'quit'退出")
while True: # while True:
try: # try:
user_command = input("\n请输入指令: ").strip() # user_command = input("\n请输入指令: ").strip()
if user_command.lower() == 'quit': # if user_command.lower() == 'quit':
break # break
parts = user_command.split(' ', 2) # parts = user_command.split(' ', 2)
if len(parts) < 2: # if len(parts) < 2:
print("格式错误,请使用:角色名 上下文 [用户输入]") # print("格式错误,请使用:角色名 上下文 [用户输入]")
continue # continue
character = parts[0] # character = parts[0]
context = parts[1] # context = parts[1]
user_input = parts[2] if len(parts) > 2 else "" # user_input = parts[2] if len(parts) > 2 else ""
if character not in generator.list_available_characters(): # if character not in generator.list_available_characters():
print(f"未知角色:{character}") # print(f"未知角色:{character}")
print(f"可用角色:{', '.join(generator.list_available_characters())}") # print(f"可用角色:{', '.join(generator.list_available_characters())}")
continue # continue
dialogue = generator.generate_character_dialogue( # dialogue = generator.generate_character_dialogue(
character, context, user_input # character, context, user_input
) # )
print(f"\n{character}{dialogue}") # print(f"\n{character}{dialogue}")
except KeyboardInterrupt: # except KeyboardInterrupt:
break # break
except Exception as e: # except Exception as e:
print(f"生成对话时出错:{e}") # print(f"生成对话时出错:{e}")
print("\n对话生成器已退出") # print("\n对话生成器已退出")
if __name__ == '__main__': # if __name__ == '__main__':
main() # main()