微调逻辑
This commit is contained in:
parent
7c8210039a
commit
38a7ac3414
@ -221,11 +221,11 @@ def run_dialogue_system():
|
|||||||
turns = int(turns_input) if turns_input.isdigit() else 4
|
turns = int(turns_input) if turns_input.isdigit() else 4
|
||||||
|
|
||||||
# 询问历史上下文设置
|
# 询问历史上下文设置
|
||||||
history_input = input("使用历史对话轮数 (默认3): ").strip()
|
history_input = input("使用历史对话轮数 (默认10): ").strip()
|
||||||
history_count = int(history_input) if history_input.isdigit() else 3
|
history_count = int(history_input) if history_input.isdigit() else 10
|
||||||
|
|
||||||
context_input = input("使用上下文信息数量 (默认10): ").strip()
|
context_input = input("使用上下文信息数量 (默认5): ").strip()
|
||||||
context_info_count = int(context_input) if context_input.isdigit() else 10
|
context_info_count = int(context_input) if context_input.isdigit() else 5
|
||||||
|
|
||||||
print(f"\n开始对话 - 主题: {user_input}")
|
print(f"\n开始对话 - 主题: {user_input}")
|
||||||
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
|
print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}")
|
||||||
|
|||||||
@ -193,6 +193,16 @@ class NPCDialogueGenerator:
|
|||||||
|
|
||||||
# 移动到设备
|
# 移动到设备
|
||||||
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
||||||
|
# 计算input token数并与模型最大token数比较
|
||||||
|
input_token_count = inputs['input_ids'].shape[1]
|
||||||
|
try:
|
||||||
|
max_model_tokens = self.model.config.max_position_embeddings
|
||||||
|
except AttributeError:
|
||||||
|
max_model_tokens = 2048
|
||||||
|
|
||||||
|
if input_token_count + max_new_tokens > max_model_tokens:
|
||||||
|
print(f"警告:当前输入token数({input_token_count})加上最大生成token数({max_new_tokens})超过模型最大token数({max_model_tokens}),可能导致生成结果不完整或报错。")
|
||||||
|
|
||||||
|
|
||||||
# 生成对话
|
# 生成对话
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user