微调逻辑
This commit is contained in:
parent
7c8210039a
commit
38a7ac3414
@ -221,11 +221,11 @@ def run_dialogue_system():
|
||||
turns = int(turns_input) if turns_input.isdigit() else 4
|
||||
|
||||
# 询问历史上下文设置
|
||||
history_input = input("使用历史对话轮数 (默认3): ").strip()
|
||||
history_count = int(history_input) if history_input.isdigit() else 3
|
||||
history_input = input("使用历史对话轮数 (默认10): ").strip()
|
||||
history_count = int(history_input) if history_input.isdigit() else 10
|
||||
|
||||
context_input = input("使用上下文信息数量 (默认10): ").strip()
|
||||
context_info_count = int(context_input) if context_input.isdigit() else 10
|
||||
context_input = input("使用上下文信息数量 (默认5): ").strip()
|
||||
context_info_count = int(context_input) if context_input.isdigit() else 5
|
||||
|
||||
print(f"\n开始对话 - 主题: {user_input}")
|
||||
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
|
||||
|
||||
@ -193,6 +193,16 @@ class NPCDialogueGenerator:
|
||||
|
||||
# 移动到设备
|
||||
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
||||
# 计算input token数并与模型最大token数比较
|
||||
input_token_count = inputs['input_ids'].shape[1]
|
||||
try:
|
||||
max_model_tokens = self.model.config.max_position_embeddings
|
||||
except AttributeError:
|
||||
max_model_tokens = 2048
|
||||
|
||||
if input_token_count + max_new_tokens > max_model_tokens:
|
||||
print(f"警告:当前输入token数({input_token_count})加上最大生成token数({max_new_tokens})超过模型最大token数({max_model_tokens}),可能导致生成结果不完整或报错。")
|
||||
|
||||
|
||||
# 生成对话
|
||||
with torch.no_grad():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user