diff --git a/AITrain/train_npc_dialogue_lora.py b/AITrain/train_npc_dialogue_lora.py index a534856..e1ed680 100644 --- a/AITrain/train_npc_dialogue_lora.py +++ b/AITrain/train_npc_dialogue_lora.py @@ -98,28 +98,32 @@ def create_lora_config(): return config def prepare_dataset(data_path, tokenizer): - """准备数据集""" + """准备数据集(增强健壮性)""" print(f"Loading dataset from: {data_path}") # 加载JSON数据 with open(data_path, 'r', encoding='utf-8') as f: data = json.load(f) - print(f"Total samples: {len(data)}") + print(f"Total samples before filtering: {len(data)}") # 转换为Dataset格式 dataset = Dataset.from_list(data) # 应用预处理函数 - def tokenize_function(examples): - return process_func(examples, tokenizer) - tokenized_dataset = dataset.map( - tokenize_function, + lambda example: process_func(example, tokenizer), remove_columns=dataset.column_names, - batched=False + batched=False # process_func expects single examples ) + # 关键步骤:过滤掉预处理后变为空的样本 + original_size = len(tokenized_dataset) + tokenized_dataset = tokenized_dataset.filter(lambda example: len(example.get("input_ids", [])) > 0) + filtered_size = len(tokenized_dataset) + + print(f"Total samples after filtering: {filtered_size} ({original_size - filtered_size} samples removed)") + return tokenized_dataset def train_lora_model(model_path, data_path, output_dir): @@ -152,17 +156,17 @@ def train_lora_model(model_path, data_path, output_dir): logging_steps=10, num_train_epochs=3, # 增加训练轮数以充分学习角色特征 save_steps=50, - learning_rate=5e-5, # 稍微提高学习率 + learning_rate=2e-5, # 降低学习率以增加稳定性 warmup_ratio=0.1, - max_grad_norm=1.0, + max_grad_norm=1.0, # 保持梯度裁剪 save_on_each_node=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": True}, dataloader_pin_memory=False, # 减少内存使用 remove_unused_columns=False, report_to="none", - #bf16=True, - #fp16=True, # 使用混合精度训练 + bf16=True, # 显式启用bf16以匹配模型加载类型 + #fp16=False, # 确保fp16被禁用 save_total_limit=3, # 只保留最新的3个检查点 ) @@ -271,7 +275,7 @@ def test_trained_model(model_path, lora_path): def main(): # 配置路径 - model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径 + model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径 data_path = './npc_dialogue_dataset.json' # 训练数据路径 output_dir = './output/NPC_Dialogue_LoRA' # 输出目录