diff --git a/AIGC/main.py b/AIGC/main.py index b13fbe6..39dedab 100644 --- a/AIGC/main.py +++ b/AIGC/main.py @@ -2,6 +2,7 @@ import argparse import asyncio import json import logging +import random import re import threading import time @@ -23,6 +24,9 @@ parser.add_argument('--model', type=str, default='deepseek-r1:1.5b', args = parser.parse_args() logger.log(logging.INFO, f"使用的模型是 {args.model}") +maxAIRegerateCount = 5 +lastPrompt = "" + #初始化ollama客户端 ollamaClient = Client(host='http://localhost:11434') @@ -34,6 +38,7 @@ async def heartbeat(websocket: WebSocket): except ConnectionClosed: break # 连接已关闭时退出循环 +#statuscode -1 服务器运行错误 0 心跳标志 1 正常 2 输出异常 async def senddata(websocket: WebSocket, statusCode: int, messages: List[str]): # 将AI响应发送回UE5 if websocket.client_state == WebSocketState.CONNECTED: @@ -53,7 +58,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str): asyncio.create_task(heartbeat(websocket)) try: while True: - logger.log(logging.INFO, "websocket_endpoint ") + # 接收UE5发来的消息 data = await websocket.receive_text() logger.log(logging.INFO, f"收到UE5消息 [{client_id}]: {data}") @@ -77,6 +82,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str): def process_prompt(promptFromUE: str) -> Tuple[bool, str]: try: + lastPrompt = promptFromUE data = json.loads(promptFromUE) # 提取数据 @@ -86,8 +92,10 @@ def process_prompt(promptFromUE: str) -> Tuple[bool, str]: assert len(persons) == 2 for person in persons: print(f" 姓名: {person['name']}, 职业: {person['job']}") - + #动态标识吗 防止重复输入导致的结果重复 + dynamic_token = str(int(time.time() % 1000)) prompt = f""" + [动态标识码:{dynamic_token}] 你是一个游戏NPC对话生成器。请严格按以下要求生成两个路人NPC({persons[0]["name"]}和{persons[1]["name"]})的日常对话: 1. 生成【2轮完整对话】,每轮包含双方各一次发言(共4句) 2. 对话场景:{dialog_scene} @@ -99,14 +107,13 @@ def process_prompt(promptFromUE: str) -> Tuple[bool, str]: * 避免任务指引或玩家交互内容 * 结尾保持对话未完成感 5. 输出格式: - 必须确保输出内容的第一行是三个连字符`---`,最后一行也是三个连字符`---` - 中间内容严格按以下顺序排列,禁止添加任何额外说明或换行: - --- + {persons[0]["name"]}:[第一轮发言] {persons[1]["name"]}:[第一轮回应] {persons[0]["name"]}:[第二轮发言] {persons[1]["name"]}:[第二轮回应] - --- + + 6.重要!若未按此格式输出,请重新生成直至完全符合 """ return True, prompt @@ -144,34 +151,63 @@ async def generateAIChat(prompt: str, websocket: WebSocket): stream = False, messages = receivemessage, options={ - "temperature": 0.8, + "temperature": random.uniform(1.0, 1.5), "repeat_penalty": 1.2, # 抑制重复 - "top_p": 0.9, + "top_p": random.uniform(0.7, 0.95), "num_ctx": 4096, # 上下文长度 - "seed": 42 + "seed": int(time.time() * 1000) % 1000000 } ) except ResponseError as e: if e.status_code == 503: print("🔄 服务不可用,5秒后重试...") - return await senddata(websocket, -1, messages={"ollama 服务不可用"}) + return await senddata(websocket, -1, messages=["ollama 服务不可用"]) except Exception as e: print(f"🔥 未预料错误: {str(e)}") - return await senddata(websocket, -1, messages={"未预料错误"}) + return await senddata(websocket, -1, messages=["未预料错误"]) logger.log(logging.INFO, "接口调用耗时 :" + str(time.time() - starttime)) logger.log(logging.INFO, "AI生成" + response['message']['content']) #处理ai输出内容 think_remove_text = re.sub(r'.*?', '', response['message']['content'], flags=re.DOTALL) - pattern = r".*---(.*?)---" # .* 吞掉前面所有字符,定位最后一组 + pattern = r".*(.*?)" # .* 吞掉前面所有字符,定位最后一组 match = re.search(pattern, think_remove_text, re.DOTALL) if not match: - await senddata(websocket, -1, messages={"ai生成格式不正确"}) + if await reGenerateAIChat(prompt, websocket): + pass + else: + logger.log(logging.ERROR, "请更换prompt,或者升级模型大小") + await senddata(websocket, -1, messages=["请更换prompt,或者升级模型大小"]) + else: core_dialog = match.group(1).strip() - logger.log(logging.INFO, "AI内容处理:\n" + core_dialog) - await senddata(websocket, 1, core_dialog.split('\n')) - + dialog_lines = core_dialog.split('\n') + if len(dialog_lines) != 4: + if await reGenerateAIChat(prompt, websocket): + pass + else: + logger.log(logging.ERROR, "请更换prompt,或者升级模型大小") + await senddata(websocket, -1, messages=["请更换prompt,或者升级模型大小"]) + else: + logger.log(logging.INFO, "AI的输出正确:\n" + core_dialog) + global regenerateCount + regenerateCount = 0 + await senddata(websocket, 1, dialog_lines) + +regenerateCount = 1 +async def reGenerateAIChat(prompt: str, websocket: WebSocket): + global regenerateCount + if regenerateCount < maxAIRegerateCount: + regenerateCount += 1 + logger.log(logging.ERROR, f"AI输出格式不正确,重新进行生成 {regenerateCount}/{maxAIRegerateCount}") + await senddata(websocket, 2, messages=["ai生成格式不正确, 重新进行生成"]) + await asyncio.sleep(0) + await generateAIChat(prompt, websocket) + return True + else: + regenerateCount = 0 + logger.log(logging.ERROR, "输出不正确 超过最大生成次数") + return False if __name__ == "__main__": # 创建并启动服务器线程