更新参数
This commit is contained in:
parent
ade8e69742
commit
1a4b4d8b76
@ -602,7 +602,6 @@ class DualAIDialogueEngine:
|
|||||||
context=context_str,
|
context=context_str,
|
||||||
dialogue_history = dialogue_history,
|
dialogue_history = dialogue_history,
|
||||||
history_context_count = history_context_count,
|
history_context_count = history_context_count,
|
||||||
temperature=0.8,
|
|
||||||
max_new_tokens=150
|
max_new_tokens=150
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -140,7 +140,7 @@ def run_dialogue_system():
|
|||||||
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
||||||
|
|
||||||
# 检查模型路径
|
# 检查模型路径
|
||||||
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B'
|
base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
|
||||||
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
||||||
|
|
||||||
if not os.path.exists(base_model_path):
|
if not os.path.exists(base_model_path):
|
||||||
@ -221,11 +221,11 @@ def run_dialogue_system():
|
|||||||
turns = int(turns_input) if turns_input.isdigit() else 4
|
turns = int(turns_input) if turns_input.isdigit() else 4
|
||||||
|
|
||||||
# 询问历史上下文设置
|
# 询问历史上下文设置
|
||||||
history_input = input("使用历史对话轮数 (默认10): ").strip()
|
history_input = input("使用历史对话轮数 (默认2): ").strip()
|
||||||
history_count = int(history_input) if history_input.isdigit() else 10
|
history_count = int(history_input) if history_input.isdigit() else 2
|
||||||
|
|
||||||
context_input = input("使用上下文信息数量 (默认5): ").strip()
|
context_input = input("使用上下文信息数量 (默认10): ").strip()
|
||||||
context_info_count = int(context_input) if context_input.isdigit() else 5
|
context_info_count = int(context_input) if context_input.isdigit() else 10
|
||||||
|
|
||||||
print(f"\n开始对话 - 主题: {user_input}")
|
print(f"\n开始对话 - 主题: {user_input}")
|
||||||
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
|
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
|
||||||
@ -267,7 +267,7 @@ def create_demo_scenario():
|
|||||||
conv_mgr = ConversationManager("./conversation_data/demo_conversations.db")
|
conv_mgr = ConversationManager("./conversation_data/demo_conversations.db")
|
||||||
|
|
||||||
# 检查模型路径
|
# 检查模型路径
|
||||||
base_model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-4B'
|
base_model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-4B'
|
||||||
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
lora_model_path = './output/NPC_Dialogue_LoRA/final_model'
|
||||||
|
|
||||||
if not os.path.exists(base_model_path):
|
if not os.path.exists(base_model_path):
|
||||||
|
|||||||
@ -133,9 +133,9 @@ class NPCDialogueGenerator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 如果有LoRA模型,则加载
|
# 如果有LoRA模型,则加载
|
||||||
if self.lora_model_path:
|
# if self.lora_model_path:
|
||||||
print(f"Loading LoRA weights from: {self.lora_model_path}")
|
# print(f"Loading LoRA weights from: {self.lora_model_path}")
|
||||||
self.model = PeftModel.from_pretrained(self.model, self.lora_model_path)
|
# self.model = PeftModel.from_pretrained(self.model, self.lora_model_path)
|
||||||
|
|
||||||
def generate_character_dialogue(
|
def generate_character_dialogue(
|
||||||
self,
|
self,
|
||||||
@ -172,7 +172,7 @@ class NPCDialogueGenerator:
|
|||||||
system_prompt = self._build_system_prompt(profile, context, dialogue_history, history_context_count)
|
system_prompt = self._build_system_prompt(profile, context, dialogue_history, history_context_count)
|
||||||
|
|
||||||
# 构建用户输入
|
# 构建用户输入
|
||||||
user_input = "请说一段符合你角色设定的话。"
|
user_input = "请说一段符合你角色设定的话,保持对话的连贯性。"
|
||||||
|
|
||||||
|
|
||||||
# 准备消息
|
# 准备消息
|
||||||
@ -210,10 +210,10 @@ class NPCDialogueGenerator:
|
|||||||
**inputs,
|
**inputs,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
temperature=temperature,
|
temperature=0.95,
|
||||||
top_p=top_p,
|
top_p=0.92,
|
||||||
pad_token_id=self.tokenizer.eos_token_id,
|
pad_token_id=self.tokenizer.eos_token_id,
|
||||||
repetition_penalty=1.1
|
repetition_penalty=1.15
|
||||||
)
|
)
|
||||||
|
|
||||||
# 解码输出
|
# 解码输出
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user