diff --git a/AITrain/prepare_dialogue_data.py b/AITrain/prepare_dialogue_data.py new file mode 100644 index 0000000..9143349 --- /dev/null +++ b/AITrain/prepare_dialogue_data.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +''' +准备角色对话微调数据集 +将test.jsonl转换为适合LoRA训练的格式 +''' + +import json +import random +from typing import List, Dict + +def load_dialogue_data(file_path: str) -> List[Dict]: + """加载对话数据""" + dialogues = [] + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + data = json.loads(line.strip()) + dialogues.append(data) + return dialogues + +def get_dialogue_characters(dialogues: List[Dict]) -> List[str]: + characters = [] + for dialogue in dialogues: + character = dialogue['role'] + if character not in characters: + characters.append(character) + + return characters + +def group_dialogues_by_character(dialogues: List[Dict]) -> Dict[str, List[str]]: + """按角色分组对话""" + character_dialogues = {} + for dialogue in dialogues: + character = dialogue['role'] + content = dialogue['dialogue'] + + if character not in character_dialogues: + character_dialogues[character] = [] + character_dialogues[character].append(content) + + return character_dialogues + + + +def create_training_samples(character_dialogues: Dict[str, List[str]], character_profiles: Dict) -> List[Dict]: + """创建训练样本""" + training_samples = [] + + for character, dialogues in character_dialogues.items(): + if character not in character_profiles: + continue + + profile = character_profiles[character] + + # 为每个角色创建多种类型的训练样本 + for dialogue in dialogues: + # 样本1: 基于角色描述生成对话 + sample1 = { + "instruction": f"你现在要扮演{character}。{profile['description']}。性格特点:{profile['personality']}。说话风格:{profile['speech_style']}。", + "input": "请根据你的角色设定说一段话。", + "output": dialogue + } + training_samples.append(sample1) + + # 样本2: 基于场景生成对话 + sample2 = { + "instruction": f"你是{character},{profile['background']}。", + "input": "在当前情境下,你会说什么?", + "output": dialogue + } + training_samples.append(sample2) + + # 创建角色互动样本 + for i in range(min(50, len(character_dialogues['克莱恩']))): + if i < len(character_dialogues.get('塔利姆', [])): + # 克莱恩与塔利姆的对话 + sample = { + "instruction": "你是克莱恩,一位神秘学专家和侦探。塔利姆是你的客户,向你寻求帮助。", + "input": f"塔利姆对你说:{character_dialogues['塔利姆'][i % len(character_dialogues['塔利姆'])]}", + "output": character_dialogues['克莱恩'][i] + } + training_samples.append(sample) + + return training_samples + +def create_npc_dialogue_samples() -> List[Dict]: + """创建专门的NPC对话样本""" + npc_samples = [ + { + "instruction": "你是一个游戏中的NPC神秘学导师,名叫克莱恩。玩家向你寻求关于神秘学的建议。", + "input": "请告诉我关于灵界的知识。", + "output": "灵界是一个充满危险的地方,好奇往往是导致死亡的主要因素。如果你真的需要了解,记住永远不要直视那些不该看的存在。" + }, + { + "instruction": "你是游戏中的阿兹克导师,经验丰富的神秘学大师。玩家遇到了困难。", + "input": "我在修炼中遇到了瓶颈,该怎么办?", + "output": "耐心是修炼的基础。不要急于求成,稳扎稳打比什么都重要。先巩固你现有的基础。" + }, + { + "instruction": "你是游戏中的塔利姆,一个有文化的普通NPC,遇到了情感问题。", + "input": "你看起来有些困扰?", + "output": "噢,尊敬的冒险者,我有个朋友爱上了不该爱的人,这种情况该怎么处理?这不是《罗密欧与朱丽叶》的故事。" + }, + { + "instruction": "你是游戏中的艾伦,一个遭遇神秘事件的NPC,需要玩家帮助。", + "input": "你遇到什么麻烦了?", + "output": "最近我总是遭遇各种厄运,摔跤、丢钱、被狗咬...我怀疑是不是受到了什么诅咒,请帮帮我!" + } + ] + return npc_samples +def create_character_dialogue_samples(character:str, dialogues: List[Dict]) ->List[Dict]: + tempDialogue = "" + character_samples = [] + for dialogue in dialogues: + speakCharacter = dialogue['role'] + if(speakCharacter != character): + tempDialogue = dialogue['dialogue'] + elif tempDialogue != '': + #确定是提问对话 + character_samples.append({"instruction": tempDialogue, + "input": "", + "output": dialogue['dialogue']}) + tempDialogue = '' + return character_samples + +def main(): + # 加载原始数据 + dialogues = load_dialogue_data('./test.jsonl') + characters = get_dialogue_characters(dialogues) + character_dialogues = group_dialogues_by_character(dialogues) + print("角色统计:") + for char, convs in character_dialogues.items(): + print(f" {char}: {len(convs)}条对话") + #获得最终训练数据 + final_samples = {} + for character in characters: + final_samples[character] = create_character_dialogue_samples(character, dialogues) + + + + + # 保存为JSON格式 + with open('./npc_dialogue_dataset.json', 'w', encoding='utf-8') as f: + json.dump(final_samples, f, ensure_ascii=False, indent=2) + + print(f"\n生成了 {len(final_samples)} 个训练样本") + print("数据集已保存为: npc_dialogue_dataset.json") + + + +if __name__ == '__main__': + main() \ No newline at end of file