diff --git a/AIGC/AICore.py b/AIGC/AICore.py index b01ae16..521bd18 100644 --- a/AIGC/AICore.py +++ b/AIGC/AICore.py @@ -1,21 +1,47 @@ +import logging +from typing import Tuple import requests from ollama import Client, ResponseError import tiktoken +import random +from Utils.AIGCLog import AIGCLog class AICore: modelMaxTokens = 128000 # 初始化 DeepSeek 使用的 Tokenizer (cl100k_base) encoder = tiktoken.get_encoding("cl100k_base") + logger = AIGCLog(name = "AIGC", log_file = "aigc.log") - def __init__(self, model): #初始化ollama客户端 self.ollamaClient = Client(host='http://localhost:11434', headers={'x-some-header': 'some-value'}) + self.modelName = model response = self.ollamaClient.show(model) modelMaxTokens = response.modelinfo['qwen2.context_length'] def getPromptToken(self, prompt)-> int: tokens = self.encoder.encode(prompt) return len(tokens) - \ No newline at end of file + + def generateAI(self, promptStr: str) -> Tuple[bool, str]: + try: + response = self.ollamaClient.generate( + model = self.modelName, + stream = False, + prompt = promptStr, + options={ + "temperature": random.uniform(1.0, 1.5), + "repeat_penalty": 1.2, # 抑制重复 + "top_p": random.uniform(0.7, 0.95), + "num_ctx": 4096, # 上下文长度 + } + ) + return True, response.response + except ResponseError as e: + if e.status_code == 503: + print("🔄 服务不可用,5秒后重试...") + return False,"ollama 服务不可用" + except Exception as e: + print(f"🔥 未预料错误: {str(e)}") + return False, "未预料错误" \ No newline at end of file diff --git a/AIGC/DatabaseHandle.py b/AIGC/DatabaseHandle.py index 02f94e7..43ba7b7 100644 --- a/AIGC/DatabaseHandle.py +++ b/AIGC/DatabaseHandle.py @@ -103,17 +103,18 @@ class DatabaseHandle: conn.commit() return cursor.lastrowid - def get_chats_by_character_id(self, character_id: int) -> list: + def get_chats_by_character_id(self, character_id: str) -> list: """ 根据角色ID查询聊天记录(target_id为空时返回全部数据) :param target_id: 目标角色ID(None时返回全部记录) :return: 聊天记录字典列表 """ + + sorted_ids = sorted(character_id.split(","), key=int) # 按数值升序 + normalized_param = ",".join(sorted_ids) with self._get_connection() as conn: cursor = conn.cursor() - sql = "SELECT * FROM chat_records WHERE ',' || character_ids || ',' LIKE '%,' || ? || ',%'" - params = (str(character_id)) - cursor.execute(sql, params) - # 转换结果为字典列表 + sql = "SELECT * FROM chat_records WHERE character_ids = ?" + cursor.execute(sql, (normalized_param,)) columns = [col[0] for col in cursor.description] return [dict(zip(columns, row)) for row in cursor.fetchall()] \ No newline at end of file diff --git a/AIGC/StartCommand.bat b/AIGC/StartCommand.bat index a467633..229d749 100644 --- a/AIGC/StartCommand.bat +++ b/AIGC/StartCommand.bat @@ -2,7 +2,7 @@ chcp 65001 > nul set OLLAMA_MODEL=deepseek-r1:7b rem 启动Ollama服务 -start "Ollama DeepSeek" cmd /k ollama run %OLLAMA_MODEL% +start "Ollama DeepSeek" cmd /k ollama serve rem 检测11434端口是否就绪 echo 等待Ollama服务启动... diff --git a/AIGC/data.db b/AIGC/data.db index 83fb5d4..01eaf47 100644 Binary files a/AIGC/data.db and b/AIGC/data.db differ diff --git a/AIGC/main.py b/AIGC/main.py index 6c81eae..b3fb640 100644 --- a/AIGC/main.py +++ b/AIGC/main.py @@ -26,11 +26,15 @@ parser.add_argument('--model', type=str, default='deepseek-r1:1.5b', args = parser.parse_args() logger.log(logging.INFO, f"使用的模型是 {args.model}") -maxAIRegerateCount = 5 +maxAIRegerateCount = 5 #最大重新生成次数 +regenerateCount = 1 #当前重新生成次数 +totalAIGenerateCount = 1 #客户端生成AI响应总数 +currentGenerateCount = 0 #当前生成次数 lastPrompt = "" - +character_id1 = 0 +character_id2 = 0 aicore = AICore(args.model) - +database = DatabaseHandle() async def heartbeat(websocket: WebSocket): pass @@ -48,6 +52,18 @@ async def senddata(websocket: WebSocket, protocol: dict): json_string = json.dumps(protocol, ensure_ascii=False) await websocket.send_text(json_string) +async def sendprotocol(websocket: WebSocket, cmd: str, status: int, message: str, data: str): + # 将AI响应发送回UE5 + protocol = {} + protocol["cmd"] = cmd + protocol["status"] = status + protocol["message"] = message + protocol["data"] = data + + if websocket.client_state == WebSocketState.CONNECTED: + json_string = json.dumps(protocol, ensure_ascii=False) + await websocket.send_text(json_string) + # WebSocket路由处理 @app.websocket("/ws/{client_id}") async def websocket_endpoint(websocket: WebSocket, client_id: str): @@ -62,20 +78,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str): data = await websocket.receive_text() logger.log(logging.INFO, f"收到UE5消息 [{client_id}]: {data}") await process_protocol_json(data, websocket) - - # success, prompt = process_prompt(data) - # global lastPrompt - # lastPrompt = prompt - # # 调用AI生成响应 - # if(success): - # asyncio.create_task(generateAIChat(prompt, websocket)) - # await senddata(websocket, 0, []) - # else: - # await senddata(websocket, -1, []) - - - - + except WebSocketDisconnect: #manager.disconnect(client_id) logger.log(logging.WARNING, f"UE5客户端主动断开 [{client_id}]") @@ -95,6 +98,72 @@ async def handle_characterlist(client: WebSocket): protocol["data"] = json.dumps(characterforUE) await senddata(client, protocol) +async def generate_aichat(promptStr: str, client: WebSocket| None = None): + dynamic_token = str(int(time.time() % 1000)) + promptStr = f""" + [动态标识码:{dynamic_token}] + """ + promptStr + logger.log(logging.INFO, "prompt:" + promptStr) + starttime = time.time() + success, response, = aicore.generateAI(promptStr) + if(success): + logger.log(logging.INFO, "接口调用耗时 :" + str(time.time() - starttime)) + logger.log(logging.INFO, "AI生成" + response) + #处理ai输出内容 + think_remove_text = re.sub(r'.*?', '', response, flags=re.DOTALL) + pattern = r".*(.*?)" # .* 吞掉前面所有字符,定位最后一组 + match = re.search(pattern, think_remove_text, re.DOTALL) + if not match: + #生成内容格式错误 + if await reGenerateAIChat(lastPrompt, client): + pass + else: + #超过重新生成次数 + logger.log(logging.ERROR, "请更换prompt,或者升级模型大小") + await sendprotocol(client, "AiChatGenerate", 0, "请更换prompt,或者升级模型大小", "") + else: + #生成内容格式正确 + core_dialog = match.group(1).strip() + dialog_lines = [line.strip() for line in core_dialog.split('\n') if line.strip()] + if len(dialog_lines) != 4: + #生成内容格式错误 + if await reGenerateAIChat(lastPrompt, client): + pass + else: + logger.log(logging.ERROR, "请更换prompt,或者升级模型大小") + await sendprotocol(client, "AiChatGenerate", 0, "请更换prompt,或者升级模型大小", "") + else: + logger.log(logging.INFO, "AI的输出正确:\n" + core_dialog) + global regenerateCount + regenerateCount = 0 + #保存数据到数据库 + database.add_chat({"character_ids":f"{character_id1},{character_id2}","chat":f"{" ".join(dialog_lines)}"}) + + await sendprotocol(client, "AiChatGenerate", 1, "AI生成成功", "|".join(dialog_lines)) + else: + await sendprotocol(client, "AiChatGenerate", -1, "调用ollama服务失败", "") + +async def handle_aichat_generate(client: WebSocket, aichat_data:str): + ### 处理ai prompt### + success, prompt = process_prompt(aichat_data) + global lastPrompt + lastPrompt = prompt + + # 调用AI生成响应 + if(success): + #asyncio.create_task(generateAIChat(prompt, client)) + global currentGenerateCount + while currentGenerateCount < totalAIGenerateCount: + currentGenerateCount += 1 + await generate_aichat(prompt, client) + + currentGenerateCount = 0 + #全部生成完成 + await sendprotocol(client, "AiChatGenerate", 2, "AI生成成功", "") + else: + #prompt生成失败 + await sendprotocol(client, "AiChatGenerate", -1, "prompt convert failed", "") + async def handle_addcharacter(client: WebSocket, chracterJson: str): ### 添加角色到数据库 ### character_info = json.loads(chracterJson) @@ -113,6 +182,8 @@ async def process_protocol_json(json_str: str, client: WebSocket): await handle_characterlist(client) elif cmd == "AddCharacter": await handle_addcharacter(client, data) + elif cmd == "AiChatGenerate": + await handle_aichat_generate(client, data) except json.JSONDecodeError as e: print(f"JSON解析错误: {e}") @@ -121,42 +192,65 @@ async def process_protocol_json(json_str: str, client: WebSocket): def process_prompt(promptFromUE: str) -> Tuple[bool, str]: try: data = json.loads(promptFromUE) + global maxAIRegerateCount # 提取数据 - dialog_scene = data["dialogContent"]["dialogScene"] - persons = data["persons"] + dialog_scene = data["dialogScene"] + global totalAIGenerateCount + totalAIGenerateCount = data["generateCount"] + persons = data["characterName"] assert len(persons) == 2 - for person in persons: - print(f" 姓名: {person['name']}, 职业: {person['job']}") - + characterInfo1 = database.get_character_byname(persons[0]) + characterInfo2 = database.get_character_byname(persons[1]) + global character_id1, character_id2 + character_id1 = characterInfo1[0]["id"] + character_id2 = characterInfo2[0]["id"] + chat_history = database.get_chats_by_character_id(str(character_id1) + "," + str(character_id2)) + #整理对话记录 + result = result = '\n'.join([item['chat'] for item in chat_history]) + + prompt = f""" - 你是一个游戏NPC对话生成器。请严格按以下要求生成两个路人NPC({persons[0]["name"]}和{persons[1]["name"]})的日常对话: - 1. 生成【2轮完整对话】,每轮包含双方各一次发言(共4句) - 2. 对话场景:{dialog_scene} - 3. 角色设定: - {persons[0]["name"]}:{persons[0]["job"]} - {persons[1]["name"]}:{persons[1]["job"]} - 4. 对话要求: - * 每轮对话需自然衔接,体现生活细节 - * 避免任务指引或玩家交互内容 - * 结尾保持对话未完成感 - 5. 输出格式: - - {persons[0]["name"]}:[第一轮发言] - {persons[1]["name"]}:[第一轮回应] - {persons[0]["name"]}:[第二轮发言] - {persons[1]["name"]}:[第二轮回应] - - - 6.重要!若未按此格式输出,请重新生成直至完全符合 + #你是一个游戏NPC对话生成器。请严格按以下要求生成两个角色的日常对话 + #对话的背景 + {dialog_scene} + 1. 生成【2轮完整对话】,每轮包含双方各一次发言(共4句) + 2.角色设定 + {characterInfo1[0]["name"]}: {{ + "姓名": {characterInfo1[0]["name"]}, + "年龄": {characterInfo1[0]["age"]}, + "性格": {characterInfo1[0]["personality"]}, + "职业": {characterInfo1[0]["profession"]}, + "背景": {characterInfo1[0]["characterBackground"]}, + "语言风格": {characterInfo1[0]["chat_style"]} + }}, + {characterInfo2[0]["name"]}: {{ + "姓名": {characterInfo2[0]["name"]}, + "年龄": {characterInfo2[0]["age"]}, + "性格": {characterInfo2[0]["personality"]}, + "职业": {characterInfo2[0]["profession"]}, + "背景": {characterInfo2[0]["characterBackground"]}, + "语言风格": {characterInfo2[0]["chat_style"]} + }} + 3.参考的历史对话内容 + {result} + 4.输出格式: + + 张三:[第一轮发言] + 李明:[第一轮回应] + 张三:[第二轮发言] + 李明:[第二轮回应] + + 5.重要!若未按此格式输出,请重新生成直至完全符合 """ return True, prompt except json.JSONDecodeError as e: print(f"JSON解析错误: {e}") return False, "" - except KeyError as e: - print(f"缺少必要字段: {e}") + except Exception as e: + print(f"发生错误:{type(e).__name__} - {e}") + return False, "" @@ -173,91 +267,15 @@ def run_webserver(): log_level="info" ) -async def generateAIChat(promptStr: str, websocket: WebSocket| None = None): - #动态标识吗 防止重复输入导致的结果重复 - dynamic_token = str(int(time.time() % 1000)) - promptStr = f""" - [动态标识码:{dynamic_token}] - """ + promptStr - logger.log(logging.INFO, "prompt:" + promptStr) - starttime = time.time() - receivemessage=[ - {"role": "system", "content": promptStr} - ] - try: - - # response = ollamaClient.chat( - # model = args.model, - # stream = False, - # messages = receivemessage, - # options={ - # "temperature": random.uniform(1.0, 1.5), - # "repeat_penalty": 1.2, # 抑制重复 - # "top_p": random.uniform(0.7, 0.95), - # "num_ctx": 4096, # 上下文长度 - # "seed": int(time.time() * 1000) % 1000000 - # } - # ) - response = ollamaClient.generate( - model = args.model, - stream = False, - prompt = promptStr, - options={ - "temperature": random.uniform(1.0, 1.5), - "repeat_penalty": 1.2, # 抑制重复 - "top_p": random.uniform(0.7, 0.95), - "num_ctx": 4096, # 上下文长度 - } - ) - - except ResponseError as e: - if e.status_code == 503: - print("🔄 服务不可用,5秒后重试...") - return await senddata(websocket, -1, messages=["ollama 服务不可用"]) - except Exception as e: - print(f"🔥 未预料错误: {str(e)}") - return await senddata(websocket, -1, messages=["未预料错误"]) - logger.log(logging.INFO, "接口调用耗时 :" + str(time.time() - starttime)) - #aiResponse = response['message']['content'] - aiResponse = response['response'] - logger.log(logging.INFO, "AI生成" + aiResponse) - #处理ai输出内容 - think_remove_text = re.sub(r'.*?', '', aiResponse, flags=re.DOTALL) - pattern = r".*(.*?)" # .* 吞掉前面所有字符,定位最后一组 - match = re.search(pattern, think_remove_text, re.DOTALL) - - if not match: - if await reGenerateAIChat(lastPrompt, websocket): - pass - else: - logger.log(logging.ERROR, "请更换prompt,或者升级模型大小") - await senddata(websocket, -1, messages=["请更换prompt,或者升级模型大小"]) - - else: - core_dialog = match.group(1).strip() - dialog_lines = [line for line in core_dialog.split('\n') if line.strip()] - if len(dialog_lines) != 4: - if await reGenerateAIChat(lastPrompt, websocket): - pass - else: - logger.log(logging.ERROR, "请更换prompt,或者升级模型大小") - await senddata(websocket, -1, messages=["请更换prompt,或者升级模型大小"]) - else: - logger.log(logging.INFO, "AI的输出正确:\n" + core_dialog) - global regenerateCount - regenerateCount = 0 - await senddata(websocket, 1, dialog_lines) - -regenerateCount = 1 async def reGenerateAIChat(prompt: str, websocket: WebSocket): global regenerateCount if regenerateCount < maxAIRegerateCount: regenerateCount += 1 logger.log(logging.ERROR, f"AI输出格式不正确,重新进行生成 {regenerateCount}/{maxAIRegerateCount}") - await senddata(websocket, 2, messages=["ai生成格式不正确, 重新进行生成"]) + await sendprotocol(websocket, "AiChatGenerate", 0, "ai生成格式不正确, 重新进行生成", "") await asyncio.sleep(0) - prompt = prompt + "补充:上一次的输出格式错误,严格执行prompt中第5条的输出格式要求" - await generateAIChat(prompt, websocket) + prompt = prompt + "补充:上一次的输出格式错误,严格执行prompt中的输出格式要求" + await generate_aichat(prompt, websocket) return True else: regenerateCount = 0 @@ -272,19 +290,18 @@ if __name__ == "__main__": #Test database - database = DatabaseHandle() + - id = database.add_character({"name":"李明","age":30,"personality":"活泼健谈","profession":"产品经理" - ,"characterBackground":"公司资深产品经理","chat_style":"热情"}) + # id = database.add_character({"name":"李明","age":30,"personality":"活泼健谈","profession":"产品经理" + # ,"characterBackground":"公司资深产品经理","chat_style":"热情"}) - characters = database.get_character_byname("") + #characters = database.get_character_byname("") #chat_id = database.add_chat({"character_ids":"1,2","chat":"张三:[第一轮发言] 李明:[第一轮回应] 张三:[第二轮发言] 李明:[第二轮回应"}) - chat = database.get_chats_by_character_id(3) - if id == 0: - logger.log(logging.ERROR, f"角色 张三已经添加到数据库") + #chat = database.get_chats_by_character_id(3) + # # Test AI - aicore.getPromptToken("测试功能") + #aicore.getPromptToken("测试功能") # asyncio.run( # generateAIChat(promptStr = f""" # #你是一个游戏NPC对话生成器。请严格按以下要求生成两个角色的日常对话