完善读取本地世界观对话
This commit is contained in:
parent
18b982fadb
commit
8824d2d25d
@ -64,14 +64,39 @@ class RAGKnowledgeBase:
|
|||||||
|
|
||||||
def _load_knowledge_base(self):
|
def _load_knowledge_base(self):
|
||||||
"""加载知识库"""
|
"""加载知识库"""
|
||||||
# 加载世界观
|
# 优先加载RAG知识库作为世界观
|
||||||
worldview_files = [f for f in os.listdir(self.knowledge_dir)
|
rag_worldview_path = "./rag_knowledge/knowledge_base.json"
|
||||||
if f.startswith('worldview') and f.endswith('.json')]
|
if os.path.exists(rag_worldview_path):
|
||||||
if worldview_files:
|
try:
|
||||||
worldview_path = os.path.join(self.knowledge_dir, worldview_files[0])
|
with open(rag_worldview_path, 'r', encoding='utf-8') as f:
|
||||||
with open(worldview_path, 'r', encoding='utf-8') as f:
|
rag_data = json.load(f)
|
||||||
self.worldview_data = json.load(f)
|
# 从RAG数据中提取世界观信息
|
||||||
print(f"✓ 世界观加载成功: {self.worldview_data.get('worldview_name', '未知')}")
|
self.worldview_data = {
|
||||||
|
"worldview_name": "克苏鲁神话世界观 (RAG)",
|
||||||
|
"source": rag_data.get("metadata", {}).get("source_file", "未知"),
|
||||||
|
"description": f"基于{rag_data.get('metadata', {}).get('source_file', 'PDF文档')}的RAG知识库",
|
||||||
|
"total_chunks": rag_data.get("metadata", {}).get("total_chunks", 0),
|
||||||
|
"total_concepts": rag_data.get("metadata", {}).get("total_concepts", 0),
|
||||||
|
"rag_enabled": True
|
||||||
|
}
|
||||||
|
# 保存RAG数据用于检索
|
||||||
|
self.rag_chunks = rag_data.get("chunks", [])
|
||||||
|
print(f"✓ RAG世界观加载成功: {self.worldview_data['worldview_name']}")
|
||||||
|
print(f" - 文档块数: {self.worldview_data['total_chunks']}")
|
||||||
|
print(f" - 概念数: {self.worldview_data['total_concepts']}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ RAG世界观加载失败: {e}")
|
||||||
|
self.rag_chunks = []
|
||||||
|
|
||||||
|
# 如果没有RAG知识库,则加载传统世界观文件
|
||||||
|
if not hasattr(self, 'rag_chunks') or not self.rag_chunks:
|
||||||
|
worldview_files = [f for f in os.listdir(self.knowledge_dir)
|
||||||
|
if f.startswith('worldview') and f.endswith('.json')]
|
||||||
|
if worldview_files:
|
||||||
|
worldview_path = os.path.join(self.knowledge_dir, worldview_files[0])
|
||||||
|
with open(worldview_path, 'r', encoding='utf-8') as f:
|
||||||
|
self.worldview_data = json.load(f)
|
||||||
|
print(f"✓ 传统世界观加载成功: {self.worldview_data.get('worldview_name', '未知')}")
|
||||||
|
|
||||||
# 加载角色数据
|
# 加载角色数据
|
||||||
character_files = [f for f in os.listdir(self.knowledge_dir)
|
character_files = [f for f in os.listdir(self.knowledge_dir)
|
||||||
@ -96,21 +121,38 @@ class RAGKnowledgeBase:
|
|||||||
"""构建可检索的文本块"""
|
"""构建可检索的文本块"""
|
||||||
self.chunks = []
|
self.chunks = []
|
||||||
|
|
||||||
# 世界观相关文本块
|
# 优先使用RAG知识库的文本块
|
||||||
if self.worldview_data:
|
if hasattr(self, 'rag_chunks') and self.rag_chunks:
|
||||||
for section_key, section_data in self.worldview_data.items():
|
for rag_chunk in self.rag_chunks:
|
||||||
if isinstance(section_data, dict):
|
self.chunks.append({
|
||||||
for sub_key, sub_data in section_data.items():
|
"type": "worldview_rag",
|
||||||
if isinstance(sub_data, (str, list)):
|
"section": "rag_knowledge",
|
||||||
content = str(sub_data)
|
"subsection": rag_chunk.get("type", "unknown"),
|
||||||
if len(content) > 50: # 只保留有意义的文本
|
"content": rag_chunk.get("content", ""),
|
||||||
self.chunks.append({
|
"metadata": {
|
||||||
"type": "worldview",
|
"source": "rag_worldview",
|
||||||
"section": section_key,
|
"chunk_id": rag_chunk.get("id", ""),
|
||||||
"subsection": sub_key,
|
"size": rag_chunk.get("size", 0),
|
||||||
"content": content,
|
"hash": rag_chunk.get("hash", "")
|
||||||
"metadata": {"source": "worldview"}
|
}
|
||||||
})
|
})
|
||||||
|
print(f"✓ 使用RAG知识库文本块: {len(self.rag_chunks)} 个")
|
||||||
|
else:
|
||||||
|
# 传统世界观相关文本块
|
||||||
|
if self.worldview_data:
|
||||||
|
for section_key, section_data in self.worldview_data.items():
|
||||||
|
if isinstance(section_data, dict):
|
||||||
|
for sub_key, sub_data in section_data.items():
|
||||||
|
if isinstance(sub_data, (str, list)):
|
||||||
|
content = str(sub_data)
|
||||||
|
if len(content) > 50: # 只保留有意义的文本
|
||||||
|
self.chunks.append({
|
||||||
|
"type": "worldview",
|
||||||
|
"section": section_key,
|
||||||
|
"subsection": sub_key,
|
||||||
|
"content": content,
|
||||||
|
"metadata": {"source": "worldview"}
|
||||||
|
})
|
||||||
|
|
||||||
# 角色相关文本块
|
# 角色相关文本块
|
||||||
for char_name, char_data in self.character_data.items():
|
for char_name, char_data in self.character_data.items():
|
||||||
@ -134,6 +176,18 @@ class RAGKnowledgeBase:
|
|||||||
def _build_vector_index(self):
|
def _build_vector_index(self):
|
||||||
"""构建向量索引"""
|
"""构建向量索引"""
|
||||||
try:
|
try:
|
||||||
|
# 优先使用RAG知识库的预构建向量索引
|
||||||
|
rag_vector_path = "./rag_knowledge/vector_index.faiss"
|
||||||
|
rag_embeddings_path = "./rag_knowledge/embeddings.npy"
|
||||||
|
|
||||||
|
if os.path.exists(rag_vector_path) and os.path.exists(rag_embeddings_path):
|
||||||
|
# 加载预构建的向量索引
|
||||||
|
self.index = faiss.read_index(rag_vector_path)
|
||||||
|
self.rag_embeddings = np.load(rag_embeddings_path)
|
||||||
|
print(f"✓ 使用RAG预构建向量索引: {self.index.ntotal}个向量")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果没有预构建的向量索引,则重新构建
|
||||||
texts = [chunk["content"] for chunk in self.chunks]
|
texts = [chunk["content"] for chunk in self.chunks]
|
||||||
embeddings = self.embedding_model.encode(texts)
|
embeddings = self.embedding_model.encode(texts)
|
||||||
|
|
||||||
@ -152,14 +206,26 @@ class RAGKnowledgeBase:
|
|||||||
# 向量搜索
|
# 向量搜索
|
||||||
if EMBEDDING_AVAILABLE and self.embedding_model and self.index:
|
if EMBEDDING_AVAILABLE and self.embedding_model and self.index:
|
||||||
try:
|
try:
|
||||||
query_vector = self.embedding_model.encode([query])
|
# 如果使用RAG预构建向量索引,直接搜索
|
||||||
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
|
if hasattr(self, 'rag_embeddings'):
|
||||||
|
query_vector = self.embedding_model.encode([query])
|
||||||
|
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
|
||||||
|
|
||||||
for distance, idx in zip(distances[0], indices[0]):
|
for distance, idx in zip(distances[0], indices[0]):
|
||||||
if idx < len(self.chunks):
|
if idx < len(self.chunks):
|
||||||
chunk = self.chunks[idx].copy()
|
chunk = self.chunks[idx].copy()
|
||||||
chunk["relevance_score"] = float(1 / (1 + distance))
|
chunk["relevance_score"] = float(1 / (1 + distance))
|
||||||
relevant_chunks.append(chunk)
|
relevant_chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
# 传统向量搜索
|
||||||
|
query_vector = self.embedding_model.encode([query])
|
||||||
|
distances, indices = self.index.search(query_vector.astype(np.float32), top_k * 2)
|
||||||
|
|
||||||
|
for distance, idx in zip(distances[0], indices[0]):
|
||||||
|
if idx < len(self.chunks):
|
||||||
|
chunk = self.chunks[idx].copy()
|
||||||
|
chunk["relevance_score"] = float(1 / (1 + distance))
|
||||||
|
relevant_chunks.append(chunk)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"向量搜索失败: {e}")
|
print(f"向量搜索失败: {e}")
|
||||||
|
|
||||||
@ -317,8 +383,17 @@ class DualAIDialogueEngine:
|
|||||||
self.conv_mgr = conversation_manager
|
self.conv_mgr = conversation_manager
|
||||||
self.llm_generator = llm_generator
|
self.llm_generator = llm_generator
|
||||||
|
|
||||||
def generate_character_prompt(self, character_name: str, context_info: List[Dict], dialogue_history: List[DialogueTurn]) -> str:
|
def generate_character_prompt(self, character_name: str, context_info: List[Dict], dialogue_history: List[DialogueTurn],
|
||||||
"""为角色生成对话提示"""
|
history_context_count: int = 3, context_info_count: int = 2) -> str:
|
||||||
|
"""为角色生成对话提示
|
||||||
|
|
||||||
|
Args:
|
||||||
|
character_name: 角色名称
|
||||||
|
context_info: 相关上下文信息
|
||||||
|
dialogue_history: 对话历史
|
||||||
|
history_context_count: 使用的历史对话轮数(默认3轮)
|
||||||
|
context_info_count: 使用的上下文信息数量(默认2个)
|
||||||
|
"""
|
||||||
char_data = self.kb.character_data.get(character_name, {})
|
char_data = self.kb.character_data.get(character_name, {})
|
||||||
|
|
||||||
# 基础角色设定
|
# 基础角色设定
|
||||||
@ -338,42 +413,60 @@ class DualAIDialogueEngine:
|
|||||||
situation = char_data['current_situation']
|
situation = char_data['current_situation']
|
||||||
prompt_parts.append(f"当前状态:{situation.get('current_mood', '')}")
|
prompt_parts.append(f"当前状态:{situation.get('current_mood', '')}")
|
||||||
|
|
||||||
# 相关世界观信息
|
# 相关世界观信息(可控制数量)
|
||||||
if context_info:
|
if context_info:
|
||||||
prompt_parts.append("相关背景信息:")
|
prompt_parts.append("相关背景信息:")
|
||||||
for info in context_info[:2]: # 只使用最相关的2个信息
|
for info in context_info[:context_info_count]:
|
||||||
content = info['content'][:200] + "..." if len(info['content']) > 200 else info['content']
|
content = info['content'][:200] + "..." if len(info['content']) > 200 else info['content']
|
||||||
prompt_parts.append(f"- {content}")
|
prompt_parts.append(f"- {content}")
|
||||||
|
|
||||||
# 对话历史
|
# 对话历史(可控制数量)
|
||||||
if dialogue_history:
|
if dialogue_history:
|
||||||
prompt_parts.append("最近的对话:")
|
prompt_parts.append("最近的对话:")
|
||||||
for turn in dialogue_history[-3:]: # 只使用最近的3轮对话
|
# 使用参数控制历史对话轮数
|
||||||
|
history_to_use = dialogue_history[-history_context_count:] if history_context_count > 0 else []
|
||||||
|
for turn in history_to_use:
|
||||||
prompt_parts.append(f"{turn.speaker}: {turn.content}")
|
prompt_parts.append(f"{turn.speaker}: {turn.content}")
|
||||||
|
|
||||||
prompt_parts.append("\n请根据角色设定和上下文,生成符合角色特点的自然对话。回复应该在50-150字之间。")
|
prompt_parts.append("\n请根据角色设定和上下文,生成符合角色特点的自然对话。回复应该在50-150字之间。")
|
||||||
|
|
||||||
return "\n".join(prompt_parts)
|
return "\n".join(prompt_parts)
|
||||||
|
|
||||||
def generate_dialogue(self, session_id: str, current_speaker: str, topic_hint: str = "") -> Tuple[str, List[str]]:
|
def generate_dialogue(self, session_id: str, current_speaker: str, topic_hint: str = "",
|
||||||
"""生成角色对话"""
|
history_context_count: int = 3, context_info_count: int = 2) -> Tuple[str, List[str]]:
|
||||||
|
"""生成角色对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
current_speaker: 当前说话者
|
||||||
|
topic_hint: 话题提示
|
||||||
|
history_context_count: 使用的历史对话轮数(默认3轮)
|
||||||
|
context_info_count: 使用的上下文信息数量(默认2个)
|
||||||
|
"""
|
||||||
# 获取对话历史
|
# 获取对话历史
|
||||||
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
|
dialogue_history = self.conv_mgr.get_conversation_history(session_id)
|
||||||
|
|
||||||
# 构建搜索查询
|
# 构建搜索查询
|
||||||
if dialogue_history:
|
if dialogue_history:
|
||||||
# 基于最近的对话内容
|
# 基于最近的对话内容(可控制数量)
|
||||||
recent_content = " ".join([turn.content for turn in dialogue_history[-2:]])
|
recent_turns = dialogue_history[-history_context_count:] if history_context_count > 0 else []
|
||||||
|
recent_content = " ".join([turn.content for turn in recent_turns])
|
||||||
search_query = recent_content + " " + topic_hint
|
search_query = recent_content + " " + topic_hint
|
||||||
else:
|
else:
|
||||||
# 首次对话
|
# 首次对话
|
||||||
search_query = f"{current_speaker} {topic_hint} introduction greeting"
|
search_query = f"{current_speaker} {topic_hint} introduction greeting"
|
||||||
|
|
||||||
# 搜索相关上下文
|
# 搜索相关上下文
|
||||||
context_info = self.kb.search_relevant_context(search_query, current_speaker, 10)
|
context_info = self.kb.search_relevant_context(search_query, current_speaker, max(10, context_info_count * 2))
|
||||||
|
|
||||||
# 生成提示
|
# 生成提示(使用参数控制上下文数量)
|
||||||
prompt = self.generate_character_prompt(current_speaker, context_info, dialogue_history)
|
prompt = self.generate_character_prompt(
|
||||||
|
current_speaker,
|
||||||
|
context_info,
|
||||||
|
dialogue_history,
|
||||||
|
history_context_count,
|
||||||
|
context_info_count
|
||||||
|
)
|
||||||
|
|
||||||
# 生成对话
|
# 生成对话
|
||||||
try:
|
try:
|
||||||
@ -386,8 +479,8 @@ class DualAIDialogueEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 记录使用的上下文
|
# 记录使用的上下文
|
||||||
context_used = [f"{info['section']}.{info['subsection']}" for info in context_info]
|
context_used = [f"{info['section']}.{info['subsection']}" for info in context_info[:context_info_count]]
|
||||||
avg_relevance = sum(info['relevance_score'] for info in context_info) / len(context_info) if context_info else 0.0
|
avg_relevance = sum(info['relevance_score'] for info in context_info[:context_info_count]) / len(context_info[:context_info_count]) if context_info else 0.0
|
||||||
|
|
||||||
# 保存对话轮次
|
# 保存对话轮次
|
||||||
self.conv_mgr.add_dialogue_turn(
|
self.conv_mgr.add_dialogue_turn(
|
||||||
@ -400,23 +493,44 @@ class DualAIDialogueEngine:
|
|||||||
print(f"✗ 对话生成失败: {e}")
|
print(f"✗ 对话生成失败: {e}")
|
||||||
return f"[{current_speaker}暂时无法回应]", []
|
return f"[{current_speaker}暂时无法回应]", []
|
||||||
|
|
||||||
def run_conversation_turn(self, session_id: str, characters: List[str], turns_count: int = 1, topic: str = ""):
|
def run_conversation_turn(self, session_id: str, characters: List[str], turns_count: int = 1, topic: str = "",
|
||||||
"""运行对话轮次"""
|
history_context_count: int = 3, context_info_count: int = 2):
|
||||||
results = []
|
"""运行对话轮次
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
characters: 角色列表
|
||||||
|
turns_count: 对话轮数
|
||||||
|
topic: 对话主题
|
||||||
|
history_context_count: 使用的历史对话轮数(默认3轮)
|
||||||
|
context_info_count: 使用的上下文信息数量(默认2个)
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
print(f" [上下文设置: 历史{history_context_count}轮, 信息{context_info_count}个]")
|
||||||
for i in range(turns_count):
|
for i in range(turns_count):
|
||||||
for char in characters:
|
for char in characters:
|
||||||
response, context_used = self.generate_dialogue(session_id, char, topic)
|
response, context_used = self.generate_dialogue(
|
||||||
|
session_id,
|
||||||
|
char,
|
||||||
|
topic,
|
||||||
|
history_context_count,
|
||||||
|
context_info_count
|
||||||
|
)
|
||||||
results.append({
|
results.append({
|
||||||
"speaker": char,
|
"speaker": char,
|
||||||
"content": response,
|
"content": response,
|
||||||
"context_used": context_used,
|
"context_used": context_used,
|
||||||
"turn": i + 1
|
"turn": i + 1,
|
||||||
|
"context_settings": {
|
||||||
|
"history_count": history_context_count,
|
||||||
|
"context_info_count": context_info_count
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
print(f"{char}: {response}")
|
print(f"{char}: {response}")
|
||||||
if context_used:
|
# if context_used:
|
||||||
print(f" [使用上下文: {', '.join(context_used)}]")
|
# print(f" [使用上下文: {', '.join(context_used)}]")
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@ -453,7 +567,13 @@ def main():
|
|||||||
if not os.path.exists(lora_model_path):
|
if not os.path.exists(lora_model_path):
|
||||||
lora_model_path = None
|
lora_model_path = None
|
||||||
|
|
||||||
llm_generator = NPCDialogueGenerator(base_model_path, lora_model_path)
|
# 创建对话生成器并传入角色数据
|
||||||
|
if hasattr(kb, 'character_data') and kb.character_data:
|
||||||
|
print("✓ 使用knowledge_base角色数据创建对话生成器")
|
||||||
|
llm_generator = NPCDialogueGenerator(base_model_path, lora_model_path, kb.character_data)
|
||||||
|
else:
|
||||||
|
print("⚠ 使用内置角色数据创建对话生成器")
|
||||||
|
llm_generator = NPCDialogueGenerator(base_model_path, lora_model_path)
|
||||||
|
|
||||||
# 创建对话引擎
|
# 创建对话引擎
|
||||||
dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator)
|
dialogue_engine = DualAIDialogueEngine(kb, conv_mgr, llm_generator)
|
||||||
@ -491,8 +611,17 @@ def main():
|
|||||||
topic = input("请输入对话主题(可选): ").strip()
|
topic = input("请输入对话主题(可选): ").strip()
|
||||||
turns = int(input("请输入对话轮次数量(默认2): ").strip() or "2")
|
turns = int(input("请输入对话轮次数量(默认2): ").strip() or "2")
|
||||||
|
|
||||||
|
# 历史上下文控制选项
|
||||||
|
print("\n历史上下文设置:")
|
||||||
|
history_count = input("使用历史对话轮数(默认3,0表示不使用): ").strip()
|
||||||
|
history_count = int(history_count) if history_count.isdigit() else 3
|
||||||
|
|
||||||
|
context_info_count = input("使用上下文信息数量(默认2): ").strip()
|
||||||
|
context_info_count = int(context_info_count) if context_info_count.isdigit() else 2
|
||||||
|
|
||||||
print(f"\n开始对话 - 会话ID: {session_id}")
|
print(f"\n开始对话 - 会话ID: {session_id}")
|
||||||
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic)
|
print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}个")
|
||||||
|
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
|
||||||
|
|
||||||
elif choice == '2':
|
elif choice == '2':
|
||||||
# 继续已有对话
|
# 继续已有对话
|
||||||
@ -523,8 +652,17 @@ def main():
|
|||||||
topic = input("请输入对话主题(可选): ").strip()
|
topic = input("请输入对话主题(可选): ").strip()
|
||||||
turns = int(input("请输入对话轮次数量(默认1): ").strip() or "1")
|
turns = int(input("请输入对话轮次数量(默认1): ").strip() or "1")
|
||||||
|
|
||||||
|
# 历史上下文控制选项
|
||||||
|
print("\n历史上下文设置:")
|
||||||
|
history_count = input("使用历史对话轮数(默认3,0表示不使用): ").strip()
|
||||||
|
history_count = int(history_count) if history_count.isdigit() else 3
|
||||||
|
|
||||||
|
context_info_count = input("使用上下文信息数量(默认2): ").strip()
|
||||||
|
context_info_count = int(context_info_count) if context_info_count.isdigit() else 2
|
||||||
|
|
||||||
print(f"\n继续对话 - 会话ID: {session_id}")
|
print(f"\n继续对话 - 会话ID: {session_id}")
|
||||||
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic)
|
print(f"上下文设置: 历史{history_count}轮, 信息{context_info_count}个")
|
||||||
|
dialogue_engine.run_conversation_turn(session_id, characters, turns, topic, history_count, context_info_count)
|
||||||
else:
|
else:
|
||||||
print("❌ 无效的会话编号")
|
print("❌ 无效的会话编号")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|||||||
@ -19,22 +19,100 @@ if platform.system() == "Windows":
|
|||||||
multiprocessing.set_start_method('spawn', force=True)
|
multiprocessing.set_start_method('spawn', force=True)
|
||||||
|
|
||||||
class NPCDialogueGenerator:
|
class NPCDialogueGenerator:
|
||||||
def __init__(self, base_model_path: str, lora_model_path: Optional[str] = None):
|
def __init__(self, base_model_path: str, lora_model_path: Optional[str] = None, external_character_data: Optional[Dict] = None):
|
||||||
"""
|
"""
|
||||||
初始化NPC对话生成器
|
初始化NPC对话生成器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
base_model_path: 基础模型路径
|
base_model_path: 基础模型路径
|
||||||
lora_model_path: LoRA模型路径(可选)
|
lora_model_path: LoRA模型路径(可选)
|
||||||
|
external_character_data: 外部角色数据(可选,优先使用)
|
||||||
"""
|
"""
|
||||||
self.base_model_path = base_model_path
|
self.base_model_path = base_model_path
|
||||||
self.lora_model_path = lora_model_path
|
self.lora_model_path = lora_model_path
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.character_profiles = self._load_character_profiles()
|
|
||||||
|
# 优先使用外部角色数据,如果没有则使用内置数据
|
||||||
|
if external_character_data:
|
||||||
|
self.character_profiles = self._process_external_character_data(external_character_data)
|
||||||
|
print(f"✓ 使用外部角色数据: {list(self.character_profiles.keys())}")
|
||||||
|
else:
|
||||||
|
self.character_profiles = self._load_character_profiles()
|
||||||
|
print(f"✓ 使用内置角色数据: {list(self.character_profiles.keys())}")
|
||||||
|
|
||||||
self._load_model()
|
self._load_model()
|
||||||
|
|
||||||
|
def _process_external_character_data(self, external_data: Dict) -> Dict:
|
||||||
|
"""
|
||||||
|
处理外部角色数据,转换为对话生成器可用的格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
external_data: 来自knowledge_base的角色数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的角色数据字典
|
||||||
|
"""
|
||||||
|
processed_profiles = {}
|
||||||
|
|
||||||
|
for char_name, char_data in external_data.items():
|
||||||
|
# 提取基本信息
|
||||||
|
basic_info = char_data.get('basic_info', {})
|
||||||
|
personality = char_data.get('personality', {})
|
||||||
|
background = char_data.get('background', {})
|
||||||
|
skills = char_data.get('skills_and_abilities', {})
|
||||||
|
speech_patterns = char_data.get('speech_patterns', {})
|
||||||
|
|
||||||
|
# 构建角色画像
|
||||||
|
profile = {
|
||||||
|
"name": char_data.get('character_name', char_name),
|
||||||
|
"title": basic_info.get('occupation', '未知'),
|
||||||
|
"personality": personality.get('core_traits', []) + personality.get('strengths', []),
|
||||||
|
"background": background.get('childhood', '') + ' ' + background.get('education', ''),
|
||||||
|
"speech_patterns": speech_patterns.get('vocabulary', []) + speech_patterns.get('tone', []),
|
||||||
|
"sample_dialogues": self._generate_sample_dialogues(char_data),
|
||||||
|
# 保存完整数据供高级功能使用
|
||||||
|
"full_data": char_data
|
||||||
|
}
|
||||||
|
|
||||||
|
processed_profiles[char_name] = profile
|
||||||
|
|
||||||
|
return processed_profiles
|
||||||
|
|
||||||
|
def _generate_sample_dialogues(self, char_data: Dict) -> List[str]:
|
||||||
|
"""
|
||||||
|
基于角色数据生成示例对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char_data: 角色数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
示例对话列表
|
||||||
|
"""
|
||||||
|
# 这里可以根据角色的性格、背景等生成更合适的示例对话
|
||||||
|
# 暂时返回一些通用的示例
|
||||||
|
basic_info = char_data.get('basic_info', {})
|
||||||
|
occupation = basic_info.get('occupation', '角色')
|
||||||
|
|
||||||
|
if '侦探' in occupation or '调查员' in occupation:
|
||||||
|
return [
|
||||||
|
"我需要仔细分析这个案件。",
|
||||||
|
"每个细节都可能很重要。",
|
||||||
|
"让我重新梳理一下线索。"
|
||||||
|
]
|
||||||
|
elif '教授' in occupation or '博士' in occupation:
|
||||||
|
return [
|
||||||
|
"根据我的研究,这个现象很特殊。",
|
||||||
|
"我们需要更谨慎地处理这个问题。",
|
||||||
|
"知识就是力量,但也要小心使用。"
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
"我遇到了一些困难。",
|
||||||
|
"请帮帮我。",
|
||||||
|
"这太奇怪了。"
|
||||||
|
]
|
||||||
|
|
||||||
def _load_character_profiles(self) -> Dict:
|
def _load_character_profiles(self) -> Dict:
|
||||||
"""加载角色画像数据"""
|
"""加载角色画像数据"""
|
||||||
return {
|
return {
|
||||||
|
|||||||
3748
AITrain/test.jsonl
Normal file
3748
AITrain/test.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user