微调逻辑

This commit is contained in:
997146918 2025-08-18 15:08:13 +08:00
parent 7c8210039a
commit 38a7ac3414
2 changed files with 15 additions and 5 deletions

View File

@ -221,11 +221,11 @@ def run_dialogue_system():
turns = int(turns_input) if turns_input.isdigit() else 4
# 询问历史上下文设置
history_input = input("使用历史对话轮数 (默认3): ").strip()
history_count = int(history_input) if history_input.isdigit() else 3
history_input = input("使用历史对话轮数 (默认10): ").strip()
history_count = int(history_input) if history_input.isdigit() else 10
context_input = input("使用上下文信息数量 (默认10): ").strip()
context_info_count = int(context_input) if context_input.isdigit() else 10
context_input = input("使用上下文信息数量 (默认5): ").strip()
context_info_count = int(context_input) if context_input.isdigit() else 5
print(f"\n开始对话 - 主题: {user_input}")
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")

View File

@ -136,7 +136,7 @@ class NPCDialogueGenerator:
if self.lora_model_path:
print(f"Loading LoRA weights from: {self.lora_model_path}")
self.model = PeftModel.from_pretrained(self.model, self.lora_model_path)
def generate_character_dialogue(
self,
character_name: str,
@ -193,7 +193,17 @@ class NPCDialogueGenerator:
# 移动到设备
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# 计算input token数并与模型最大token数比较
input_token_count = inputs['input_ids'].shape[1]
try:
max_model_tokens = self.model.config.max_position_embeddings
except AttributeError:
max_model_tokens = 2048
if input_token_count + max_new_tokens > max_model_tokens:
print(f"警告当前输入token数({input_token_count})加上最大生成token数({max_new_tokens})超过模型最大token数({max_model_tokens}),可能导致生成结果不完整或报错。")
# 生成对话
with torch.no_grad():
outputs = self.model.generate(