更新训练参数

This commit is contained in:
997146918 2025-08-11 11:21:03 +08:00
parent 1fd8319715
commit de05061bc4

View File

@ -98,28 +98,32 @@ def create_lora_config():
return config return config
def prepare_dataset(data_path, tokenizer): def prepare_dataset(data_path, tokenizer):
"""准备数据集""" """准备数据集(增强健壮性)"""
print(f"Loading dataset from: {data_path}") print(f"Loading dataset from: {data_path}")
# 加载JSON数据 # 加载JSON数据
with open(data_path, 'r', encoding='utf-8') as f: with open(data_path, 'r', encoding='utf-8') as f:
data = json.load(f) data = json.load(f)
print(f"Total samples: {len(data)}") print(f"Total samples before filtering: {len(data)}")
# 转换为Dataset格式 # 转换为Dataset格式
dataset = Dataset.from_list(data) dataset = Dataset.from_list(data)
# 应用预处理函数 # 应用预处理函数
def tokenize_function(examples):
return process_func(examples, tokenizer)
tokenized_dataset = dataset.map( tokenized_dataset = dataset.map(
tokenize_function, lambda example: process_func(example, tokenizer),
remove_columns=dataset.column_names, remove_columns=dataset.column_names,
batched=False batched=False # process_func expects single examples
) )
# 关键步骤:过滤掉预处理后变为空的样本
original_size = len(tokenized_dataset)
tokenized_dataset = tokenized_dataset.filter(lambda example: len(example.get("input_ids", [])) > 0)
filtered_size = len(tokenized_dataset)
print(f"Total samples after filtering: {filtered_size} ({original_size - filtered_size} samples removed)")
return tokenized_dataset return tokenized_dataset
def train_lora_model(model_path, data_path, output_dir): def train_lora_model(model_path, data_path, output_dir):
@ -152,17 +156,17 @@ def train_lora_model(model_path, data_path, output_dir):
logging_steps=10, logging_steps=10,
num_train_epochs=3, # 增加训练轮数以充分学习角色特征 num_train_epochs=3, # 增加训练轮数以充分学习角色特征
save_steps=50, save_steps=50,
learning_rate=5e-5, # 稍微提高学习率 learning_rate=2e-5, # 降低学习率以增加稳定性
warmup_ratio=0.1, warmup_ratio=0.1,
max_grad_norm=1.0, max_grad_norm=1.0, # 保持梯度裁剪
save_on_each_node=True, save_on_each_node=True,
gradient_checkpointing=True, gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": True}, gradient_checkpointing_kwargs={"use_reentrant": True},
dataloader_pin_memory=False, # 减少内存使用 dataloader_pin_memory=False, # 减少内存使用
remove_unused_columns=False, remove_unused_columns=False,
report_to="none", report_to="none",
#bf16=True, bf16=True, # 显式启用bf16以匹配模型加载类型
#fp16=True, # 使用混合精度训练 #fp16=False, # 确保fp16被禁用
save_total_limit=3, # 只保留最新的3个检查点 save_total_limit=3, # 只保留最新的3个检查点
) )
@ -271,7 +275,7 @@ def test_trained_model(model_path, lora_path):
def main(): def main():
# 配置路径 # 配置路径
model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径 model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径
data_path = './npc_dialogue_dataset.json' # 训练数据路径 data_path = './npc_dialogue_dataset.json' # 训练数据路径
output_dir = './output/NPC_Dialogue_LoRA' # 输出目录 output_dir = './output/NPC_Dialogue_LoRA' # 输出目录