diff --git a/AITrain/dual_ai_dialogue_system.py b/AITrain/dual_ai_dialogue_system.py index 9880112..ab4de91 100644 --- a/AITrain/dual_ai_dialogue_system.py +++ b/AITrain/dual_ai_dialogue_system.py @@ -602,7 +602,6 @@ class DualAIDialogueEngine: context=context_str, dialogue_history = dialogue_history, history_context_count = history_context_count, - temperature=0.8, max_new_tokens=150 ) diff --git a/AITrain/main_controller.py b/AITrain/main_controller.py index ad1784f..e6c0f51 100644 --- a/AITrain/main_controller.py +++ b/AITrain/main_controller.py @@ -140,7 +140,7 @@ def run_dialogue_system(): conv_mgr = ConversationManager("./conversation_data/conversations.db") # 检查模型路径 - base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' + base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B' lora_model_path = './output/NPC_Dialogue_LoRA/final_model' if not os.path.exists(base_model_path): @@ -221,11 +221,11 @@ def run_dialogue_system(): turns = int(turns_input) if turns_input.isdigit() else 4 # 询问历史上下文设置 - history_input = input("使用历史对话轮数 (默认10): ").strip() - history_count = int(history_input) if history_input.isdigit() else 10 + history_input = input("使用历史对话轮数 (默认2): ").strip() + history_count = int(history_input) if history_input.isdigit() else 2 - context_input = input("使用上下文信息数量 (默认5): ").strip() - context_info_count = int(context_input) if context_input.isdigit() else 5 + context_input = input("使用上下文信息数量 (默认10): ").strip() + context_info_count = int(context_input) if context_input.isdigit() else 10 print(f"\n开始对话 - 主题: {user_input}") print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}") @@ -267,7 +267,7 @@ def create_demo_scenario(): conv_mgr = ConversationManager("./conversation_data/demo_conversations.db") # 检查模型路径 - base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' + base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B' lora_model_path = './output/NPC_Dialogue_LoRA/final_model' if not os.path.exists(base_model_path): diff --git a/AITrain/npc_dialogue_generator.py b/AITrain/npc_dialogue_generator.py index 5ccce21..bac40aa 100644 --- a/AITrain/npc_dialogue_generator.py +++ b/AITrain/npc_dialogue_generator.py @@ -133,9 +133,9 @@ class NPCDialogueGenerator: ) # 如果有LoRA模型,则加载 - if self.lora_model_path: - print(f"Loading LoRA weights from: {self.lora_model_path}") - self.model = PeftModel.from_pretrained(self.model, self.lora_model_path) + # if self.lora_model_path: + # print(f"Loading LoRA weights from: {self.lora_model_path}") + # self.model = PeftModel.from_pretrained(self.model, self.lora_model_path) def generate_character_dialogue( self, @@ -172,7 +172,7 @@ class NPCDialogueGenerator: system_prompt = self._build_system_prompt(profile, context, dialogue_history, history_context_count) # 构建用户输入 - user_input = "请说一段符合你角色设定的话。" + user_input = "请说一段符合你角色设定的话,保持对话的连贯性。" # 准备消息 @@ -210,10 +210,10 @@ class NPCDialogueGenerator: **inputs, max_new_tokens=max_new_tokens, do_sample=True, - temperature=temperature, - top_p=top_p, + temperature=0.95, + top_p=0.92, pad_token_id=self.tokenizer.eos_token_id, - repetition_penalty=1.1 + repetition_penalty=1.15 ) # 解码输出