微调逻辑

This commit is contained in:
997146918 2025-08-18 15:08:13 +08:00
parent 7c8210039a
commit 38a7ac3414
2 changed files with 15 additions and 5 deletions

View File

@ -221,11 +221,11 @@ def run_dialogue_system():
turns = int(turns_input) if turns_input.isdigit() else 4 turns = int(turns_input) if turns_input.isdigit() else 4
# 询问历史上下文设置 # 询问历史上下文设置
history_input = input("使用历史对话轮数 (默认3): ").strip() history_input = input("使用历史对话轮数 (默认10): ").strip()
history_count = int(history_input) if history_input.isdigit() else 3 history_count = int(history_input) if history_input.isdigit() else 10
context_input = input("使用上下文信息数量 (默认10): ").strip() context_input = input("使用上下文信息数量 (默认5): ").strip()
context_info_count = int(context_input) if context_input.isdigit() else 10 context_info_count = int(context_input) if context_input.isdigit() else 5
print(f"\n开始对话 - 主题: {user_input}") print(f"\n开始对话 - 主题: {user_input}")
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}") print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")

View File

@ -193,6 +193,16 @@ class NPCDialogueGenerator:
# 移动到设备 # 移动到设备
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# 计算input token数并与模型最大token数比较
input_token_count = inputs['input_ids'].shape[1]
try:
max_model_tokens = self.model.config.max_position_embeddings
except AttributeError:
max_model_tokens = 2048
if input_token_count + max_new_tokens > max_model_tokens:
print(f"警告当前输入token数({input_token_count})加上最大生成token数({max_new_tokens})超过模型最大token数({max_model_tokens}),可能导致生成结果不完整或报错。")
# 生成对话 # 生成对话
with torch.no_grad(): with torch.no_grad():