62 lines
2.5 KiB
Python
62 lines
2.5 KiB
Python
from extract import system_prompt
|
||
from schema import novel_schema
|
||
from LLM import DeepseekChat
|
||
from utils import ReadFiles
|
||
from tqdm import tqdm
|
||
import json
|
||
import platform
|
||
|
||
# Windows multiprocessing兼容性修复
|
||
if platform.system() == "Windows":
|
||
import multiprocessing
|
||
multiprocessing.freeze_support()
|
||
|
||
def main():
|
||
file_path = './data/test.txt'
|
||
model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ'
|
||
docs = ReadFiles(file_path).get_content(max_token_len=500, cover_content=0)
|
||
|
||
sys_prompt = system_prompt(novel_schema)
|
||
|
||
#model = DeepseekChat() 使用api需要配置url 和 api key
|
||
model = DeepseekChat(path = model_path, use_api=False) #使用本地模型,需要修改对应的模型地址
|
||
|
||
file_name = file_path.split('/')[-1].split('.')[0]
|
||
|
||
try:
|
||
for i in tqdm(range(len(docs))):
|
||
response = model.chat(sys_prompt, docs[i])
|
||
|
||
# 清理响应格式,去除markdown代码块
|
||
response = response.strip()
|
||
if response.startswith('```json'):
|
||
response = response[7:] # 去除开头的```json
|
||
if response.startswith('```'):
|
||
response = response[3:] # 去除开头的```
|
||
if response.endswith('```'):
|
||
response = response[:-3] # 去除结尾的```
|
||
response = response.strip()
|
||
|
||
try:
|
||
response = json.loads(response)
|
||
for item in response:
|
||
# 数据质量检查:过滤空的instruction或output
|
||
if (item.get('instruction', '').strip() and
|
||
item.get('output', '').strip() and
|
||
item.get('character', '').strip()):
|
||
with open(f'{file_name}.jsonl', 'a', encoding='utf-8') as f:
|
||
json.dump(item, f, ensure_ascii=False)
|
||
f.write('\n')
|
||
else:
|
||
print(f"跳过空字段数据: instruction='{item.get('instruction', '')}', output='{item.get('output', '')}', character='{item.get('character', '')}'")
|
||
except Exception as e:
|
||
print(f"解析错误: {e}")
|
||
print(f"原始响应: {repr(response[:200])}") # 打印前200字符用于调试
|
||
|
||
finally:
|
||
# 确保在程序结束时清理LLM实例
|
||
print("Cleaning up model resources...")
|
||
model.cleanup()
|
||
|
||
if __name__ == '__main__':
|
||
main() |