更新参数

This commit is contained in:
997146918 2025-08-18 18:25:19 +08:00
parent ade8e69742
commit 1a4b4d8b76
3 changed files with 13 additions and 14 deletions

View File

@ -602,7 +602,6 @@ class DualAIDialogueEngine:
context=context_str,
dialogue_history = dialogue_history,
history_context_count = history_context_count,
temperature=0.8,
max_new_tokens=150
)

View File

@ -140,7 +140,7 @@ def run_dialogue_system():
conv_mgr = ConversationManager("./conversation_data/conversations.db")
# 检查模型路径
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B'
base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
if not os.path.exists(base_model_path):
@ -221,11 +221,11 @@ def run_dialogue_system():
turns = int(turns_input) if turns_input.isdigit() else 4
# 询问历史上下文设置
history_input = input("使用历史对话轮数 (默认10): ").strip()
history_count = int(history_input) if history_input.isdigit() else 10
history_input = input("使用历史对话轮数 (默认2): ").strip()
history_count = int(history_input) if history_input.isdigit() else 2
context_input = input("使用上下文信息数量 (默认5): ").strip()
context_info_count = int(context_input) if context_input.isdigit() else 5
context_input = input("使用上下文信息数量 (默认10): ").strip()
context_info_count = int(context_input) if context_input.isdigit() else 10
print(f"\n开始对话 - 主题: {user_input}")
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
@ -267,7 +267,7 @@ def create_demo_scenario():
conv_mgr = ConversationManager("./conversation_data/demo_conversations.db")
# 检查模型路径
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B'
base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
if not os.path.exists(base_model_path):

View File

@ -133,9 +133,9 @@ class NPCDialogueGenerator:
)
# 如果有LoRA模型则加载
if self.lora_model_path:
print(f"Loading LoRA weights from: {self.lora_model_path}")
self.model = PeftModel.from_pretrained(self.model, self.lora_model_path)
# if self.lora_model_path:
# print(f"Loading LoRA weights from: {self.lora_model_path}")
# self.model = PeftModel.from_pretrained(self.model, self.lora_model_path)
def generate_character_dialogue(
self,
@ -172,7 +172,7 @@ class NPCDialogueGenerator:
system_prompt = self._build_system_prompt(profile, context, dialogue_history, history_context_count)
# 构建用户输入
user_input = "请说一段符合你角色设定的话"
user_input = "请说一段符合你角色设定的话,保持对话的连贯性"
# 准备消息
@ -210,10 +210,10 @@ class NPCDialogueGenerator:
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
temperature=0.95,
top_p=0.92,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1
repetition_penalty=1.15
)
# 解码输出