Project02/AIGC/main.py

363 lines
14 KiB
Python
Raw Normal View History

2025-06-06 10:28:34 +08:00
import argparse
2025-06-05 19:55:37 +08:00
import asyncio
import json
import logging
import random
2025-06-05 19:55:37 +08:00
import re
import threading
import time
from typing import List, Tuple
from fastapi import FastAPI, Request, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.websockets import WebSocketState
from h11 import ConnectionClosed
import uvicorn
from AICore import AICore
2025-07-03 19:01:27 +08:00
from DatabaseHandle import DatabaseHandle
2025-06-05 19:55:37 +08:00
from Utils.AIGCLog import AIGCLog
2025-06-05 19:55:37 +08:00
app = FastAPI(title = "AI 通信服务")
2025-06-27 17:25:46 +08:00
logger = AIGCLog(name = "AIGC", log_file = "aigc.log")
2025-06-06 10:28:34 +08:00
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='deepseek-r1:1.5b',
help='Ollama模型名称')
args = parser.parse_args()
logger.log(logging.INFO, f"使用的模型是 {args.model}")
2025-07-09 19:58:44 +08:00
maxAIRegerateCount = 5 #最大重新生成次数
regenerateCount = 1 #当前重新生成次数
totalAIGenerateCount = 1 #客户端生成AI响应总数
currentGenerateCount = 0 #当前生成次数
lastPrompt = ""
2025-07-09 19:58:44 +08:00
character_id1 = 0
character_id2 = 0
aicore = AICore(args.model)
2025-07-09 19:58:44 +08:00
database = DatabaseHandle()
2025-06-05 19:55:37 +08:00
async def heartbeat(websocket: WebSocket):
2025-06-27 17:25:46 +08:00
pass
# while True:
# await asyncio.sleep(30) # 每30秒发送一次心跳
# try:
# await senddata(websocket, 0, [])
# except ConnectionClosed:
# break # 连接已关闭时退出循环
2025-06-05 19:55:37 +08:00
#statuscode -1 服务器运行错误 0 心跳标志 1 正常 2 输出异常
2025-07-07 09:33:56 +08:00
async def senddata(websocket: WebSocket, protocol: dict):
2025-06-05 19:55:37 +08:00
# 将AI响应发送回UE5
if websocket.client_state == WebSocketState.CONNECTED:
2025-07-07 09:33:56 +08:00
json_string = json.dumps(protocol, ensure_ascii=False)
2025-06-05 19:55:37 +08:00
await websocket.send_text(json_string)
2025-07-09 19:58:44 +08:00
async def sendprotocol(websocket: WebSocket, cmd: str, status: int, message: str, data: str):
# 将AI响应发送回UE5
protocol = {}
protocol["cmd"] = cmd
protocol["status"] = status
protocol["message"] = message
protocol["data"] = data
if websocket.client_state == WebSocketState.CONNECTED:
json_string = json.dumps(protocol, ensure_ascii=False)
await websocket.send_text(json_string)
2025-06-05 19:55:37 +08:00
# WebSocket路由处理
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await websocket.accept()
logger.log(logging.INFO, f"UE5客户端 {client_id} 已连接")
# 添加心跳任务
asyncio.create_task(heartbeat(websocket))
try:
while True:
2025-06-05 19:55:37 +08:00
# 接收UE5发来的消息
data = await websocket.receive_text()
logger.log(logging.INFO, f"收到UE5消息 [{client_id}]: {data}")
2025-07-07 09:33:56 +08:00
await process_protocol_json(data, websocket)
2025-07-09 19:58:44 +08:00
2025-06-05 19:55:37 +08:00
except WebSocketDisconnect:
#manager.disconnect(client_id)
logger.log(logging.WARNING, f"UE5客户端主动断开 [{client_id}]")
except Exception as e:
#manager.disconnect(client_id)
logger.log(logging.ERROR, f"WebSocket异常 [{client_id}]: {str(e)}")
2025-07-07 09:33:56 +08:00
async def handle_characterlist(client: WebSocket):
### 获得数据库中的角色信息###
characters = database.get_character_byname("")
protocol = {}
2025-07-10 09:30:45 +08:00
protocol["cmd"] = "RequestCharacterInfos"
protocol["status"] = 1
protocol["message"] = "success"
characterforUE = {}
characterforUE["characterInfos"] = characters
protocol["data"] = json.dumps(characterforUE)
await senddata(client, protocol)
async def handle_characternames(client: WebSocket):
### 获得数据库中的角色信息###
characters = database.get_character_byname("")
protocol = {}
protocol["cmd"] = "RequestCharacterNames"
2025-07-07 09:33:56 +08:00
protocol["status"] = 1
protocol["message"] = "success"
2025-07-07 17:18:41 +08:00
characterforUE = {}
characterforUE["characterInfos"] = characters
protocol["data"] = json.dumps(characterforUE)
2025-07-07 09:33:56 +08:00
await senddata(client, protocol)
2025-07-09 19:58:44 +08:00
async def generate_aichat(promptStr: str, client: WebSocket| None = None):
dynamic_token = str(int(time.time() % 1000))
promptStr = f"""
[动态标识码:{dynamic_token}]
""" + promptStr
logger.log(logging.INFO, "prompt:" + promptStr)
starttime = time.time()
success, response, = aicore.generateAI(promptStr)
if(success):
logger.log(logging.INFO, "接口调用耗时 :" + str(time.time() - starttime))
logger.log(logging.INFO, "AI生成" + response)
#处理ai输出内容
think_remove_text = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
pattern = r".*<format>(.*?)</format>" # .* 吞掉前面所有字符,定位最后一组
match = re.search(pattern, think_remove_text, re.DOTALL)
if not match:
#生成内容格式错误
if await reGenerateAIChat(lastPrompt, client):
pass
else:
#超过重新生成次数
logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
await sendprotocol(client, "AiChatGenerate", 0, "请更换prompt,或者升级模型大小", "")
else:
#生成内容格式正确
core_dialog = match.group(1).strip()
dialog_lines = [line.strip() for line in core_dialog.split('\n') if line.strip()]
if len(dialog_lines) != 4:
#生成内容格式错误
if await reGenerateAIChat(lastPrompt, client):
pass
else:
logger.log(logging.ERROR, "请更换prompt或者升级模型大小")
await sendprotocol(client, "AiChatGenerate", 0, "请更换prompt或者升级模型大小", "")
else:
logger.log(logging.INFO, "AI的输出正确\n" + core_dialog)
global regenerateCount
regenerateCount = 0
#保存数据到数据库
database.add_chat({"character_ids":f"{character_id1},{character_id2}","chat":f"{" ".join(dialog_lines)}"})
await sendprotocol(client, "AiChatGenerate", 1, "AI生成成功", "|".join(dialog_lines))
else:
await sendprotocol(client, "AiChatGenerate", -1, "调用ollama服务失败", "")
async def handle_aichat_generate(client: WebSocket, aichat_data:str):
### 处理ai prompt###
success, prompt = process_prompt(aichat_data)
global lastPrompt
lastPrompt = prompt
# 调用AI生成响应
if(success):
#asyncio.create_task(generateAIChat(prompt, client))
global currentGenerateCount
while currentGenerateCount < totalAIGenerateCount:
currentGenerateCount += 1
await generate_aichat(prompt, client)
currentGenerateCount = 0
#全部生成完成
await sendprotocol(client, "AiChatGenerate", 2, "AI生成成功", "")
else:
#prompt生成失败
await sendprotocol(client, "AiChatGenerate", -1, "prompt convert failed", "")
2025-07-07 18:21:29 +08:00
async def handle_addcharacter(client: WebSocket, chracterJson: str):
### 添加角色到数据库 ###
character_info = json.loads(chracterJson)
id = database.add_character(character_info)
logger.log(logging.INFO, f"添加角色到数据库 id = {id}")
# id = database.add_character({"name":"张三","age":35,"personality":"成熟稳重/惜字如金","profession":"阿里巴巴算法工程师"
# ,"characterBackground":"浙大计算机系毕业专注AI优化项目","chat_style":"请在对话中表现出专业、冷静、惜字如金。用口语化的方式简短回答"})
2025-07-07 09:33:56 +08:00
async def process_protocol_json(json_str: str, client: WebSocket):
### 处理协议JSON ###
try:
protocol = json.loads(json_str)
cmd = protocol.get("cmd")
data = protocol.get("data")
2025-07-10 09:30:45 +08:00
if cmd == "RequestCharacterInfos":
2025-07-07 09:33:56 +08:00
await handle_characterlist(client)
2025-07-10 09:30:45 +08:00
elif cmd == "RequestCharacterNames":
await handle_characternames(client)
2025-07-07 18:21:29 +08:00
elif cmd == "AddCharacter":
await handle_addcharacter(client, data)
2025-07-09 19:58:44 +08:00
elif cmd == "AiChatGenerate":
await handle_aichat_generate(client, data)
2025-07-07 09:33:56 +08:00
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}")
2025-06-05 19:55:37 +08:00
def process_prompt(promptFromUE: str) -> Tuple[bool, str]:
try:
data = json.loads(promptFromUE)
2025-07-09 19:58:44 +08:00
global maxAIRegerateCount
2025-06-05 19:55:37 +08:00
# 提取数据
2025-07-09 19:58:44 +08:00
dialog_scene = data["dialogScene"]
global totalAIGenerateCount
totalAIGenerateCount = data["generateCount"]
persons = data["characterName"]
2025-06-05 19:55:37 +08:00
assert len(persons) == 2
2025-07-09 19:58:44 +08:00
characterInfo1 = database.get_character_byname(persons[0])
characterInfo2 = database.get_character_byname(persons[1])
global character_id1, character_id2
character_id1 = characterInfo1[0]["id"]
character_id2 = characterInfo2[0]["id"]
chat_history = database.get_chats_by_character_id(str(character_id1) + "," + str(character_id2))
#整理对话记录
result = result = '\n'.join([item['chat'] for item in chat_history])
2025-06-05 19:55:37 +08:00
prompt = f"""
2025-07-09 19:58:44 +08:00
#你是一个游戏NPC对话生成器。请严格按以下要求生成两个角色的日常对话
#对话的背景
{dialog_scene}
1. 生成2轮完整对话每轮包含双方各一次发言共4句
2.角色设定
{characterInfo1[0]["name"]}: {{
"姓名": {characterInfo1[0]["name"]},
"年龄": {characterInfo1[0]["age"]},
"性格": {characterInfo1[0]["personality"]},
"职业": {characterInfo1[0]["profession"]},
"背景": {characterInfo1[0]["characterBackground"]},
"语言风格": {characterInfo1[0]["chat_style"]}
}},
{characterInfo2[0]["name"]}: {{
"姓名": {characterInfo2[0]["name"]},
"年龄": {characterInfo2[0]["age"]},
"性格": {characterInfo2[0]["personality"]},
"职业": {characterInfo2[0]["profession"]},
"背景": {characterInfo2[0]["characterBackground"]},
"语言风格": {characterInfo2[0]["chat_style"]}
}}
3.参考的历史对话内容
{result}
4.输出格式
<format>
张三[第一轮发言]
李明[第一轮回应]
张三[第二轮发言]
李明[第二轮回应]
</format>
5.重要若未按此格式输出请重新生成直至完全符合
2025-06-05 19:55:37 +08:00
"""
return True, prompt
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}")
return False, ""
2025-07-09 19:58:44 +08:00
except Exception as e:
print(f"发生错误:{type(e).__name__} - {e}")
return False, ""
2025-06-05 19:55:37 +08:00
def run_webserver():
logger.log(logging.INFO, "启动web服务器 ")
#启动服务器
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
timeout_keep_alive=3000,
log_level="info"
)
async def reGenerateAIChat(prompt: str, websocket: WebSocket):
global regenerateCount
if regenerateCount < maxAIRegerateCount:
regenerateCount += 1
logger.log(logging.ERROR, f"AI输出格式不正确重新进行生成 {regenerateCount}/{maxAIRegerateCount}")
2025-07-09 19:58:44 +08:00
await sendprotocol(websocket, "AiChatGenerate", 0, "ai生成格式不正确 重新进行生成", "")
await asyncio.sleep(0)
2025-07-09 19:58:44 +08:00
prompt = prompt + "补充上一次的输出格式错误严格执行prompt中的输出格式要求"
await generate_aichat(prompt, websocket)
return True
else:
regenerateCount = 0
logger.log(logging.ERROR, "输出不正确 超过最大生成次数")
return False
2025-06-05 19:55:37 +08:00
if __name__ == "__main__":
# 创建并启动服务器线程
server_thread = threading.Thread(target=run_webserver)
server_thread.daemon = True # 设为守护线程(主程序退出时自动终止)
server_thread.start()
#Test database
2025-07-09 19:58:44 +08:00
2025-07-07 18:21:29 +08:00
2025-07-03 19:01:27 +08:00
2025-07-09 19:58:44 +08:00
# id = database.add_character({"name":"李明","age":30,"personality":"活泼健谈","profession":"产品经理"
# ,"characterBackground":"公司资深产品经理","chat_style":"热情"})
2025-07-09 19:58:44 +08:00
#characters = database.get_character_byname("")
#chat_id = database.add_chat({"character_ids":"1,2","chat":"张三:[第一轮发言] 李明:[第一轮回应] 张三:[第二轮发言] 李明:[第二轮回应"})
2025-07-09 19:58:44 +08:00
#chat = database.get_chats_by_character_id(3)
#
# Test AI
2025-07-09 19:58:44 +08:00
#aicore.getPromptToken("测试功能")
2025-07-07 09:33:56 +08:00
# asyncio.run(
# generateAIChat(promptStr = f"""
# #你是一个游戏NPC对话生成器。请严格按以下要求生成两个角色的日常对话
# #对话的世界观背景是2025年的都市背景
# 1. 生成【2轮完整对话】每轮包含双方各一次发言共4句
# 2.角色设定
2025-07-02 11:14:50 +08:00
2025-07-07 09:33:56 +08:00
# "张三": {{
# "姓名": "张三",
# "年龄": 35,
# "性格": "成熟稳重/惜字如金",
# "职业": "阿里巴巴算法工程师",
# "背景": "浙大计算机系毕业专注AI优化项目",
# "对话场景": "你正在和用户聊天,用户是你的同事",
# "语言风格": "请在对话中表现出专业、冷静、惜字如金。用口语化的方式简短回答"
# }},
# "李明": {{
# "姓名": "李明",
# "年龄": 30,
# "职业": "产品经理",
# "性格": "活泼健谈"
# "背景": "公司资深产品经理",
# "对话场景": "你正在和用户聊天,用户是你的同事",
# "语言风格": "热情"
# }}
2025-07-02 11:14:50 +08:00
2025-07-07 09:33:56 +08:00
# 3.输出格式:
# <format>
# 张三:[第一轮发言]
# 李明:[第一轮回应]
# 张三:[第二轮发言]
# 李明:[第二轮回应]
# </format>
# """
# )
# )
2025-07-01 16:16:32 +08:00
2025-06-05 19:55:37 +08:00
try:
# 主线程永久挂起(监听退出信号)
while True:
time.sleep(3600) # 每1小时唤醒一次避免CPU占用
except KeyboardInterrupt:
logger.log(logging.WARNING, "接收到中断信号,程序退出 ")