增加ai的输出随机性,添加输出内容格式错误时,重新生成的机制
This commit is contained in:
parent
1cb89e167c
commit
609b4cd072
68
AIGC/main.py
68
AIGC/main.py
@ -2,6 +2,7 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
@ -23,6 +24,9 @@ parser.add_argument('--model', type=str, default='deepseek-r1:1.5b',
|
||||
args = parser.parse_args()
|
||||
logger.log(logging.INFO, f"使用的模型是 {args.model}")
|
||||
|
||||
maxAIRegerateCount = 5
|
||||
lastPrompt = ""
|
||||
|
||||
#初始化ollama客户端
|
||||
ollamaClient = Client(host='http://localhost:11434')
|
||||
|
||||
@ -34,6 +38,7 @@ async def heartbeat(websocket: WebSocket):
|
||||
except ConnectionClosed:
|
||||
break # 连接已关闭时退出循环
|
||||
|
||||
#statuscode -1 服务器运行错误 0 心跳标志 1 正常 2 输出异常
|
||||
async def senddata(websocket: WebSocket, statusCode: int, messages: List[str]):
|
||||
# 将AI响应发送回UE5
|
||||
if websocket.client_state == WebSocketState.CONNECTED:
|
||||
@ -53,7 +58,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
||||
asyncio.create_task(heartbeat(websocket))
|
||||
try:
|
||||
while True:
|
||||
logger.log(logging.INFO, "websocket_endpoint ")
|
||||
|
||||
# 接收UE5发来的消息
|
||||
data = await websocket.receive_text()
|
||||
logger.log(logging.INFO, f"收到UE5消息 [{client_id}]: {data}")
|
||||
@ -77,6 +82,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
||||
|
||||
def process_prompt(promptFromUE: str) -> Tuple[bool, str]:
|
||||
try:
|
||||
lastPrompt = promptFromUE
|
||||
data = json.loads(promptFromUE)
|
||||
|
||||
# 提取数据
|
||||
@ -86,8 +92,10 @@ def process_prompt(promptFromUE: str) -> Tuple[bool, str]:
|
||||
assert len(persons) == 2
|
||||
for person in persons:
|
||||
print(f" 姓名: {person['name']}, 职业: {person['job']}")
|
||||
|
||||
#动态标识吗 防止重复输入导致的结果重复
|
||||
dynamic_token = str(int(time.time() % 1000))
|
||||
prompt = f"""
|
||||
[动态标识码:{dynamic_token}]
|
||||
你是一个游戏NPC对话生成器。请严格按以下要求生成两个路人NPC({persons[0]["name"]}和{persons[1]["name"]})的日常对话:
|
||||
1. 生成【2轮完整对话】,每轮包含双方各一次发言(共4句)
|
||||
2. 对话场景:{dialog_scene}
|
||||
@ -99,14 +107,13 @@ def process_prompt(promptFromUE: str) -> Tuple[bool, str]:
|
||||
* 避免任务指引或玩家交互内容
|
||||
* 结尾保持对话未完成感
|
||||
5. 输出格式:
|
||||
必须确保输出内容的第一行是三个连字符`---`,最后一行也是三个连字符`---`
|
||||
中间内容严格按以下顺序排列,禁止添加任何额外说明或换行:
|
||||
---
|
||||
<format>
|
||||
{persons[0]["name"]}:[第一轮发言]
|
||||
{persons[1]["name"]}:[第一轮回应]
|
||||
{persons[0]["name"]}:[第二轮发言]
|
||||
{persons[1]["name"]}:[第二轮回应]
|
||||
---
|
||||
</format>
|
||||
|
||||
6.重要!若未按此格式输出,请重新生成直至完全符合
|
||||
"""
|
||||
return True, prompt
|
||||
@ -144,34 +151,63 @@ async def generateAIChat(prompt: str, websocket: WebSocket):
|
||||
stream = False,
|
||||
messages = receivemessage,
|
||||
options={
|
||||
"temperature": 0.8,
|
||||
"temperature": random.uniform(1.0, 1.5),
|
||||
"repeat_penalty": 1.2, # 抑制重复
|
||||
"top_p": 0.9,
|
||||
"top_p": random.uniform(0.7, 0.95),
|
||||
"num_ctx": 4096, # 上下文长度
|
||||
"seed": 42
|
||||
"seed": int(time.time() * 1000) % 1000000
|
||||
}
|
||||
)
|
||||
except ResponseError as e:
|
||||
if e.status_code == 503:
|
||||
print("🔄 服务不可用,5秒后重试...")
|
||||
return await senddata(websocket, -1, messages={"ollama 服务不可用"})
|
||||
return await senddata(websocket, -1, messages=["ollama 服务不可用"])
|
||||
except Exception as e:
|
||||
print(f"🔥 未预料错误: {str(e)}")
|
||||
return await senddata(websocket, -1, messages={"未预料错误"})
|
||||
return await senddata(websocket, -1, messages=["未预料错误"])
|
||||
logger.log(logging.INFO, "接口调用耗时 :" + str(time.time() - starttime))
|
||||
logger.log(logging.INFO, "AI生成" + response['message']['content'])
|
||||
#处理ai输出内容
|
||||
think_remove_text = re.sub(r'<think>.*?</think>', '', response['message']['content'], flags=re.DOTALL)
|
||||
pattern = r".*---(.*?)---" # .* 吞掉前面所有字符,定位最后一组
|
||||
pattern = r".*<format>(.*?)</format>" # .* 吞掉前面所有字符,定位最后一组
|
||||
match = re.search(pattern, think_remove_text, re.DOTALL)
|
||||
|
||||
if not match:
|
||||
await senddata(websocket, -1, messages={"ai生成格式不正确"})
|
||||
if await reGenerateAIChat(prompt, websocket):
|
||||
pass
|
||||
else:
|
||||
logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
|
||||
await senddata(websocket, -1, messages=["请更换prompt,或者升级模型大小"])
|
||||
|
||||
else:
|
||||
core_dialog = match.group(1).strip()
|
||||
logger.log(logging.INFO, "AI内容处理:\n" + core_dialog)
|
||||
await senddata(websocket, 1, core_dialog.split('\n'))
|
||||
|
||||
dialog_lines = core_dialog.split('\n')
|
||||
if len(dialog_lines) != 4:
|
||||
if await reGenerateAIChat(prompt, websocket):
|
||||
pass
|
||||
else:
|
||||
logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
|
||||
await senddata(websocket, -1, messages=["请更换prompt,或者升级模型大小"])
|
||||
else:
|
||||
logger.log(logging.INFO, "AI的输出正确:\n" + core_dialog)
|
||||
global regenerateCount
|
||||
regenerateCount = 0
|
||||
await senddata(websocket, 1, dialog_lines)
|
||||
|
||||
regenerateCount = 1
|
||||
async def reGenerateAIChat(prompt: str, websocket: WebSocket):
|
||||
global regenerateCount
|
||||
if regenerateCount < maxAIRegerateCount:
|
||||
regenerateCount += 1
|
||||
logger.log(logging.ERROR, f"AI输出格式不正确,重新进行生成 {regenerateCount}/{maxAIRegerateCount}")
|
||||
await senddata(websocket, 2, messages=["ai生成格式不正确, 重新进行生成"])
|
||||
await asyncio.sleep(0)
|
||||
await generateAIChat(prompt, websocket)
|
||||
return True
|
||||
else:
|
||||
regenerateCount = 0
|
||||
logger.log(logging.ERROR, "输出不正确 超过最大生成次数")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建并启动服务器线程
|
||||
|
Loading…
x
Reference in New Issue
Block a user