diff --git a/AITrain/train_npc_dialogue_lora.py b/AITrain/train_npc_dialogue_lora.py index b46ce5b..8408e72 100644 --- a/AITrain/train_npc_dialogue_lora.py +++ b/AITrain/train_npc_dialogue_lora.py @@ -39,7 +39,7 @@ def process_func(example, tokenizer): instruction = tokenizer( f"<|im_start|>system\n{system_prompt}<|im_end|>\n" f"<|im_start|>user\n{instruction + user_input}<|im_end|>\n" - f"<|im_start|>assistant\n", + f"<|im_start|>assistant\n\n\n\n\n", add_special_tokens=False ) @@ -142,12 +142,12 @@ def train_lora_model(model_path, data_path, output_dir): model.config.use_cache = False # 关闭缓存以节省显存 # 5. 准备数据集 - train_dataset = prepare_dataset(data_path, tokenizer) + train_preparedataset = prepare_dataset(data_path, tokenizer) # 6. 配置训练参数 - 针对3080显卡优化 training_args = TrainingArguments( output_dir=output_dir, - per_device_train_batch_size=1, # 减小batch size + per_device_train_batch_size=2, # 减小batch size gradient_accumulation_steps=4, # 增加梯度累积 logging_steps=10, num_train_epochs=3, # 增加训练轮数以充分学习角色特征 @@ -161,6 +161,7 @@ def train_lora_model(model_path, data_path, output_dir): dataloader_pin_memory=False, # 减少内存使用 remove_unused_columns=False, report_to="none", + #bf16=True, #fp16=True, # 使用混合精度训练 save_total_limit=3, # 只保留最新的3个检查点 ) @@ -175,7 +176,7 @@ def train_lora_model(model_path, data_path, output_dir): trainer = Trainer( model=model, args=training_args, - train_dataset=train_dataset, + train_dataset=train_preparedataset, data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True), callbacks=[swanlab_callback] # 传入之前的swanlab_callback ) @@ -270,13 +271,13 @@ def test_trained_model(model_path, lora_path): def main(): # 配置路径 - model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径 + model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径 data_path = './npc_dialogue_dataset.json' # 训练数据路径 output_dir = './output/NPC_Dialogue_LoRA' # 输出目录 - #####test - final_model_path = os.path.join(output_dir, "final_model") - test_trained_model(model_path, final_model_path) + # #####test + # final_model_path = os.path.join(output_dir, "final_model") + # test_trained_model(model_path, final_model_path) # 确保数据文件存在 if not os.path.exists(data_path): print(f"数据文件不存在: {data_path}")