diff --git a/AIGC/AICore.py b/AIGC/AICore.py
index b01ae16..521bd18 100644
--- a/AIGC/AICore.py
+++ b/AIGC/AICore.py
@@ -1,21 +1,47 @@
+import logging
+from typing import Tuple
import requests
from ollama import Client, ResponseError
import tiktoken
+import random
+from Utils.AIGCLog import AIGCLog
class AICore:
modelMaxTokens = 128000
# 初始化 DeepSeek 使用的 Tokenizer (cl100k_base)
encoder = tiktoken.get_encoding("cl100k_base")
+ logger = AIGCLog(name = "AIGC", log_file = "aigc.log")
-
def __init__(self, model):
#初始化ollama客户端
self.ollamaClient = Client(host='http://localhost:11434', headers={'x-some-header': 'some-value'})
+ self.modelName = model
response = self.ollamaClient.show(model)
modelMaxTokens = response.modelinfo['qwen2.context_length']
def getPromptToken(self, prompt)-> int:
tokens = self.encoder.encode(prompt)
return len(tokens)
-
\ No newline at end of file
+
+ def generateAI(self, promptStr: str) -> Tuple[bool, str]:
+ try:
+ response = self.ollamaClient.generate(
+ model = self.modelName,
+ stream = False,
+ prompt = promptStr,
+ options={
+ "temperature": random.uniform(1.0, 1.5),
+ "repeat_penalty": 1.2, # 抑制重复
+ "top_p": random.uniform(0.7, 0.95),
+ "num_ctx": 4096, # 上下文长度
+ }
+ )
+ return True, response.response
+ except ResponseError as e:
+ if e.status_code == 503:
+ print("🔄 服务不可用,5秒后重试...")
+ return False,"ollama 服务不可用"
+ except Exception as e:
+ print(f"🔥 未预料错误: {str(e)}")
+ return False, "未预料错误"
\ No newline at end of file
diff --git a/AIGC/DatabaseHandle.py b/AIGC/DatabaseHandle.py
index 02f94e7..43ba7b7 100644
--- a/AIGC/DatabaseHandle.py
+++ b/AIGC/DatabaseHandle.py
@@ -103,17 +103,18 @@ class DatabaseHandle:
conn.commit()
return cursor.lastrowid
- def get_chats_by_character_id(self, character_id: int) -> list:
+ def get_chats_by_character_id(self, character_id: str) -> list:
"""
根据角色ID查询聊天记录(target_id为空时返回全部数据)
:param target_id: 目标角色ID(None时返回全部记录)
:return: 聊天记录字典列表
"""
+
+ sorted_ids = sorted(character_id.split(","), key=int) # 按数值升序
+ normalized_param = ",".join(sorted_ids)
with self._get_connection() as conn:
cursor = conn.cursor()
- sql = "SELECT * FROM chat_records WHERE ',' || character_ids || ',' LIKE '%,' || ? || ',%'"
- params = (str(character_id))
- cursor.execute(sql, params)
- # 转换结果为字典列表
+ sql = "SELECT * FROM chat_records WHERE character_ids = ?"
+ cursor.execute(sql, (normalized_param,))
columns = [col[0] for col in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
\ No newline at end of file
diff --git a/AIGC/StartCommand.bat b/AIGC/StartCommand.bat
index a467633..229d749 100644
--- a/AIGC/StartCommand.bat
+++ b/AIGC/StartCommand.bat
@@ -2,7 +2,7 @@
chcp 65001 > nul
set OLLAMA_MODEL=deepseek-r1:7b
rem 启动Ollama服务
-start "Ollama DeepSeek" cmd /k ollama run %OLLAMA_MODEL%
+start "Ollama DeepSeek" cmd /k ollama serve
rem 检测11434端口是否就绪
echo 等待Ollama服务启动...
diff --git a/AIGC/data.db b/AIGC/data.db
index 83fb5d4..01eaf47 100644
Binary files a/AIGC/data.db and b/AIGC/data.db differ
diff --git a/AIGC/main.py b/AIGC/main.py
index 6c81eae..b3fb640 100644
--- a/AIGC/main.py
+++ b/AIGC/main.py
@@ -26,11 +26,15 @@ parser.add_argument('--model', type=str, default='deepseek-r1:1.5b',
args = parser.parse_args()
logger.log(logging.INFO, f"使用的模型是 {args.model}")
-maxAIRegerateCount = 5
+maxAIRegerateCount = 5 #最大重新生成次数
+regenerateCount = 1 #当前重新生成次数
+totalAIGenerateCount = 1 #客户端生成AI响应总数
+currentGenerateCount = 0 #当前生成次数
lastPrompt = ""
-
+character_id1 = 0
+character_id2 = 0
aicore = AICore(args.model)
-
+database = DatabaseHandle()
async def heartbeat(websocket: WebSocket):
pass
@@ -48,6 +52,18 @@ async def senddata(websocket: WebSocket, protocol: dict):
json_string = json.dumps(protocol, ensure_ascii=False)
await websocket.send_text(json_string)
+async def sendprotocol(websocket: WebSocket, cmd: str, status: int, message: str, data: str):
+ # 将AI响应发送回UE5
+ protocol = {}
+ protocol["cmd"] = cmd
+ protocol["status"] = status
+ protocol["message"] = message
+ protocol["data"] = data
+
+ if websocket.client_state == WebSocketState.CONNECTED:
+ json_string = json.dumps(protocol, ensure_ascii=False)
+ await websocket.send_text(json_string)
+
# WebSocket路由处理
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
@@ -62,20 +78,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str):
data = await websocket.receive_text()
logger.log(logging.INFO, f"收到UE5消息 [{client_id}]: {data}")
await process_protocol_json(data, websocket)
-
- # success, prompt = process_prompt(data)
- # global lastPrompt
- # lastPrompt = prompt
- # # 调用AI生成响应
- # if(success):
- # asyncio.create_task(generateAIChat(prompt, websocket))
- # await senddata(websocket, 0, [])
- # else:
- # await senddata(websocket, -1, [])
-
-
-
-
+
except WebSocketDisconnect:
#manager.disconnect(client_id)
logger.log(logging.WARNING, f"UE5客户端主动断开 [{client_id}]")
@@ -95,6 +98,72 @@ async def handle_characterlist(client: WebSocket):
protocol["data"] = json.dumps(characterforUE)
await senddata(client, protocol)
+async def generate_aichat(promptStr: str, client: WebSocket| None = None):
+ dynamic_token = str(int(time.time() % 1000))
+ promptStr = f"""
+ [动态标识码:{dynamic_token}]
+ """ + promptStr
+ logger.log(logging.INFO, "prompt:" + promptStr)
+ starttime = time.time()
+ success, response, = aicore.generateAI(promptStr)
+ if(success):
+ logger.log(logging.INFO, "接口调用耗时 :" + str(time.time() - starttime))
+ logger.log(logging.INFO, "AI生成" + response)
+ #处理ai输出内容
+ think_remove_text = re.sub(r'.*?', '', response, flags=re.DOTALL)
+ pattern = r".*(.*?)" # .* 吞掉前面所有字符,定位最后一组
+ match = re.search(pattern, think_remove_text, re.DOTALL)
+ if not match:
+ #生成内容格式错误
+ if await reGenerateAIChat(lastPrompt, client):
+ pass
+ else:
+ #超过重新生成次数
+ logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
+ await sendprotocol(client, "AiChatGenerate", 0, "请更换prompt,或者升级模型大小", "")
+ else:
+ #生成内容格式正确
+ core_dialog = match.group(1).strip()
+ dialog_lines = [line.strip() for line in core_dialog.split('\n') if line.strip()]
+ if len(dialog_lines) != 4:
+ #生成内容格式错误
+ if await reGenerateAIChat(lastPrompt, client):
+ pass
+ else:
+ logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
+ await sendprotocol(client, "AiChatGenerate", 0, "请更换prompt,或者升级模型大小", "")
+ else:
+ logger.log(logging.INFO, "AI的输出正确:\n" + core_dialog)
+ global regenerateCount
+ regenerateCount = 0
+ #保存数据到数据库
+ database.add_chat({"character_ids":f"{character_id1},{character_id2}","chat":f"{" ".join(dialog_lines)}"})
+
+ await sendprotocol(client, "AiChatGenerate", 1, "AI生成成功", "|".join(dialog_lines))
+ else:
+ await sendprotocol(client, "AiChatGenerate", -1, "调用ollama服务失败", "")
+
+async def handle_aichat_generate(client: WebSocket, aichat_data:str):
+ ### 处理ai prompt###
+ success, prompt = process_prompt(aichat_data)
+ global lastPrompt
+ lastPrompt = prompt
+
+ # 调用AI生成响应
+ if(success):
+ #asyncio.create_task(generateAIChat(prompt, client))
+ global currentGenerateCount
+ while currentGenerateCount < totalAIGenerateCount:
+ currentGenerateCount += 1
+ await generate_aichat(prompt, client)
+
+ currentGenerateCount = 0
+ #全部生成完成
+ await sendprotocol(client, "AiChatGenerate", 2, "AI生成成功", "")
+ else:
+ #prompt生成失败
+ await sendprotocol(client, "AiChatGenerate", -1, "prompt convert failed", "")
+
async def handle_addcharacter(client: WebSocket, chracterJson: str):
### 添加角色到数据库 ###
character_info = json.loads(chracterJson)
@@ -113,6 +182,8 @@ async def process_protocol_json(json_str: str, client: WebSocket):
await handle_characterlist(client)
elif cmd == "AddCharacter":
await handle_addcharacter(client, data)
+ elif cmd == "AiChatGenerate":
+ await handle_aichat_generate(client, data)
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}")
@@ -121,42 +192,65 @@ async def process_protocol_json(json_str: str, client: WebSocket):
def process_prompt(promptFromUE: str) -> Tuple[bool, str]:
try:
data = json.loads(promptFromUE)
+ global maxAIRegerateCount
# 提取数据
- dialog_scene = data["dialogContent"]["dialogScene"]
- persons = data["persons"]
+ dialog_scene = data["dialogScene"]
+ global totalAIGenerateCount
+ totalAIGenerateCount = data["generateCount"]
+ persons = data["characterName"]
assert len(persons) == 2
- for person in persons:
- print(f" 姓名: {person['name']}, 职业: {person['job']}")
-
+ characterInfo1 = database.get_character_byname(persons[0])
+ characterInfo2 = database.get_character_byname(persons[1])
+ global character_id1, character_id2
+ character_id1 = characterInfo1[0]["id"]
+ character_id2 = characterInfo2[0]["id"]
+ chat_history = database.get_chats_by_character_id(str(character_id1) + "," + str(character_id2))
+ #整理对话记录
+ result = result = '\n'.join([item['chat'] for item in chat_history])
+
+
prompt = f"""
- 你是一个游戏NPC对话生成器。请严格按以下要求生成两个路人NPC({persons[0]["name"]}和{persons[1]["name"]})的日常对话:
- 1. 生成【2轮完整对话】,每轮包含双方各一次发言(共4句)
- 2. 对话场景:{dialog_scene}
- 3. 角色设定:
- {persons[0]["name"]}:{persons[0]["job"]}
- {persons[1]["name"]}:{persons[1]["job"]}
- 4. 对话要求:
- * 每轮对话需自然衔接,体现生活细节
- * 避免任务指引或玩家交互内容
- * 结尾保持对话未完成感
- 5. 输出格式:
-
- {persons[0]["name"]}:[第一轮发言]
- {persons[1]["name"]}:[第一轮回应]
- {persons[0]["name"]}:[第二轮发言]
- {persons[1]["name"]}:[第二轮回应]
-
-
- 6.重要!若未按此格式输出,请重新生成直至完全符合
+ #你是一个游戏NPC对话生成器。请严格按以下要求生成两个角色的日常对话
+ #对话的背景
+ {dialog_scene}
+ 1. 生成【2轮完整对话】,每轮包含双方各一次发言(共4句)
+ 2.角色设定
+ {characterInfo1[0]["name"]}: {{
+ "姓名": {characterInfo1[0]["name"]},
+ "年龄": {characterInfo1[0]["age"]},
+ "性格": {characterInfo1[0]["personality"]},
+ "职业": {characterInfo1[0]["profession"]},
+ "背景": {characterInfo1[0]["characterBackground"]},
+ "语言风格": {characterInfo1[0]["chat_style"]}
+ }},
+ {characterInfo2[0]["name"]}: {{
+ "姓名": {characterInfo2[0]["name"]},
+ "年龄": {characterInfo2[0]["age"]},
+ "性格": {characterInfo2[0]["personality"]},
+ "职业": {characterInfo2[0]["profession"]},
+ "背景": {characterInfo2[0]["characterBackground"]},
+ "语言风格": {characterInfo2[0]["chat_style"]}
+ }}
+ 3.参考的历史对话内容
+ {result}
+ 4.输出格式:
+
+ 张三:[第一轮发言]
+ 李明:[第一轮回应]
+ 张三:[第二轮发言]
+ 李明:[第二轮回应]
+
+ 5.重要!若未按此格式输出,请重新生成直至完全符合
"""
return True, prompt
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}")
return False, ""
- except KeyError as e:
- print(f"缺少必要字段: {e}")
+ except Exception as e:
+ print(f"发生错误:{type(e).__name__} - {e}")
+ return False, ""
@@ -173,91 +267,15 @@ def run_webserver():
log_level="info"
)
-async def generateAIChat(promptStr: str, websocket: WebSocket| None = None):
- #动态标识吗 防止重复输入导致的结果重复
- dynamic_token = str(int(time.time() % 1000))
- promptStr = f"""
- [动态标识码:{dynamic_token}]
- """ + promptStr
- logger.log(logging.INFO, "prompt:" + promptStr)
- starttime = time.time()
- receivemessage=[
- {"role": "system", "content": promptStr}
- ]
- try:
-
- # response = ollamaClient.chat(
- # model = args.model,
- # stream = False,
- # messages = receivemessage,
- # options={
- # "temperature": random.uniform(1.0, 1.5),
- # "repeat_penalty": 1.2, # 抑制重复
- # "top_p": random.uniform(0.7, 0.95),
- # "num_ctx": 4096, # 上下文长度
- # "seed": int(time.time() * 1000) % 1000000
- # }
- # )
- response = ollamaClient.generate(
- model = args.model,
- stream = False,
- prompt = promptStr,
- options={
- "temperature": random.uniform(1.0, 1.5),
- "repeat_penalty": 1.2, # 抑制重复
- "top_p": random.uniform(0.7, 0.95),
- "num_ctx": 4096, # 上下文长度
- }
- )
-
- except ResponseError as e:
- if e.status_code == 503:
- print("🔄 服务不可用,5秒后重试...")
- return await senddata(websocket, -1, messages=["ollama 服务不可用"])
- except Exception as e:
- print(f"🔥 未预料错误: {str(e)}")
- return await senddata(websocket, -1, messages=["未预料错误"])
- logger.log(logging.INFO, "接口调用耗时 :" + str(time.time() - starttime))
- #aiResponse = response['message']['content']
- aiResponse = response['response']
- logger.log(logging.INFO, "AI生成" + aiResponse)
- #处理ai输出内容
- think_remove_text = re.sub(r'.*?', '', aiResponse, flags=re.DOTALL)
- pattern = r".*(.*?)" # .* 吞掉前面所有字符,定位最后一组
- match = re.search(pattern, think_remove_text, re.DOTALL)
-
- if not match:
- if await reGenerateAIChat(lastPrompt, websocket):
- pass
- else:
- logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
- await senddata(websocket, -1, messages=["请更换prompt,或者升级模型大小"])
-
- else:
- core_dialog = match.group(1).strip()
- dialog_lines = [line for line in core_dialog.split('\n') if line.strip()]
- if len(dialog_lines) != 4:
- if await reGenerateAIChat(lastPrompt, websocket):
- pass
- else:
- logger.log(logging.ERROR, "请更换prompt,或者升级模型大小")
- await senddata(websocket, -1, messages=["请更换prompt,或者升级模型大小"])
- else:
- logger.log(logging.INFO, "AI的输出正确:\n" + core_dialog)
- global regenerateCount
- regenerateCount = 0
- await senddata(websocket, 1, dialog_lines)
-
-regenerateCount = 1
async def reGenerateAIChat(prompt: str, websocket: WebSocket):
global regenerateCount
if regenerateCount < maxAIRegerateCount:
regenerateCount += 1
logger.log(logging.ERROR, f"AI输出格式不正确,重新进行生成 {regenerateCount}/{maxAIRegerateCount}")
- await senddata(websocket, 2, messages=["ai生成格式不正确, 重新进行生成"])
+ await sendprotocol(websocket, "AiChatGenerate", 0, "ai生成格式不正确, 重新进行生成", "")
await asyncio.sleep(0)
- prompt = prompt + "补充:上一次的输出格式错误,严格执行prompt中第5条的输出格式要求"
- await generateAIChat(prompt, websocket)
+ prompt = prompt + "补充:上一次的输出格式错误,严格执行prompt中的输出格式要求"
+ await generate_aichat(prompt, websocket)
return True
else:
regenerateCount = 0
@@ -272,19 +290,18 @@ if __name__ == "__main__":
#Test database
- database = DatabaseHandle()
+
- id = database.add_character({"name":"李明","age":30,"personality":"活泼健谈","profession":"产品经理"
- ,"characterBackground":"公司资深产品经理","chat_style":"热情"})
+ # id = database.add_character({"name":"李明","age":30,"personality":"活泼健谈","profession":"产品经理"
+ # ,"characterBackground":"公司资深产品经理","chat_style":"热情"})
- characters = database.get_character_byname("")
+ #characters = database.get_character_byname("")
#chat_id = database.add_chat({"character_ids":"1,2","chat":"张三:[第一轮发言] 李明:[第一轮回应] 张三:[第二轮发言] 李明:[第二轮回应"})
- chat = database.get_chats_by_character_id(3)
- if id == 0:
- logger.log(logging.ERROR, f"角色 张三已经添加到数据库")
+ #chat = database.get_chats_by_character_id(3)
+ #
# Test AI
- aicore.getPromptToken("测试功能")
+ #aicore.getPromptToken("测试功能")
# asyncio.run(
# generateAIChat(promptStr = f"""
# #你是一个游戏NPC对话生成器。请严格按以下要求生成两个角色的日常对话