调整prompt
This commit is contained in:
parent
7d61d1ce85
commit
b7b1944729
23
AIGC/main.py
23
AIGC/main.py
@ -15,7 +15,7 @@ from Utils.AIGCLog import AIGCLog
|
|||||||
from ollama import Client, ResponseError
|
from ollama import Client, ResponseError
|
||||||
|
|
||||||
app = FastAPI(title = "AI 通信服务")
|
app = FastAPI(title = "AI 通信服务")
|
||||||
logger = AIGCLog(name = "AIGC", log_file = "./Log/aigc.log")
|
logger = AIGCLog(name = "AIGC", log_file = "aigc.log")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', type=str, default='deepseek-r1:1.5b',
|
parser.add_argument('--model', type=str, default='deepseek-r1:1.5b',
|
||||||
@ -31,12 +31,13 @@ lastPrompt = ""
|
|||||||
ollamaClient = Client(host='http://localhost:11434')
|
ollamaClient = Client(host='http://localhost:11434')
|
||||||
|
|
||||||
async def heartbeat(websocket: WebSocket):
|
async def heartbeat(websocket: WebSocket):
|
||||||
while True:
|
pass
|
||||||
await asyncio.sleep(30) # 每30秒发送一次心跳
|
# while True:
|
||||||
try:
|
# await asyncio.sleep(30) # 每30秒发送一次心跳
|
||||||
await senddata(websocket, 0, [])
|
# try:
|
||||||
except ConnectionClosed:
|
# await senddata(websocket, 0, [])
|
||||||
break # 连接已关闭时退出循环
|
# except ConnectionClosed:
|
||||||
|
# break # 连接已关闭时退出循环
|
||||||
|
|
||||||
#statuscode -1 服务器运行错误 0 心跳标志 1 正常 2 输出异常
|
#statuscode -1 服务器运行错误 0 心跳标志 1 正常 2 输出异常
|
||||||
async def senddata(websocket: WebSocket, statusCode: int, messages: List[str]):
|
async def senddata(websocket: WebSocket, statusCode: int, messages: List[str]):
|
||||||
@ -63,6 +64,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
|||||||
data = await websocket.receive_text()
|
data = await websocket.receive_text()
|
||||||
logger.log(logging.INFO, f"收到UE5消息 [{client_id}]: {data}")
|
logger.log(logging.INFO, f"收到UE5消息 [{client_id}]: {data}")
|
||||||
success, prompt = process_prompt(data)
|
success, prompt = process_prompt(data)
|
||||||
|
global lastPrompt
|
||||||
lastPrompt = prompt
|
lastPrompt = prompt
|
||||||
# 调用AI生成响应
|
# 调用AI生成响应
|
||||||
if(success):
|
if(success):
|
||||||
@ -175,7 +177,7 @@ async def generateAIChat(prompt: str, websocket: WebSocket):
|
|||||||
match = re.search(pattern, think_remove_text, re.DOTALL)
|
match = re.search(pattern, think_remove_text, re.DOTALL)
|
||||||
|
|
||||||
if not match:
|
if not match:
|
||||||
if await reGenerateAIChat(prompt, websocket):
|
if await reGenerateAIChat(lastPrompt, websocket):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
|
logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
|
||||||
@ -183,9 +185,9 @@ async def generateAIChat(prompt: str, websocket: WebSocket):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
core_dialog = match.group(1).strip()
|
core_dialog = match.group(1).strip()
|
||||||
dialog_lines = core_dialog.split('\n')
|
dialog_lines = [line for line in core_dialog.split('\n') if line.strip()]
|
||||||
if len(dialog_lines) != 4:
|
if len(dialog_lines) != 4:
|
||||||
if await reGenerateAIChat(prompt, websocket):
|
if await reGenerateAIChat(lastPrompt, websocket):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
|
logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
|
||||||
@ -204,6 +206,7 @@ async def reGenerateAIChat(prompt: str, websocket: WebSocket):
|
|||||||
logger.log(logging.ERROR, f"AI输出格式不正确,重新进行生成 {regenerateCount}/{maxAIRegerateCount}")
|
logger.log(logging.ERROR, f"AI输出格式不正确,重新进行生成 {regenerateCount}/{maxAIRegerateCount}")
|
||||||
await senddata(websocket, 2, messages=["ai生成格式不正确, 重新进行生成"])
|
await senddata(websocket, 2, messages=["ai生成格式不正确, 重新进行生成"])
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
prompt = prompt + "补充:上一次的输出格式错误,严格执行prompt中第5条的输出格式要求"
|
||||||
await generateAIChat(prompt, websocket)
|
await generateAIChat(prompt, websocket)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user