Project02/AITrain/example.py

62 lines
2.5 KiB
Python
Raw Normal View History

2025-08-06 17:02:29 +08:00
from extract import system_prompt
from schema import novel_schema
from LLM import DeepseekChat
from utils import ReadFiles
from tqdm import tqdm
import json
import platform
# Windows multiprocessing兼容性修复
if platform.system() == "Windows":
import multiprocessing
multiprocessing.freeze_support()
def main():
file_path = './data/test.txt'
model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ'
docs = ReadFiles(file_path).get_content(max_token_len=500, cover_content=0)
sys_prompt = system_prompt(novel_schema)
#model = DeepseekChat() 使用api需要配置url 和 api key
model = DeepseekChat(path = model_path, use_api=False) #使用本地模型,需要修改对应的模型地址
file_name = file_path.split('/')[-1].split('.')[0]
try:
for i in tqdm(range(len(docs))):
response = model.chat(sys_prompt, docs[i])
# 清理响应格式去除markdown代码块
response = response.strip()
if response.startswith('```json'):
response = response[7:] # 去除开头的```json
if response.startswith('```'):
response = response[3:] # 去除开头的```
if response.endswith('```'):
response = response[:-3] # 去除结尾的```
response = response.strip()
try:
response = json.loads(response)
for item in response:
2025-08-11 14:00:55 +08:00
# 数据质量检查过滤空的instruction或output
if (item.get('instruction', '').strip() and
item.get('output', '').strip() and
item.get('character', '').strip()):
with open(f'{file_name}.jsonl', 'a', encoding='utf-8') as f:
json.dump(item, f, ensure_ascii=False)
f.write('\n')
else:
print(f"跳过空字段数据: instruction='{item.get('instruction', '')}', output='{item.get('output', '')}', character='{item.get('character', '')}'")
2025-08-06 17:02:29 +08:00
except Exception as e:
print(f"解析错误: {e}")
print(f"原始响应: {repr(response[:200])}") # 打印前200字符用于调试
finally:
# 确保在程序结束时清理LLM实例
print("Cleaning up model resources...")
model.cleanup()
if __name__ == '__main__':
main()