Project02/AITrain/prepare_dialogue_data.py

152 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
准备角色对话微调数据集
将test.jsonl转换为适合LoRA训练的格式
'''
import json
import random
from typing import List, Dict
def load_dialogue_data(file_path: str) -> List[Dict]:
"""加载对话数据"""
dialogues = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line.strip())
dialogues.append(data)
return dialogues
def get_dialogue_characters(dialogues: List[Dict]) -> List[str]:
characters = []
for dialogue in dialogues:
character = dialogue['role']
if character not in characters:
characters.append(character)
return characters
def group_dialogues_by_character(dialogues: List[Dict]) -> Dict[str, List[str]]:
"""按角色分组对话"""
character_dialogues = {}
for dialogue in dialogues:
character = dialogue['role']
content = dialogue['dialogue']
if character not in character_dialogues:
character_dialogues[character] = []
character_dialogues[character].append(content)
return character_dialogues
def create_training_samples(character_dialogues: Dict[str, List[str]], character_profiles: Dict) -> List[Dict]:
"""创建训练样本"""
training_samples = []
for character, dialogues in character_dialogues.items():
if character not in character_profiles:
continue
profile = character_profiles[character]
# 为每个角色创建多种类型的训练样本
for dialogue in dialogues:
# 样本1: 基于角色描述生成对话
sample1 = {
"instruction": f"你现在要扮演{character}{profile['description']}。性格特点:{profile['personality']}。说话风格:{profile['speech_style']}",
"input": "请根据你的角色设定说一段话。",
"output": dialogue
}
training_samples.append(sample1)
# 样本2: 基于场景生成对话
sample2 = {
"instruction": f"你是{character}{profile['background']}",
"input": "在当前情境下,你会说什么?",
"output": dialogue
}
training_samples.append(sample2)
# 创建角色互动样本
for i in range(min(50, len(character_dialogues['克莱恩']))):
if i < len(character_dialogues.get('塔利姆', [])):
# 克莱恩与塔利姆的对话
sample = {
"instruction": "你是克莱恩,一位神秘学专家和侦探。塔利姆是你的客户,向你寻求帮助。",
"input": f"塔利姆对你说:{character_dialogues['塔利姆'][i % len(character_dialogues['塔利姆'])]}",
"output": character_dialogues['克莱恩'][i]
}
training_samples.append(sample)
return training_samples
def create_npc_dialogue_samples() -> List[Dict]:
"""创建专门的NPC对话样本"""
npc_samples = [
{
"instruction": "你是一个游戏中的NPC神秘学导师名叫克莱恩。玩家向你寻求关于神秘学的建议。",
"input": "请告诉我关于灵界的知识。",
"output": "灵界是一个充满危险的地方,好奇往往是导致死亡的主要因素。如果你真的需要了解,记住永远不要直视那些不该看的存在。"
},
{
"instruction": "你是游戏中的阿兹克导师,经验丰富的神秘学大师。玩家遇到了困难。",
"input": "我在修炼中遇到了瓶颈,该怎么办?",
"output": "耐心是修炼的基础。不要急于求成,稳扎稳打比什么都重要。先巩固你现有的基础。"
},
{
"instruction": "你是游戏中的塔利姆一个有文化的普通NPC遇到了情感问题。",
"input": "你看起来有些困扰?",
"output": "噢,尊敬的冒险者,我有个朋友爱上了不该爱的人,这种情况该怎么处理?这不是《罗密欧与朱丽叶》的故事。"
},
{
"instruction": "你是游戏中的艾伦一个遭遇神秘事件的NPC需要玩家帮助。",
"input": "你遇到什么麻烦了?",
"output": "最近我总是遭遇各种厄运,摔跤、丢钱、被狗咬...我怀疑是不是受到了什么诅咒,请帮帮我!"
}
]
return npc_samples
def create_character_dialogue_samples(character:str, dialogues: List[Dict]) ->List[Dict]:
tempDialogue = ""
character_samples = []
for dialogue in dialogues:
speakCharacter = dialogue['role']
if(speakCharacter != character):
tempDialogue = dialogue['dialogue']
elif tempDialogue != '':
#确定是提问对话
character_samples.append({"instruction": tempDialogue,
"input": "",
"output": dialogue['dialogue']})
tempDialogue = ''
return character_samples
def main():
# 加载原始数据
dialogues = load_dialogue_data('./test.jsonl')
characters = get_dialogue_characters(dialogues)
character_dialogues = group_dialogues_by_character(dialogues)
print("角色统计:")
for char, convs in character_dialogues.items():
print(f" {char}: {len(convs)}条对话")
#获得最终训练数据
final_samples = {}
for character in characters:
final_samples[character] = create_character_dialogue_samples(character, dialogues)
# 保存为JSON格式
with open('./npc_dialogue_dataset.json', 'w', encoding='utf-8') as f:
json.dump(final_samples, f, ensure_ascii=False, indent=2)
print(f"\n生成了 {len(final_samples)} 个训练样本")
print("数据集已保存为: npc_dialogue_dataset.json")
if __name__ == '__main__':
main()