363 lines
14 KiB
Python
363 lines
14 KiB
Python
import argparse
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import random
|
||
import re
|
||
import threading
|
||
import time
|
||
from typing import List, Tuple
|
||
from fastapi import FastAPI, Request, HTTPException, WebSocket, WebSocketDisconnect
|
||
from fastapi.websockets import WebSocketState
|
||
from h11 import ConnectionClosed
|
||
import uvicorn
|
||
from AICore import AICore
|
||
from DatabaseHandle import DatabaseHandle
|
||
from Utils.AIGCLog import AIGCLog
|
||
|
||
|
||
app = FastAPI(title = "AI 通信服务")
|
||
logger = AIGCLog(name = "AIGC", log_file = "aigc.log")
|
||
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--model', type=str, default='deepseek-r1:1.5b',
|
||
help='Ollama模型名称')
|
||
|
||
args = parser.parse_args()
|
||
logger.log(logging.INFO, f"使用的模型是 {args.model}")
|
||
|
||
maxAIRegerateCount = 5 #最大重新生成次数
|
||
regenerateCount = 1 #当前重新生成次数
|
||
totalAIGenerateCount = 1 #客户端生成AI响应总数
|
||
currentGenerateCount = 0 #当前生成次数
|
||
lastPrompt = ""
|
||
character_id1 = 0
|
||
character_id2 = 0
|
||
aicore = AICore(args.model)
|
||
database = DatabaseHandle()
|
||
|
||
async def heartbeat(websocket: WebSocket):
|
||
pass
|
||
# while True:
|
||
# await asyncio.sleep(30) # 每30秒发送一次心跳
|
||
# try:
|
||
# await senddata(websocket, 0, [])
|
||
# except ConnectionClosed:
|
||
# break # 连接已关闭时退出循环
|
||
|
||
#statuscode -1 服务器运行错误 0 心跳标志 1 正常 2 输出异常
|
||
async def senddata(websocket: WebSocket, protocol: dict):
|
||
# 将AI响应发送回UE5
|
||
if websocket.client_state == WebSocketState.CONNECTED:
|
||
json_string = json.dumps(protocol, ensure_ascii=False)
|
||
await websocket.send_text(json_string)
|
||
|
||
async def sendprotocol(websocket: WebSocket, cmd: str, status: int, message: str, data: str):
|
||
# 将AI响应发送回UE5
|
||
protocol = {}
|
||
protocol["cmd"] = cmd
|
||
protocol["status"] = status
|
||
protocol["message"] = message
|
||
protocol["data"] = data
|
||
|
||
if websocket.client_state == WebSocketState.CONNECTED:
|
||
json_string = json.dumps(protocol, ensure_ascii=False)
|
||
await websocket.send_text(json_string)
|
||
|
||
# WebSocket路由处理
|
||
@app.websocket("/ws/{client_id}")
|
||
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
||
await websocket.accept()
|
||
logger.log(logging.INFO, f"UE5客户端 {client_id} 已连接")
|
||
# 添加心跳任务
|
||
asyncio.create_task(heartbeat(websocket))
|
||
try:
|
||
while True:
|
||
|
||
# 接收UE5发来的消息
|
||
data = await websocket.receive_text()
|
||
logger.log(logging.INFO, f"收到UE5消息 [{client_id}]: {data}")
|
||
await process_protocol_json(data, websocket)
|
||
|
||
except WebSocketDisconnect:
|
||
#manager.disconnect(client_id)
|
||
logger.log(logging.WARNING, f"UE5客户端主动断开 [{client_id}]")
|
||
except Exception as e:
|
||
#manager.disconnect(client_id)
|
||
logger.log(logging.ERROR, f"WebSocket异常 [{client_id}]: {str(e)}")
|
||
|
||
async def handle_characterlist(client: WebSocket):
|
||
### 获得数据库中的角色信息###
|
||
characters = database.get_character_byname("")
|
||
protocol = {}
|
||
protocol["cmd"] = "RequestCharacterInfos"
|
||
protocol["status"] = 1
|
||
protocol["message"] = "success"
|
||
characterforUE = {}
|
||
characterforUE["characterInfos"] = characters
|
||
protocol["data"] = json.dumps(characterforUE)
|
||
await senddata(client, protocol)
|
||
|
||
async def handle_characternames(client: WebSocket):
|
||
### 获得数据库中的角色信息###
|
||
characters = database.get_character_byname("")
|
||
protocol = {}
|
||
protocol["cmd"] = "RequestCharacterNames"
|
||
protocol["status"] = 1
|
||
protocol["message"] = "success"
|
||
characterforUE = {}
|
||
characterforUE["characterInfos"] = characters
|
||
protocol["data"] = json.dumps(characterforUE)
|
||
await senddata(client, protocol)
|
||
|
||
async def generate_aichat(promptStr: str, client: WebSocket| None = None):
|
||
dynamic_token = str(int(time.time() % 1000))
|
||
promptStr = f"""
|
||
[动态标识码:{dynamic_token}]
|
||
""" + promptStr
|
||
logger.log(logging.INFO, "prompt:" + promptStr)
|
||
starttime = time.time()
|
||
success, response, = aicore.generateAI(promptStr)
|
||
if(success):
|
||
logger.log(logging.INFO, "接口调用耗时 :" + str(time.time() - starttime))
|
||
logger.log(logging.INFO, "AI生成" + response)
|
||
#处理ai输出内容
|
||
think_remove_text = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
|
||
pattern = r".*<format>(.*?)</format>" # .* 吞掉前面所有字符,定位最后一组
|
||
match = re.search(pattern, think_remove_text, re.DOTALL)
|
||
if not match:
|
||
#生成内容格式错误
|
||
if await reGenerateAIChat(lastPrompt, client):
|
||
pass
|
||
else:
|
||
#超过重新生成次数
|
||
logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
|
||
await sendprotocol(client, "AiChatGenerate", 0, "请更换prompt,或者升级模型大小", "")
|
||
else:
|
||
#生成内容格式正确
|
||
core_dialog = match.group(1).strip()
|
||
dialog_lines = [line.strip() for line in core_dialog.split('\n') if line.strip()]
|
||
if len(dialog_lines) != 4:
|
||
#生成内容格式错误
|
||
if await reGenerateAIChat(lastPrompt, client):
|
||
pass
|
||
else:
|
||
logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
|
||
await sendprotocol(client, "AiChatGenerate", 0, "请更换prompt,或者升级模型大小", "")
|
||
else:
|
||
logger.log(logging.INFO, "AI的输出正确:\n" + core_dialog)
|
||
global regenerateCount
|
||
regenerateCount = 0
|
||
#保存数据到数据库
|
||
database.add_chat({"character_ids":f"{character_id1},{character_id2}","chat":f"{" ".join(dialog_lines)}"})
|
||
|
||
await sendprotocol(client, "AiChatGenerate", 1, "AI生成成功", "|".join(dialog_lines))
|
||
else:
|
||
await sendprotocol(client, "AiChatGenerate", -1, "调用ollama服务失败", "")
|
||
|
||
async def handle_aichat_generate(client: WebSocket, aichat_data:str):
|
||
### 处理ai prompt###
|
||
success, prompt = process_prompt(aichat_data)
|
||
global lastPrompt
|
||
lastPrompt = prompt
|
||
|
||
# 调用AI生成响应
|
||
if(success):
|
||
#asyncio.create_task(generateAIChat(prompt, client))
|
||
global currentGenerateCount
|
||
while currentGenerateCount < totalAIGenerateCount:
|
||
currentGenerateCount += 1
|
||
await generate_aichat(prompt, client)
|
||
|
||
currentGenerateCount = 0
|
||
#全部生成完成
|
||
await sendprotocol(client, "AiChatGenerate", 2, "AI生成成功", "")
|
||
else:
|
||
#prompt生成失败
|
||
await sendprotocol(client, "AiChatGenerate", -1, "prompt convert failed", "")
|
||
|
||
async def handle_addcharacter(client: WebSocket, chracterJson: str):
|
||
### 添加角色到数据库 ###
|
||
character_info = json.loads(chracterJson)
|
||
id = database.add_character(character_info)
|
||
logger.log(logging.INFO, f"添加角色到数据库 id = {id}")
|
||
# id = database.add_character({"name":"张三","age":35,"personality":"成熟稳重/惜字如金","profession":"阿里巴巴算法工程师"
|
||
# ,"characterBackground":"浙大计算机系毕业,专注AI优化项目","chat_style":"请在对话中表现出专业、冷静、惜字如金。用口语化的方式简短回答"})
|
||
|
||
async def process_protocol_json(json_str: str, client: WebSocket):
|
||
### 处理协议JSON ###
|
||
try:
|
||
protocol = json.loads(json_str)
|
||
cmd = protocol.get("cmd")
|
||
data = protocol.get("data")
|
||
if cmd == "RequestCharacterInfos":
|
||
await handle_characterlist(client)
|
||
elif cmd == "RequestCharacterNames":
|
||
await handle_characternames(client)
|
||
elif cmd == "AddCharacter":
|
||
await handle_addcharacter(client, data)
|
||
elif cmd == "AiChatGenerate":
|
||
await handle_aichat_generate(client, data)
|
||
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON解析错误: {e}")
|
||
|
||
|
||
def process_prompt(promptFromUE: str) -> Tuple[bool, str]:
|
||
try:
|
||
data = json.loads(promptFromUE)
|
||
global maxAIRegerateCount
|
||
# 提取数据
|
||
dialog_scene = data["dialogScene"]
|
||
global totalAIGenerateCount
|
||
totalAIGenerateCount = data["generateCount"]
|
||
persons = data["characterName"]
|
||
|
||
assert len(persons) == 2
|
||
characterInfo1 = database.get_character_byname(persons[0])
|
||
characterInfo2 = database.get_character_byname(persons[1])
|
||
global character_id1, character_id2
|
||
character_id1 = characterInfo1[0]["id"]
|
||
character_id2 = characterInfo2[0]["id"]
|
||
chat_history = database.get_chats_by_character_id(str(character_id1) + "," + str(character_id2))
|
||
#整理对话记录
|
||
result = result = '\n'.join([item['chat'] for item in chat_history])
|
||
|
||
|
||
prompt = f"""
|
||
#你是一个游戏NPC对话生成器。请严格按以下要求生成两个角色的日常对话
|
||
#对话的背景
|
||
{dialog_scene}
|
||
1. 生成【2轮完整对话】,每轮包含双方各一次发言(共4句)
|
||
2.角色设定
|
||
{characterInfo1[0]["name"]}: {{
|
||
"姓名": {characterInfo1[0]["name"]},
|
||
"年龄": {characterInfo1[0]["age"]},
|
||
"性格": {characterInfo1[0]["personality"]},
|
||
"职业": {characterInfo1[0]["profession"]},
|
||
"背景": {characterInfo1[0]["characterBackground"]},
|
||
"语言风格": {characterInfo1[0]["chat_style"]}
|
||
}},
|
||
{characterInfo2[0]["name"]}: {{
|
||
"姓名": {characterInfo2[0]["name"]},
|
||
"年龄": {characterInfo2[0]["age"]},
|
||
"性格": {characterInfo2[0]["personality"]},
|
||
"职业": {characterInfo2[0]["profession"]},
|
||
"背景": {characterInfo2[0]["characterBackground"]},
|
||
"语言风格": {characterInfo2[0]["chat_style"]}
|
||
}}
|
||
3.参考的历史对话内容
|
||
{result}
|
||
4.输出格式:
|
||
<format>
|
||
张三:[第一轮发言]
|
||
李明:[第一轮回应]
|
||
张三:[第二轮发言]
|
||
李明:[第二轮回应]
|
||
</format>
|
||
5.重要!若未按此格式输出,请重新生成直至完全符合
|
||
"""
|
||
return True, prompt
|
||
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON解析错误: {e}")
|
||
return False, ""
|
||
except Exception as e:
|
||
print(f"发生错误:{type(e).__name__} - {e}")
|
||
return False, ""
|
||
|
||
|
||
|
||
|
||
|
||
def run_webserver():
|
||
logger.log(logging.INFO, "启动web服务器 ")
|
||
#启动服务器
|
||
uvicorn.run(
|
||
app,
|
||
host="0.0.0.0",
|
||
port=8000,
|
||
timeout_keep_alive=3000,
|
||
log_level="info"
|
||
)
|
||
|
||
async def reGenerateAIChat(prompt: str, websocket: WebSocket):
|
||
global regenerateCount
|
||
if regenerateCount < maxAIRegerateCount:
|
||
regenerateCount += 1
|
||
logger.log(logging.ERROR, f"AI输出格式不正确,重新进行生成 {regenerateCount}/{maxAIRegerateCount}")
|
||
await sendprotocol(websocket, "AiChatGenerate", 0, "ai生成格式不正确, 重新进行生成", "")
|
||
await asyncio.sleep(0)
|
||
prompt = prompt + "补充:上一次的输出格式错误,严格执行prompt中的输出格式要求"
|
||
await generate_aichat(prompt, websocket)
|
||
return True
|
||
else:
|
||
regenerateCount = 0
|
||
logger.log(logging.ERROR, "输出不正确 超过最大生成次数")
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
# 创建并启动服务器线程
|
||
server_thread = threading.Thread(target=run_webserver)
|
||
server_thread.daemon = True # 设为守护线程(主程序退出时自动终止)
|
||
server_thread.start()
|
||
|
||
|
||
#Test database
|
||
|
||
|
||
|
||
# id = database.add_character({"name":"李明","age":30,"personality":"活泼健谈","profession":"产品经理"
|
||
# ,"characterBackground":"公司资深产品经理","chat_style":"热情"})
|
||
|
||
#characters = database.get_character_byname("")
|
||
#chat_id = database.add_chat({"character_ids":"1,2","chat":"张三:[第一轮发言] 李明:[第一轮回应] 张三:[第二轮发言] 李明:[第二轮回应"})
|
||
#chat = database.get_chats_by_character_id(3)
|
||
#
|
||
# Test AI
|
||
#aicore.getPromptToken("测试功能")
|
||
# asyncio.run(
|
||
# generateAIChat(promptStr = f"""
|
||
# #你是一个游戏NPC对话生成器。请严格按以下要求生成两个角色的日常对话
|
||
# #对话的世界观背景是2025年的都市背景
|
||
# 1. 生成【2轮完整对话】,每轮包含双方各一次发言(共4句)
|
||
# 2.角色设定
|
||
|
||
# "张三": {{
|
||
# "姓名": "张三",
|
||
# "年龄": 35,
|
||
# "性格": "成熟稳重/惜字如金",
|
||
# "职业": "阿里巴巴算法工程师",
|
||
# "背景": "浙大计算机系毕业,专注AI优化项目",
|
||
# "对话场景": "你正在和用户聊天,用户是你的同事",
|
||
# "语言风格": "请在对话中表现出专业、冷静、惜字如金。用口语化的方式简短回答"
|
||
# }},
|
||
# "李明": {{
|
||
# "姓名": "李明",
|
||
# "年龄": 30,
|
||
# "职业": "产品经理",
|
||
# "性格": "活泼健谈"
|
||
# "背景": "公司资深产品经理",
|
||
# "对话场景": "你正在和用户聊天,用户是你的同事",
|
||
# "语言风格": "热情"
|
||
# }}
|
||
|
||
# 3.输出格式:
|
||
# <format>
|
||
# 张三:[第一轮发言]
|
||
# 李明:[第一轮回应]
|
||
# 张三:[第二轮发言]
|
||
# 李明:[第二轮回应]
|
||
# </format>
|
||
# """
|
||
# )
|
||
# )
|
||
|
||
try:
|
||
# 主线程永久挂起(监听退出信号)
|
||
while True:
|
||
time.sleep(3600) # 每1小时唤醒一次(避免CPU占用)
|
||
except KeyboardInterrupt:
|
||
logger.log(logging.WARNING, "接收到中断信号,程序退出 ")
|
||
|
||
|