更新训练参数
This commit is contained in:
parent
1fd8319715
commit
de05061bc4
@ -98,28 +98,32 @@ def create_lora_config():
|
||||
return config
|
||||
|
||||
def prepare_dataset(data_path, tokenizer):
|
||||
"""准备数据集"""
|
||||
"""准备数据集(增强健壮性)"""
|
||||
print(f"Loading dataset from: {data_path}")
|
||||
|
||||
# 加载JSON数据
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
print(f"Total samples: {len(data)}")
|
||||
print(f"Total samples before filtering: {len(data)}")
|
||||
|
||||
# 转换为Dataset格式
|
||||
dataset = Dataset.from_list(data)
|
||||
|
||||
# 应用预处理函数
|
||||
def tokenize_function(examples):
|
||||
return process_func(examples, tokenizer)
|
||||
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
lambda example: process_func(example, tokenizer),
|
||||
remove_columns=dataset.column_names,
|
||||
batched=False
|
||||
batched=False # process_func expects single examples
|
||||
)
|
||||
|
||||
# 关键步骤:过滤掉预处理后变为空的样本
|
||||
original_size = len(tokenized_dataset)
|
||||
tokenized_dataset = tokenized_dataset.filter(lambda example: len(example.get("input_ids", [])) > 0)
|
||||
filtered_size = len(tokenized_dataset)
|
||||
|
||||
print(f"Total samples after filtering: {filtered_size} ({original_size - filtered_size} samples removed)")
|
||||
|
||||
return tokenized_dataset
|
||||
|
||||
def train_lora_model(model_path, data_path, output_dir):
|
||||
@ -152,17 +156,17 @@ def train_lora_model(model_path, data_path, output_dir):
|
||||
logging_steps=10,
|
||||
num_train_epochs=3, # 增加训练轮数以充分学习角色特征
|
||||
save_steps=50,
|
||||
learning_rate=5e-5, # 稍微提高学习率
|
||||
learning_rate=2e-5, # 降低学习率以增加稳定性
|
||||
warmup_ratio=0.1,
|
||||
max_grad_norm=1.0,
|
||||
max_grad_norm=1.0, # 保持梯度裁剪
|
||||
save_on_each_node=True,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": True},
|
||||
dataloader_pin_memory=False, # 减少内存使用
|
||||
remove_unused_columns=False,
|
||||
report_to="none",
|
||||
#bf16=True,
|
||||
#fp16=True, # 使用混合精度训练
|
||||
bf16=True, # 显式启用bf16以匹配模型加载类型
|
||||
#fp16=False, # 确保fp16被禁用
|
||||
save_total_limit=3, # 只保留最新的3个检查点
|
||||
)
|
||||
|
||||
@ -271,7 +275,7 @@ def test_trained_model(model_path, lora_path):
|
||||
|
||||
def main():
|
||||
# 配置路径
|
||||
model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径
|
||||
model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径
|
||||
data_path = './npc_dialogue_dataset.json' # 训练数据路径
|
||||
output_dir = './output/NPC_Dialogue_LoRA' # 输出目录
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user