Project02/AIGC/main.py
2025-06-27 17:25:46 +08:00

253 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import asyncio
import json
import logging
import random
import re
import threading
import time
from typing import List, Tuple
from fastapi import FastAPI, Request, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.websockets import WebSocketState
from h11 import ConnectionClosed
import uvicorn
from Utils.AIGCLog import AIGCLog
from ollama import Client, ResponseError
app = FastAPI(title = "AI 通信服务")
logger = AIGCLog(name = "AIGC", log_file = "aigc.log")
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='deepseek-r1:1.5b',
help='Ollama模型名称')
args = parser.parse_args()
logger.log(logging.INFO, f"使用的模型是 {args.model}")
maxAIRegerateCount = 5
lastPrompt = ""
#初始化ollama客户端
ollamaClient = Client(host='http://localhost:11434')
async def heartbeat(websocket: WebSocket):
pass
# while True:
# await asyncio.sleep(30) # 每30秒发送一次心跳
# try:
# await senddata(websocket, 0, [])
# except ConnectionClosed:
# break # 连接已关闭时退出循环
#statuscode -1 服务器运行错误 0 心跳标志 1 正常 2 输出异常
async def senddata(websocket: WebSocket, statusCode: int, messages: List[str]):
# 将AI响应发送回UE5
if websocket.client_state == WebSocketState.CONNECTED:
data = {
"statuscode": statusCode,
"messages": messages
}
json_string = json.dumps(data, ensure_ascii=False)
await websocket.send_text(json_string)
# WebSocket路由处理
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await websocket.accept()
logger.log(logging.INFO, f"UE5客户端 {client_id} 已连接")
# 添加心跳任务
asyncio.create_task(heartbeat(websocket))
try:
while True:
# 接收UE5发来的消息
data = await websocket.receive_text()
logger.log(logging.INFO, f"收到UE5消息 [{client_id}]: {data}")
success, prompt = process_prompt(data)
global lastPrompt
lastPrompt = prompt
# 调用AI生成响应
if(success):
asyncio.create_task(generateAIChat(prompt, websocket))
await senddata(websocket, 0, [])
else:
await senddata(websocket, -1, [])
except WebSocketDisconnect:
#manager.disconnect(client_id)
logger.log(logging.WARNING, f"UE5客户端主动断开 [{client_id}]")
except Exception as e:
#manager.disconnect(client_id)
logger.log(logging.ERROR, f"WebSocket异常 [{client_id}]: {str(e)}")
def process_prompt(promptFromUE: str) -> Tuple[bool, str]:
try:
data = json.loads(promptFromUE)
# 提取数据
dialog_scene = data["dialogContent"]["dialogScene"]
persons = data["persons"]
assert len(persons) == 2
for person in persons:
print(f" 姓名: {person['name']}, 职业: {person['job']}")
prompt = f"""
你是一个游戏NPC对话生成器。请严格按以下要求生成两个路人NPC{persons[0]["name"]}{persons[1]["name"]})的日常对话:
1. 生成【2轮完整对话】每轮包含双方各一次发言共4句
2. 对话场景:{dialog_scene}
3. 角色设定:
{persons[0]["name"]}{persons[0]["job"]}
{persons[1]["name"]}{persons[1]["job"]}
4. 对话要求:
* 每轮对话需自然衔接,体现生活细节
* 避免任务指引或玩家交互内容
* 结尾保持对话未完成感
5. 输出格式:
<format>
{persons[0]["name"]}[第一轮发言]
{persons[1]["name"]}[第一轮回应]
{persons[0]["name"]}[第二轮发言]
{persons[1]["name"]}[第二轮回应]
</format>
6.重要!若未按此格式输出,请重新生成直至完全符合
"""
return True, prompt
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}")
return False, ""
except KeyError as e:
print(f"缺少必要字段: {e}")
def run_webserver():
logger.log(logging.INFO, "启动web服务器 ")
#启动服务器
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
timeout_keep_alive=3000,
log_level="info"
)
async def generateAIChat(prompt: str, websocket: WebSocket):
#动态标识吗 防止重复输入导致的结果重复
dynamic_token = str(int(time.time() % 1000))
prompt = f"""
[动态标识码:{dynamic_token}]
""" + prompt
logger.log(logging.INFO, "prompt:" + prompt)
starttime = time.time()
receivemessage=[
{"role": "system", "content": prompt}
]
try:
response = ollamaClient.chat(
model = args.model,
stream = False,
messages = receivemessage,
options={
"temperature": random.uniform(1.0, 1.5),
"repeat_penalty": 1.2, # 抑制重复
"top_p": random.uniform(0.7, 0.95),
"num_ctx": 4096, # 上下文长度
"seed": int(time.time() * 1000) % 1000000
}
)
except ResponseError as e:
if e.status_code == 503:
print("🔄 服务不可用5秒后重试...")
return await senddata(websocket, -1, messages=["ollama 服务不可用"])
except Exception as e:
print(f"🔥 未预料错误: {str(e)}")
return await senddata(websocket, -1, messages=["未预料错误"])
logger.log(logging.INFO, "接口调用耗时 :" + str(time.time() - starttime))
logger.log(logging.INFO, "AI生成" + response['message']['content'])
#处理ai输出内容
think_remove_text = re.sub(r'<think>.*?</think>', '', response['message']['content'], flags=re.DOTALL)
pattern = r".*<format>(.*?)</format>" # .* 吞掉前面所有字符,定位最后一组
match = re.search(pattern, think_remove_text, re.DOTALL)
if not match:
if await reGenerateAIChat(lastPrompt, websocket):
pass
else:
logger.log(logging.ERROR, "请更换prompt或者升级模型大小")
await senddata(websocket, -1, messages=["请更换prompt或者升级模型大小"])
else:
core_dialog = match.group(1).strip()
dialog_lines = [line for line in core_dialog.split('\n') if line.strip()]
if len(dialog_lines) != 4:
if await reGenerateAIChat(lastPrompt, websocket):
pass
else:
logger.log(logging.ERROR, "请更换prompt或者升级模型大小")
await senddata(websocket, -1, messages=["请更换prompt或者升级模型大小"])
else:
logger.log(logging.INFO, "AI的输出正确\n" + core_dialog)
global regenerateCount
regenerateCount = 0
await senddata(websocket, 1, dialog_lines)
regenerateCount = 1
async def reGenerateAIChat(prompt: str, websocket: WebSocket):
global regenerateCount
if regenerateCount < maxAIRegerateCount:
regenerateCount += 1
logger.log(logging.ERROR, f"AI输出格式不正确重新进行生成 {regenerateCount}/{maxAIRegerateCount}")
await senddata(websocket, 2, messages=["ai生成格式不正确 重新进行生成"])
await asyncio.sleep(0)
prompt = prompt + "补充上一次的输出格式错误严格执行prompt中第5条的输出格式要求"
await generateAIChat(prompt, websocket)
return True
else:
regenerateCount = 0
logger.log(logging.ERROR, "输出不正确 超过最大生成次数")
return False
if __name__ == "__main__":
# 创建并启动服务器线程
server_thread = threading.Thread(target=run_webserver)
server_thread.daemon = True # 设为守护线程(主程序退出时自动终止)
server_thread.start()
## Test
# generateAIChat(f"""
# 你是一个游戏NPC对话生成器。请严格按以下要求生成两个路人NPCA和B的日常对话
# 1. 生成【2轮完整对话】每轮包含双方各一次发言共4句
# 2. 对话场景:中世纪奇幻小镇的日常场景(如市场/酒馆/街道)
# 3. 角色设定:
# - NPC A随机职业铁匠/农夫/商人/卫兵等)
# - NPC B随机职业不同于A
# 4. 对话要求:
# * 每轮对话需自然衔接,体现生活细节
# * 避免任务指引或玩家交互内容
# * 结尾保持对话未完成感
# 5. 输出格式(严格遵循,
# ---
# A[第一轮发言]
# B[第一轮回应]
# A[第二轮发言]
# B[第二轮回应]
# ---
# """
# )
try:
# 主线程永久挂起(监听退出信号)
while True:
time.sleep(3600) # 每1小时唤醒一次避免CPU占用
except KeyboardInterrupt:
logger.log(logging.WARNING, "接收到中断信号,程序退出 ")