diff --git a/AITrain/main_controller.py b/AITrain/main_controller.py index 717f6a9..ad1784f 100644 --- a/AITrain/main_controller.py +++ b/AITrain/main_controller.py @@ -221,11 +221,11 @@ def run_dialogue_system(): turns = int(turns_input) if turns_input.isdigit() else 4 # 询问历史上下文设置 - history_input = input("使用历史对话轮数 (默认3): ").strip() - history_count = int(history_input) if history_input.isdigit() else 3 + history_input = input("使用历史对话轮数 (默认10): ").strip() + history_count = int(history_input) if history_input.isdigit() else 10 - context_input = input("使用上下文信息数量 (默认10): ").strip() - context_info_count = int(context_input) if context_input.isdigit() else 10 + context_input = input("使用上下文信息数量 (默认5): ").strip() + context_info_count = int(context_input) if context_input.isdigit() else 5 print(f"\n开始对话 - 主题: {user_input}") print(f"轮数: {turns}, 历史: {history_count}, 上下文: {context_info_count}") diff --git a/AITrain/npc_dialogue_generator.py b/AITrain/npc_dialogue_generator.py index b936bf1..5ccce21 100644 --- a/AITrain/npc_dialogue_generator.py +++ b/AITrain/npc_dialogue_generator.py @@ -136,7 +136,7 @@ class NPCDialogueGenerator: if self.lora_model_path: print(f"Loading LoRA weights from: {self.lora_model_path}") self.model = PeftModel.from_pretrained(self.model, self.lora_model_path) - + def generate_character_dialogue( self, character_name: str, @@ -193,7 +193,17 @@ class NPCDialogueGenerator: # 移动到设备 inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + # 计算input token数并与模型最大token数比较 + input_token_count = inputs['input_ids'].shape[1] + try: + max_model_tokens = self.model.config.max_position_embeddings + except AttributeError: + max_model_tokens = 2048 + if input_token_count + max_new_tokens > max_model_tokens: + print(f"警告:当前输入token数({input_token_count})加上最大生成token数({max_new_tokens})超过模型最大token数({max_model_tokens}),可能导致生成结果不完整或报错。") + + # 生成对话 with torch.no_grad(): outputs = self.model.generate(