添加双ai 模型互相对话
This commit is contained in:
parent
8e32ec6689
commit
82427b08ce
@ -468,8 +468,20 @@ class DualAIDialogueEngine:
|
|||||||
context_info_count
|
context_info_count
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成对话
|
# 生成对话 - 使用双模型系统
|
||||||
try:
|
try:
|
||||||
|
# 检查是否为双模型对话系统
|
||||||
|
if hasattr(self.llm_generator, 'generate_dual_character_dialogue'):
|
||||||
|
# 使用双模型系统
|
||||||
|
response = self.llm_generator.generate_dual_character_dialogue(
|
||||||
|
current_speaker,
|
||||||
|
prompt,
|
||||||
|
topic_hint or "请继续对话",
|
||||||
|
temperature=0.8,
|
||||||
|
max_new_tokens=150
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 兼容旧的单模型系统
|
||||||
response = self.llm_generator.generate_character_dialogue(
|
response = self.llm_generator.generate_character_dialogue(
|
||||||
current_speaker,
|
current_speaker,
|
||||||
prompt,
|
prompt,
|
||||||
@ -535,6 +547,72 @@ class DualAIDialogueEngine:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def run_dual_model_conversation(self, session_id: str, topic: str = "", turns: int = 4,
|
||||||
|
history_context_count: int = 3, context_info_count: int = 2):
|
||||||
|
"""使用双模型系统运行对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
topic: 对话主题
|
||||||
|
turns: 对话轮数
|
||||||
|
history_context_count: 使用的历史对话轮数
|
||||||
|
context_info_count: 使用的上下文信息数量
|
||||||
|
"""
|
||||||
|
# 检查是否为双模型对话系统
|
||||||
|
if not hasattr(self.llm_generator, 'run_dual_character_conversation'):
|
||||||
|
print("⚠ 当前系统不支持双模型对话")
|
||||||
|
return self.run_conversation_turn(session_id, self.llm_generator.list_characters(), turns, topic,
|
||||||
|
history_context_count, context_info_count)
|
||||||
|
|
||||||
|
# 获取对话历史
|
||||||
|
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
|
||||||
|
|
||||||
|
# 构建上下文信息
|
||||||
|
if dialogue_history:
|
||||||
|
recent_turns = dialogue_history[-history_context_count:] if history_context_count > 0 else []
|
||||||
|
recent_content = " ".join([turn.content for turn in recent_turns])
|
||||||
|
search_query = recent_content + " " + topic
|
||||||
|
else:
|
||||||
|
search_query = f"{topic} introduction greeting"
|
||||||
|
|
||||||
|
# 搜索相关上下文
|
||||||
|
context_info = self.kb.search_relevant_context(search_query, top_k=context_info_count)
|
||||||
|
|
||||||
|
# 构建上下文字符串
|
||||||
|
context_str = ""
|
||||||
|
if context_info:
|
||||||
|
context_str = "相关背景信息:"
|
||||||
|
for info in context_info[:context_info_count]:
|
||||||
|
content = info['content'][:150] + "..." if len(info['content']) > 150 else info['content']
|
||||||
|
context_str += f"\n- {content}"
|
||||||
|
|
||||||
|
print(f"\n=== 双模型对话系统 ===")
|
||||||
|
print(f"主题: {topic}")
|
||||||
|
print(f"角色: {', '.join(self.llm_generator.list_characters())}")
|
||||||
|
print(f"轮数: {turns}")
|
||||||
|
print(f"上下文设置: 历史{history_context_count}轮, 信息{context_info_count}个")
|
||||||
|
|
||||||
|
# 使用双模型系统生成对话
|
||||||
|
conversation_results = self.llm_generator.run_dual_character_conversation(
|
||||||
|
topic=topic,
|
||||||
|
turns=turns,
|
||||||
|
context=context_str,
|
||||||
|
temperature=0.8,
|
||||||
|
max_new_tokens=150
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存对话到数据库
|
||||||
|
for result in conversation_results:
|
||||||
|
self.conv_mgr.add_dialogue_turn(
|
||||||
|
session_id,
|
||||||
|
result['speaker'],
|
||||||
|
result['dialogue'],
|
||||||
|
[result.get('context_used', '')],
|
||||||
|
0.8 # 默认相关性分数
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversation_results
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数 - 演示系统使用"""
|
"""主函数 - 演示系统使用"""
|
||||||
print("=== RAG增强双AI角色对话系统 ===")
|
print("=== RAG增强双AI角色对话系统 ===")
|
||||||
@ -559,21 +637,43 @@ def main():
|
|||||||
kb = RAGKnowledgeBase(knowledge_dir)
|
kb = RAGKnowledgeBase(knowledge_dir)
|
||||||
conv_mgr = ConversationManager()
|
conv_mgr = ConversationManager()
|
||||||
|
|
||||||
# 这里需要你的LLM生成器,使用现有的NPCDialogueGenerator
|
# 这里需要你的LLM生成器,使用新的双模型对话系统
|
||||||
from npc_dialogue_generator import NPCDialogueGenerator
|
from npc_dialogue_generator import DualModelDialogueGenerator
|
||||||
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' # 根据你的路径调整
|
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' # 根据你的路径调整
|
||||||
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
||||||
|
|
||||||
if not os.path.exists(lora_model_path):
|
if not os.path.exists(lora_model_path):
|
||||||
lora_model_path = None
|
lora_model_path = None
|
||||||
|
|
||||||
# 创建对话生成器并传入角色数据
|
# 创建双模型对话生成器
|
||||||
if hasattr(kb, 'character_data') and kb.character_data:
|
if hasattr(kb, 'character_data') and len(kb.character_data) >= 2:
|
||||||
print("✓ 使用knowledge_base角色数据创建对话生成器")
|
print("✓ 使用knowledge_base角色数据创建双模型对话系统")
|
||||||
llm_generator = NPCDialogueGenerator(base_model_path, lora_model_path, kb.character_data)
|
# 获取前两个角色
|
||||||
|
character_names = list(kb.character_data.keys())[:2]
|
||||||
|
char1_name = character_names[0]
|
||||||
|
char2_name = character_names[1]
|
||||||
|
|
||||||
|
# 配置两个角色的模型
|
||||||
|
character1_config = {
|
||||||
|
"name": char1_name,
|
||||||
|
"lora_path": lora_model_path, # 可以为每个角色设置不同的LoRA
|
||||||
|
"character_data": kb.character_data[char1_name]
|
||||||
|
}
|
||||||
|
|
||||||
|
character2_config = {
|
||||||
|
"name": char2_name,
|
||||||
|
"lora_path": lora_model_path, # 可以为每个角色设置不同的LoRA
|
||||||
|
"character_data": kb.character_data[char2_name]
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_generator = DualModelDialogueGenerator(
|
||||||
|
base_model_path,
|
||||||
|
character1_config,
|
||||||
|
character2_config
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("⚠ 使用内置角色数据创建对话生成器")
|
print("⚠ 角色数据不足,无法创建双模型对话系统")
|
||||||
llm_generator = NPCDialogueGenerator(base_model_path, lora_model_path)
|
return
|
||||||
|
|
||||||
# 创建对话引擎
|
# 创建对话引擎
|
||||||
dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator)
|
dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator)
|
||||||
@ -621,6 +721,14 @@ def main():
|
|||||||
|
|
||||||
print(f"\n开始对话 - 会话ID: {session_id}")
|
print(f"\n开始对话 - 会话ID: {session_id}")
|
||||||
print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}个")
|
print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}个")
|
||||||
|
|
||||||
|
# 询问是否使用双模型对话
|
||||||
|
use_dual_model = input("是否使用双模型对话系统?(y/n,默认y): ").strip().lower()
|
||||||
|
if use_dual_model != 'n':
|
||||||
|
print("使用双模型对话系统...")
|
||||||
|
dialogue_engine.run_dual_model_conversation(session_id, topic, turns, history_count, context_info_count)
|
||||||
|
else:
|
||||||
|
print("使用传统对话系统...")
|
||||||
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
|
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
|
||||||
|
|
||||||
elif choice == '2':
|
elif choice == '2':
|
||||||
@ -662,6 +770,14 @@ def main():
|
|||||||
|
|
||||||
print(f"\n继续对话 - 会话ID: {session_id}")
|
print(f"\n继续对话 - 会话ID: {session_id}")
|
||||||
print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}个")
|
print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}个")
|
||||||
|
|
||||||
|
# 询问是否使用双模型对话
|
||||||
|
use_dual_model = input("是否使用双模型对话系统?(y/n,默认y): ").strip().lower()
|
||||||
|
if use_dual_model != 'n':
|
||||||
|
print("使用双模型对话系统...")
|
||||||
|
dialogue_engine.run_dual_model_conversation(session_id, topic, turns, history_count, context_info_count)
|
||||||
|
else:
|
||||||
|
print("使用传统对话系统...")
|
||||||
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
|
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
|
||||||
else:
|
else:
|
||||||
print("❌ 无效的会话编号")
|
print("❌ 无效的会话编号")
|
||||||
|
|||||||
@ -122,14 +122,133 @@ def show_character_info():
|
|||||||
print(f"✗ 读取角色文件失败: {char_file} - {e}")
|
print(f"✗ 读取角色文件失败: {char_file} - {e}")
|
||||||
|
|
||||||
def run_dialogue_system():
|
def run_dialogue_system():
|
||||||
"""运行对话系统"""
|
"""运行双AI对话系统"""
|
||||||
print("\n" + "="*60)
|
print("\n" + "="*60)
|
||||||
print("启动双AI角色对话系统")
|
print("启动双AI角色对话系统")
|
||||||
print("="*60)
|
print("="*60)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from dual_ai_dialogue_system import main as dialogue_main
|
|
||||||
dialogue_main()
|
# 直接启动双模型对话
|
||||||
|
print("\n正在初始化双模型对话系统...")
|
||||||
|
|
||||||
|
from dual_ai_dialogue_system import RAGKnowledgeBase, ConversationManager, DualAIDialogueEngine
|
||||||
|
from npc_dialogue_generator import DualModelDialogueGenerator
|
||||||
|
|
||||||
|
# 初始化组件
|
||||||
|
kb = RAGKnowledgeBase("./knowledge_base")
|
||||||
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
||||||
|
|
||||||
|
# 检查模型路径
|
||||||
|
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B'
|
||||||
|
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
||||||
|
|
||||||
|
if not os.path.exists(base_model_path):
|
||||||
|
print(f"✗ 基础模型路径不存在: {base_model_path}")
|
||||||
|
print("请修改 main_controller.py 中的模型路径")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(lora_model_path):
|
||||||
|
lora_model_path = None
|
||||||
|
print("⚠ LoRA模型不存在,使用基础模型")
|
||||||
|
|
||||||
|
# 检查角色数据
|
||||||
|
if not hasattr(kb, 'character_data') or len(kb.character_data) < 2:
|
||||||
|
print("✗ 角色数据不足,无法创建双模型对话系统")
|
||||||
|
print("请确保knowledge_base目录中有至少两个角色文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取前两个角色
|
||||||
|
character_names = list(kb.character_data.keys())[:2]
|
||||||
|
char1_name = character_names[0]
|
||||||
|
char2_name = character_names[1]
|
||||||
|
|
||||||
|
print(f"✓ 使用角色: {char1_name} 和 {char2_name}")
|
||||||
|
|
||||||
|
# 配置两个角色的模型
|
||||||
|
character1_config = {
|
||||||
|
"name": char1_name,
|
||||||
|
"lora_path": lora_model_path,
|
||||||
|
"character_data": kb.character_data[char1_name]
|
||||||
|
}
|
||||||
|
|
||||||
|
character2_config = {
|
||||||
|
"name": char2_name,
|
||||||
|
"lora_path": lora_model_path,
|
||||||
|
"character_data": kb.character_data[char2_name]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建双模型对话生成器
|
||||||
|
print("正在初始化双模型对话生成器...")
|
||||||
|
dual_generator = DualModelDialogueGenerator(
|
||||||
|
base_model_path,
|
||||||
|
character1_config,
|
||||||
|
character2_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建对话引擎
|
||||||
|
dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, dual_generator)
|
||||||
|
|
||||||
|
# 创建对话会话
|
||||||
|
characters = [char1_name, char2_name]
|
||||||
|
worldview = kb.worldview_data.get('worldview_name', '未知世界观') if kb.worldview_data else '未知世界观'
|
||||||
|
|
||||||
|
session_id = conv_mgr.create_session(characters, worldview)
|
||||||
|
print(f"✓ 创建对话会话: {session_id}")
|
||||||
|
|
||||||
|
# 交互式对话循环
|
||||||
|
print(f"\n=== 双AI模型对话系统 ===")
|
||||||
|
print(f"角色: {char1_name} vs {char2_name}")
|
||||||
|
print(f"世界观: {worldview}")
|
||||||
|
print("输入 'quit' 退出对话")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# 获取用户输入
|
||||||
|
user_input = input("\n请输入对话主题或指令: ").strip()
|
||||||
|
|
||||||
|
if user_input.lower() == 'quit':
|
||||||
|
print("退出双AI对话系统")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not user_input:
|
||||||
|
print("请输入有效的对话主题")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 询问对话轮数
|
||||||
|
turns_input = input("请输入对话轮数 (默认4): ").strip()
|
||||||
|
turns = int(turns_input) if turns_input.isdigit() else 4
|
||||||
|
|
||||||
|
# 询问历史上下文设置
|
||||||
|
history_input = input("使用历史对话轮数 (默认3): ").strip()
|
||||||
|
history_count = int(history_input) if history_input.isdigit() else 3
|
||||||
|
|
||||||
|
context_input = input("使用上下文信息数量 (默认2): ").strip()
|
||||||
|
context_info_count = int(context_input) if context_input.isdigit() else 2
|
||||||
|
|
||||||
|
print(f"\n开始对话 - 主题: {user_input}")
|
||||||
|
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
# 运行双模型对话
|
||||||
|
dialogue_engine.run_dual_model_conversation(
|
||||||
|
session_id, user_input, turns, history_count, context_info_count
|
||||||
|
)
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
print("对话完成!")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n用户中断对话")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"对话过程中出现错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ 对话系统启动失败: {e}")
|
print(f"✗ 对话系统启动失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
@ -260,7 +379,7 @@ def main():
|
|||||||
print("主菜单 - 请选择操作:")
|
print("主菜单 - 请选择操作:")
|
||||||
print("1. 处理PDF世界观文档 (转换为RAG格式)")
|
print("1. 处理PDF世界观文档 (转换为RAG格式)")
|
||||||
print("2. 查看角色设定信息")
|
print("2. 查看角色设定信息")
|
||||||
print("3. 启动双AI对话系统")
|
print("3. 启动双AI对话系统 (支持双模型对话)")
|
||||||
print("4. 创建演示对话场景")
|
print("4. 创建演示对话场景")
|
||||||
print("5. 系统状态检查")
|
print("5. 系统状态检查")
|
||||||
print("6. 查看使用说明")
|
print("6. 查看使用说明")
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
'''
|
'''
|
||||||
游戏NPC角色对话生成器
|
游戏NPC角色对话生成器
|
||||||
基于微调后的LoRA模型生成角色对话
|
基于微调后的LoRA模型生成角色对话
|
||||||
|
支持双模型对话系统,每个模型扮演一个角色
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -10,8 +11,9 @@ import json
|
|||||||
import random
|
import random
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
import platform
|
import platform
|
||||||
|
import os
|
||||||
|
|
||||||
# Windows multiprocessing兼容性修复
|
# Windows multiprocessing兼容性修复
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
@ -37,9 +39,6 @@ class NPCDialogueGenerator:
|
|||||||
if external_character_data:
|
if external_character_data:
|
||||||
self.character_profiles = self._process_external_character_data(external_character_data)
|
self.character_profiles = self._process_external_character_data(external_character_data)
|
||||||
print(f"✓ 使用外部角色数据: {list(self.character_profiles.keys())}")
|
print(f"✓ 使用外部角色数据: {list(self.character_profiles.keys())}")
|
||||||
# else:
|
|
||||||
# self.character_profiles = self._load_character_profiles()
|
|
||||||
# print(f"✓ 使用内置角色数据: {list(self.character_profiles.keys())}")
|
|
||||||
|
|
||||||
self._load_model()
|
self._load_model()
|
||||||
|
|
||||||
@ -113,71 +112,6 @@ class NPCDialogueGenerator:
|
|||||||
"这太奇怪了。"
|
"这太奇怪了。"
|
||||||
]
|
]
|
||||||
|
|
||||||
# def _load_character_profiles(self) -> Dict:
|
|
||||||
# """加载角色画像数据"""
|
|
||||||
# return {
|
|
||||||
# "维多利亚·布莱克伍德": {
|
|
||||||
# "name": "维多利亚·布莱克伍德",
|
|
||||||
# "title": "神秘学专家",
|
|
||||||
# "personality": ["理性分析", "谨慎小心", "实用主义", "思维缜密"],
|
|
||||||
# "background": "拥有丰富神秘学知识和战斗经验的侦探,既是非凡者也是夏洛克·莫里亚蒂",
|
|
||||||
# "speech_patterns": ["会使用专业术语", "经常进行逻辑分析", "对危险保持警告", "内心独白较多"],
|
|
||||||
# "sample_dialogues": [
|
|
||||||
# "好奇往往是导致死亡的主要因素。",
|
|
||||||
# "总之,我的任务到此为止。",
|
|
||||||
# "这需要仔细分析才能得出结论。"
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# "阿奇博尔德·韦恩博士": {
|
|
||||||
# "name": "阿奇博尔德·韦恩博士",
|
|
||||||
# "title": "神秘学导师",
|
|
||||||
# "personality": ["沉稳睿智", "言简意赅", "关怀学生", "经验丰富"],
|
|
||||||
# "background": "神秘学领域的资深专家,经验极其丰富的导师,知识渊博",
|
|
||||||
# "speech_patterns": ["话语简练但信息量大", "给予实用指导", "语调平和但权威", "关心但保持距离"],
|
|
||||||
# "sample_dialogues": [
|
|
||||||
# "耐心是修炼的基础。",
|
|
||||||
# "不要急于求成,稳扎稳打比什么都重要。",
|
|
||||||
# "这种情况需要格外小心。"
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# "塔利姆": {
|
|
||||||
# "name": "塔利姆",
|
|
||||||
# "title": "文雅绅士",
|
|
||||||
# "personality": ["礼貌尊敬", "有文化素养", "寻求帮助", "温和友善"],
|
|
||||||
# "background": "受过良好教育的普通人,有一定的文学修养,遇到困难时会寻求专家帮助",
|
|
||||||
# "speech_patterns": ["使用礼貌称谓", "表达困惑时措辞文雅", "会引用文学作品", "语气温和"],
|
|
||||||
# "sample_dialogues": [
|
|
||||||
# "噢,尊敬的大侦探,你最近在忙碌什么?",
|
|
||||||
# "这不是《罗密欧与朱丽叶》的故事!",
|
|
||||||
# "我有个朋友遇到了困难..."
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# "艾伦": {
|
|
||||||
# "name": "艾伦",
|
|
||||||
# "title": "困扰的求助者",
|
|
||||||
# "personality": ["焦虑不安", "详细描述", "半信半疑", "急需帮助"],
|
|
||||||
# "background": "普通人,但最近遭遇了一系列神秘的厄运事件,怀疑受到诅咒",
|
|
||||||
# "speech_patterns": ["情绪紧张", "会详细描述遭遇", "语气急切", "表现出恐惧"],
|
|
||||||
# "sample_dialogues": [
|
|
||||||
# "最近我总是遭遇各种厄运...",
|
|
||||||
# "我怀疑是不是受到了什么诅咒。",
|
|
||||||
# "请帮帮我,我不知道该怎么办!"
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# "戴莉.西蒙妮": {
|
|
||||||
# "name": "戴莉·西蒙妮",
|
|
||||||
# "title": "专业调查员",
|
|
||||||
# "personality": ["专业简洁", "直接明确", "严谨认真", "目标导向"],
|
|
||||||
# "background": "负责调查神秘事件的专业人员,办事效率高,问题直接",
|
|
||||||
# "speech_patterns": ["问题直接明确", "语气专业", "注重事实", "简洁有力"],
|
|
||||||
# "sample_dialogues": [
|
|
||||||
# "请详细描述事件经过。",
|
|
||||||
# "有什么证据可以证明?",
|
|
||||||
# "这件事需要立即调查。"
|
|
||||||
# ]
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
|
|
||||||
def _load_model(self):
|
def _load_model(self):
|
||||||
"""加载模型和分词器"""
|
"""加载模型和分词器"""
|
||||||
print(f"Loading tokenizer from: {self.base_model_path}")
|
print(f"Loading tokenizer from: {self.base_model_path}")
|
||||||
@ -339,14 +273,175 @@ class NPCDialogueGenerator:
|
|||||||
"""列出所有可用角色"""
|
"""列出所有可用角色"""
|
||||||
return list(self.character_profiles.keys())
|
return list(self.character_profiles.keys())
|
||||||
|
|
||||||
def main():
|
class DualModelDialogueGenerator:
|
||||||
|
"""双模型对话生成器 - 每个模型扮演一个角色"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
base_model_path: str,
|
||||||
|
character1_config: Dict,
|
||||||
|
character2_config: Dict,
|
||||||
|
lora_model_path: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
初始化双模型对话生成器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_model_path: 基础模型路径
|
||||||
|
character1_config: 角色1配置 {"name": "角色名", "lora_path": "LoRA路径", "character_data": 角色数据}
|
||||||
|
character2_config: 角色2配置 {"name": "角色名", "lora_path": "LoRA路径", "character_data": 角色数据}
|
||||||
|
lora_model_path: 通用LoRA模型路径(可选)
|
||||||
|
"""
|
||||||
|
self.base_model_path = base_model_path
|
||||||
|
self.character1_config = character1_config
|
||||||
|
self.character2_config = character2_config
|
||||||
|
|
||||||
|
# 为每个角色创建独立的模型实例
|
||||||
|
self.character1_generator = None
|
||||||
|
self.character2_generator = None
|
||||||
|
|
||||||
|
self._initialize_character_models()
|
||||||
|
|
||||||
|
def _initialize_character_models(self):
|
||||||
|
"""初始化两个角色的模型"""
|
||||||
|
print("=== 初始化双模型对话系统 ===")
|
||||||
|
|
||||||
|
# 初始化角色1的模型
|
||||||
|
print(f"\n初始化角色1: {self.character1_config['name']}")
|
||||||
|
char1_lora_path = self.character1_config.get('lora_path') or self.character1_config.get('lora_model_path')
|
||||||
|
self.character1_generator = NPCDialogueGenerator(
|
||||||
|
self.base_model_path,
|
||||||
|
char1_lora_path,
|
||||||
|
{self.character1_config['name']: self.character1_config['character_data']}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化角色2的模型
|
||||||
|
print(f"\n初始化角色2: {self.character2_config['name']}")
|
||||||
|
char2_lora_path = self.character2_config.get('lora_path') or self.character2_config.get('lora_model_path')
|
||||||
|
self.character2_generator = NPCDialogueGenerator(
|
||||||
|
self.base_model_path,
|
||||||
|
char2_lora_path,
|
||||||
|
{self.character2_config['name']: self.character2_config['character_data']}
|
||||||
|
)
|
||||||
|
|
||||||
|
print("✓ 双模型对话系统初始化完成")
|
||||||
|
|
||||||
|
def generate_dual_character_dialogue(self,
|
||||||
|
character_name: str,
|
||||||
|
context: str = "",
|
||||||
|
user_input: str = "",
|
||||||
|
temperature: float = 0.8,
|
||||||
|
max_new_tokens: int = 150) -> str:
|
||||||
|
"""
|
||||||
|
生成指定角色的对话(使用对应的模型)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
character_name: 角色名称
|
||||||
|
context: 对话上下文
|
||||||
|
user_input: 用户输入
|
||||||
|
temperature: 采样温度
|
||||||
|
max_new_tokens: 最大生成token数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成的对话内容
|
||||||
|
"""
|
||||||
|
if character_name == self.character1_config['name']:
|
||||||
|
return self.character1_generator.generate_character_dialogue(
|
||||||
|
character_name, context, user_input, temperature, max_new_tokens
|
||||||
|
)
|
||||||
|
elif character_name == self.character2_config['name']:
|
||||||
|
return self.character2_generator.generate_character_dialogue(
|
||||||
|
character_name, context, user_input, temperature, max_new_tokens
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown character: {character_name}")
|
||||||
|
|
||||||
|
def run_dual_character_conversation(self,
|
||||||
|
topic: str = "",
|
||||||
|
turns: int = 4,
|
||||||
|
context: str = "",
|
||||||
|
temperature: float = 0.8,
|
||||||
|
max_new_tokens: int = 150) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
运行双角色对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: 对话主题
|
||||||
|
turns: 对话轮数
|
||||||
|
context: 额外上下文
|
||||||
|
temperature: 采样温度
|
||||||
|
max_new_tokens: 最大生成token数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
对话列表
|
||||||
|
"""
|
||||||
|
conversation = []
|
||||||
|
char1_name = self.character1_config['name']
|
||||||
|
char2_name = self.character2_config['name']
|
||||||
|
|
||||||
|
# 构建完整上下文
|
||||||
|
full_context = f"现在{char1_name}和{char2_name}在讨论关于{topic}的话题。{context}"
|
||||||
|
|
||||||
|
print(f"\n=== 开始双角色对话 ===")
|
||||||
|
print(f"主题: {topic}")
|
||||||
|
print(f"角色: {char1_name} vs {char2_name}")
|
||||||
|
print(f"轮数: {turns}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
for turn in range(turns):
|
||||||
|
if turn % 2 == 0:
|
||||||
|
# 角色1说话
|
||||||
|
speaker = char1_name
|
||||||
|
if turn == 0:
|
||||||
|
user_input = f"开始和{char2_name}讨论{topic}这个话题。"
|
||||||
|
else:
|
||||||
|
last_dialogue = conversation[-1]["dialogue"]
|
||||||
|
user_input = f"{char2_name}刚才说:\"{last_dialogue}\"。请回应。"
|
||||||
|
else:
|
||||||
|
# 角色2说话
|
||||||
|
speaker = char2_name
|
||||||
|
last_dialogue = conversation[-1]["dialogue"]
|
||||||
|
user_input = f"{char1_name}刚才说:\"{last_dialogue}\"。请回应。"
|
||||||
|
|
||||||
|
print(f"\n[第{turn+1}轮] {speaker}正在思考...")
|
||||||
|
|
||||||
|
# 使用对应角色的模型生成对话
|
||||||
|
dialogue = self.generate_dual_character_dialogue(
|
||||||
|
speaker, full_context, user_input, temperature, max_new_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation.append({
|
||||||
|
"turn": turn + 1,
|
||||||
|
"speaker": speaker,
|
||||||
|
"dialogue": dialogue,
|
||||||
|
"context_used": full_context[:100] + "..." if len(full_context) > 100 else full_context
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"{speaker}: {dialogue}")
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
print("✓ 双角色对话完成")
|
||||||
|
|
||||||
|
return conversation
|
||||||
|
|
||||||
|
def get_character_info(self, character_name: str) -> Dict:
|
||||||
|
"""获取角色信息"""
|
||||||
|
if character_name == self.character1_config['name']:
|
||||||
|
return self.character1_generator.get_character_info(character_name)
|
||||||
|
elif character_name == self.character2_config['name']:
|
||||||
|
return self.character2_generator.get_character_info(character_name)
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def list_characters(self) -> List[str]:
|
||||||
|
"""列出两个角色名称"""
|
||||||
|
return [self.character1_config['name'], self.character2_config['name']]
|
||||||
|
|
||||||
|
def main():
|
||||||
"""测试对话生成器"""
|
"""测试对话生成器"""
|
||||||
# 配置路径
|
# 配置路径
|
||||||
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ'
|
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ'
|
||||||
lora_model_path = './output/NPC_Dialogue_LoRA/final_model' # 如果没有训练LoRA,设为None
|
lora_model_path = './output/NPC_Dialogue_LoRA/final_model' # 如果没有训练LoRA,设为None
|
||||||
|
|
||||||
# 检查LoRA模型是否存在
|
# 检查LoRA模型是否存在
|
||||||
import os
|
|
||||||
if not os.path.exists(lora_model_path):
|
if not os.path.exists(lora_model_path):
|
||||||
print("LoRA模型不存在,使用基础模型")
|
print("LoRA模型不存在,使用基础模型")
|
||||||
lora_model_path = None
|
lora_model_path = None
|
||||||
@ -436,5 +531,5 @@ class NPCDialogueGenerator:
|
|||||||
|
|
||||||
print("\n对话生成器已退出")
|
print("\n对话生成器已退出")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user