89 lines
2.8 KiB
Python
89 lines
2.8 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
'''
|
|
准备角色对话微调数据集
|
|
将test.jsonl转换为适合LoRA训练的格式
|
|
'''
|
|
|
|
import json
|
|
import random
|
|
from typing import List, Dict
|
|
|
|
def load_dialogue_data(file_path: str) -> List[Dict]:
|
|
"""加载对话数据"""
|
|
dialogues = []
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
data = json.loads(line.strip())
|
|
if data and isinstance(data, dict) and 'role' in data:
|
|
dialogues.append(data)
|
|
else:
|
|
pass
|
|
return dialogues
|
|
|
|
def get_dialogue_characters(dialogues: List[Dict]) -> List[str]:
|
|
characters = []
|
|
for dialogue in dialogues:
|
|
character = dialogue['role']
|
|
if character not in characters:
|
|
characters.append(character)
|
|
|
|
return characters
|
|
|
|
def group_dialogues_by_character(dialogues: List[Dict]) -> Dict[str, List[str]]:
|
|
"""按角色分组对话"""
|
|
character_dialogues = {}
|
|
for dialogue in dialogues:
|
|
character = dialogue['role']
|
|
content = dialogue['dialogue']
|
|
|
|
if character not in character_dialogues:
|
|
character_dialogues[character] = []
|
|
character_dialogues[character].append(content)
|
|
|
|
return character_dialogues
|
|
|
|
|
|
def create_character_dialogue_samples(character:str, dialogues: List[Dict]) ->List[Dict]:
|
|
tempDialogue = ""
|
|
character_samples = []
|
|
for dialogue in dialogues:
|
|
speakCharacter = dialogue['role']
|
|
if(speakCharacter != character):
|
|
tempDialogue = dialogue['dialogue']
|
|
elif tempDialogue != '':
|
|
#确定是提问对话
|
|
character_samples.append({"character": character,
|
|
"instruction": tempDialogue,
|
|
"input": "",
|
|
"output": dialogue['dialogue']})
|
|
tempDialogue = ''
|
|
return character_samples
|
|
|
|
def main():
|
|
# 加载原始数据
|
|
dialogues = load_dialogue_data('./test.jsonl')
|
|
characters = get_dialogue_characters(dialogues)
|
|
character_dialogues = group_dialogues_by_character(dialogues)
|
|
print("角色统计:")
|
|
for char, convs in character_dialogues.items():
|
|
print(f" {char}: {len(convs)}条对话")
|
|
#获得最终训练数据
|
|
final_samples = []
|
|
for character in characters:
|
|
final_samples += create_character_dialogue_samples(character, dialogues)
|
|
|
|
|
|
|
|
|
|
# 保存为JSON格式
|
|
with open('./npc_dialogue_dataset.json', 'w', encoding='utf-8') as f:
|
|
json.dump(final_samples, f, ensure_ascii=False, indent=2)
|
|
|
|
print(f"\n生成了 {len(final_samples)} 个训练样本")
|
|
print("数据集已保存为: npc_dialogue_dataset.json")
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |