更新参数

This commit is contained in:
997146918 2025-08-18 18:25:19 +08:00
parent ade8e69742
commit 1a4b4d8b76
3 changed files with 13 additions and 14 deletions

View File

@ -602,7 +602,6 @@ class DualAIDialogueEngine:
context=context_str, context=context_str,
dialogue_history = dialogue_history, dialogue_history = dialogue_history,
history_context_count = history_context_count, history_context_count = history_context_count,
temperature=0.8,
max_new_tokens=150 max_new_tokens=150
) )

View File

@ -140,7 +140,7 @@ def run_dialogue_system():
conv_mgr = ConversationManager("./conversation_data/conversations.db") conv_mgr = ConversationManager("./conversation_data/conversations.db")
# 检查模型路径 # 检查模型路径
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
lora_model_path = './output/NPC_Dialogue_LoRA/final_model' lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
if not os.path.exists(base_model_path): if not os.path.exists(base_model_path):
@ -221,11 +221,11 @@ def run_dialogue_system():
turns = int(turns_input) if turns_input.isdigit() else 4 turns = int(turns_input) if turns_input.isdigit() else 4
# 询问历史上下文设置 # 询问历史上下文设置
history_input = input("使用历史对话轮数 (默认10): ").strip() history_input = input("使用历史对话轮数 (默认2): ").strip()
history_count = int(history_input) if history_input.isdigit() else 10 history_count = int(history_input) if history_input.isdigit() else 2
context_input = input("使用上下文信息数量 (默认5): ").strip() context_input = input("使用上下文信息数量 (默认10): ").strip()
context_info_count = int(context_input) if context_input.isdigit() else 5 context_info_count = int(context_input) if context_input.isdigit() else 10
print(f"\n开始对话 - 主题: {user_input}") print(f"\n开始对话 - 主题: {user_input}")
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}") print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
@ -267,7 +267,7 @@ def create_demo_scenario():
conv_mgr = ConversationManager("./conversation_data/demo_conversations.db") conv_mgr = ConversationManager("./conversation_data/demo_conversations.db")
# 检查模型路径 # 检查模型路径
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B' base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
lora_model_path = './output/NPC_Dialogue_LoRA/final_model' lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
if not os.path.exists(base_model_path): if not os.path.exists(base_model_path):

View File

@ -133,9 +133,9 @@ class NPCDialogueGenerator:
) )
# 如果有LoRA模型则加载 # 如果有LoRA模型则加载
if self.lora_model_path: # if self.lora_model_path:
print(f"Loading LoRA weights from: {self.lora_model_path}") # print(f"Loading LoRA weights from: {self.lora_model_path}")
self.model = PeftModel.from_pretrained(self.model, self.lora_model_path) # self.model = PeftModel.from_pretrained(self.model, self.lora_model_path)
def generate_character_dialogue( def generate_character_dialogue(
self, self,
@ -172,7 +172,7 @@ class NPCDialogueGenerator:
system_prompt = self._build_system_prompt(profile, context, dialogue_history, history_context_count) system_prompt = self._build_system_prompt(profile, context, dialogue_history, history_context_count)
# 构建用户输入 # 构建用户输入
user_input = "请说一段符合你角色设定的话" user_input = "请说一段符合你角色设定的话,保持对话的连贯性"
# 准备消息 # 准备消息
@ -210,10 +210,10 @@ class NPCDialogueGenerator:
**inputs, **inputs,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
do_sample=True, do_sample=True,
temperature=temperature, temperature=0.95,
top_p=top_p, top_p=0.92,
pad_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1 repetition_penalty=1.15
) )
# 解码输出 # 解码输出