#!/usr/bin/env python # -*- coding: utf-8 -*- ''' 准备角色对话微调数据集 将test.jsonl转换为适合LoRA训练的格式 ''' import json import random from typing import List, Dict def load_dialogue_data(file_path: str) -> List[Dict]: """加载对话数据""" dialogues = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: data = json.loads(line.strip()) dialogues.append(data) return dialogues def get_dialogue_characters(dialogues: List[Dict]) -> List[str]: characters = [] for dialogue in dialogues: character = dialogue['role'] if character not in characters: characters.append(character) return characters def group_dialogues_by_character(dialogues: List[Dict]) -> Dict[str, List[str]]: """按角色分组对话""" character_dialogues = {} for dialogue in dialogues: character = dialogue['role'] content = dialogue['dialogue'] if character not in character_dialogues: character_dialogues[character] = [] character_dialogues[character].append(content) return character_dialogues def create_character_dialogue_samples(character:str, dialogues: List[Dict]) ->List[Dict]: tempDialogue = "" character_samples = [] for dialogue in dialogues: speakCharacter = dialogue['role'] if(speakCharacter != character): tempDialogue = dialogue['dialogue'] elif tempDialogue != '': #确定是提问对话 character_samples.append({"instruction": tempDialogue, "input": "", "output": dialogue['dialogue']}) tempDialogue = '' return character_samples def main(): # 加载原始数据 dialogues = load_dialogue_data('./test.jsonl') characters = get_dialogue_characters(dialogues) character_dialogues = group_dialogues_by_character(dialogues) print("角色统计:") for char, convs in character_dialogues.items(): print(f" {char}: {len(convs)}条对话") #获得最终训练数据 final_samples = {} for character in characters: final_samples[character] = create_character_dialogue_samples(character, dialogues) # 保存为JSON格式 with open('./npc_dialogue_dataset.json', 'w', encoding='utf-8') as f: json.dump(final_samples, f, ensure_ascii=False, indent=2) print(f"\n生成了 {len(final_samples)} 个训练样本") print("数据集已保存为: npc_dialogue_dataset.json") if __name__ == '__main__': main()