更新参数
This commit is contained in:
parent
ade8e69742
commit
1a4b4d8b76
@ -602,7 +602,6 @@ class DualAIDialogueEngine:
|
||||
context=context_str,
|
||||
dialogue_history = dialogue_history,
|
||||
history_context_count = history_context_count,
|
||||
temperature=0.8,
|
||||
max_new_tokens=150
|
||||
)
|
||||
|
||||
|
||||
@ -140,7 +140,7 @@ def run_dialogue_system():
|
||||
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
||||
|
||||
# 检查模型路径
|
||||
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B'
|
||||
base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
|
||||
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
||||
|
||||
if not os.path.exists(base_model_path):
|
||||
@ -221,11 +221,11 @@ def run_dialogue_system():
|
||||
turns = int(turns_input) if turns_input.isdigit() else 4
|
||||
|
||||
# 询问历史上下文设置
|
||||
history_input = input("使用历史对话轮数 (默认10): ").strip()
|
||||
history_count = int(history_input) if history_input.isdigit() else 10
|
||||
history_input = input("使用历史对话轮数 (默认2): ").strip()
|
||||
history_count = int(history_input) if history_input.isdigit() else 2
|
||||
|
||||
context_input = input("使用上下文信息数量 (默认5): ").strip()
|
||||
context_info_count = int(context_input) if context_input.isdigit() else 5
|
||||
context_input = input("使用上下文信息数量 (默认10): ").strip()
|
||||
context_info_count = int(context_input) if context_input.isdigit() else 10
|
||||
|
||||
print(f"\n开始对话 - 主题: {user_input}")
|
||||
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
|
||||
@ -267,7 +267,7 @@ def create_demo_scenario():
|
||||
conv_mgr = ConversationManager("./conversation_data/demo_conversations.db")
|
||||
|
||||
# 检查模型路径
|
||||
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B'
|
||||
base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
|
||||
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
||||
|
||||
if not os.path.exists(base_model_path):
|
||||
|
||||
@ -133,9 +133,9 @@ class NPCDialogueGenerator:
|
||||
)
|
||||
|
||||
# 如果有LoRA模型,则加载
|
||||
if self.lora_model_path:
|
||||
print(f"Loading LoRA weights from: {self.lora_model_path}")
|
||||
self.model = PeftModel.from_pretrained(self.model, self.lora_model_path)
|
||||
# if self.lora_model_path:
|
||||
# print(f"Loading LoRA weights from: {self.lora_model_path}")
|
||||
# self.model = PeftModel.from_pretrained(self.model, self.lora_model_path)
|
||||
|
||||
def generate_character_dialogue(
|
||||
self,
|
||||
@ -172,7 +172,7 @@ class NPCDialogueGenerator:
|
||||
system_prompt = self._build_system_prompt(profile, context, dialogue_history, history_context_count)
|
||||
|
||||
# 构建用户输入
|
||||
user_input = "请说一段符合你角色设定的话。"
|
||||
user_input = "请说一段符合你角色设定的话,保持对话的连贯性。"
|
||||
|
||||
|
||||
# 准备消息
|
||||
@ -210,10 +210,10 @@ class NPCDialogueGenerator:
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
temperature=0.95,
|
||||
top_p=0.92,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
repetition_penalty=1.1
|
||||
repetition_penalty=1.15
|
||||
)
|
||||
|
||||
# 解码输出
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user