调整训练参数
This commit is contained in:
parent
5c27e14038
commit
66dade2f3f
@ -39,7 +39,7 @@ def process_func(example, tokenizer):
|
|||||||
instruction = tokenizer(
|
instruction = tokenizer(
|
||||||
f"<s><|im_start|>system\n{system_prompt}<|im_end|>\n"
|
f"<s><|im_start|>system\n{system_prompt}<|im_end|>\n"
|
||||||
f"<|im_start|>user\n{instruction + user_input}<|im_end|>\n"
|
f"<|im_start|>user\n{instruction + user_input}<|im_end|>\n"
|
||||||
f"<|im_start|>assistant\n",
|
f"<|im_start|>assistant\n<think>\n\n</think>\n\n",
|
||||||
add_special_tokens=False
|
add_special_tokens=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -142,12 +142,12 @@ def train_lora_model(model_path, data_path, output_dir):
|
|||||||
model.config.use_cache = False # 关闭缓存以节省显存
|
model.config.use_cache = False # 关闭缓存以节省显存
|
||||||
|
|
||||||
# 5. 准备数据集
|
# 5. 准备数据集
|
||||||
train_dataset = prepare_dataset(data_path, tokenizer)
|
train_preparedataset = prepare_dataset(data_path, tokenizer)
|
||||||
|
|
||||||
# 6. 配置训练参数 - 针对3080显卡优化
|
# 6. 配置训练参数 - 针对3080显卡优化
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
per_device_train_batch_size=1, # 减小batch size
|
per_device_train_batch_size=2, # 减小batch size
|
||||||
gradient_accumulation_steps=4, # 增加梯度累积
|
gradient_accumulation_steps=4, # 增加梯度累积
|
||||||
logging_steps=10,
|
logging_steps=10,
|
||||||
num_train_epochs=3, # 增加训练轮数以充分学习角色特征
|
num_train_epochs=3, # 增加训练轮数以充分学习角色特征
|
||||||
@ -161,6 +161,7 @@ def train_lora_model(model_path, data_path, output_dir):
|
|||||||
dataloader_pin_memory=False, # 减少内存使用
|
dataloader_pin_memory=False, # 减少内存使用
|
||||||
remove_unused_columns=False,
|
remove_unused_columns=False,
|
||||||
report_to="none",
|
report_to="none",
|
||||||
|
#bf16=True,
|
||||||
#fp16=True, # 使用混合精度训练
|
#fp16=True, # 使用混合精度训练
|
||||||
save_total_limit=3, # 只保留最新的3个检查点
|
save_total_limit=3, # 只保留最新的3个检查点
|
||||||
)
|
)
|
||||||
@ -175,7 +176,7 @@ def train_lora_model(model_path, data_path, output_dir):
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_preparedataset,
|
||||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
|
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
|
||||||
callbacks=[swanlab_callback] # 传入之前的swanlab_callback
|
callbacks=[swanlab_callback] # 传入之前的swanlab_callback
|
||||||
)
|
)
|
||||||
@ -270,13 +271,13 @@ def test_trained_model(model_path, lora_path):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 配置路径
|
# 配置路径
|
||||||
model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径
|
model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径
|
||||||
data_path = './npc_dialogue_dataset.json' # 训练数据路径
|
data_path = './npc_dialogue_dataset.json' # 训练数据路径
|
||||||
output_dir = './output/NPC_Dialogue_LoRA' # 输出目录
|
output_dir = './output/NPC_Dialogue_LoRA' # 输出目录
|
||||||
|
|
||||||
#####test
|
# #####test
|
||||||
final_model_path = os.path.join(output_dir, "final_model")
|
# final_model_path = os.path.join(output_dir, "final_model")
|
||||||
test_trained_model(model_path, final_model_path)
|
# test_trained_model(model_path, final_model_path)
|
||||||
# 确保数据文件存在
|
# 确保数据文件存在
|
||||||
if not os.path.exists(data_path):
|
if not os.path.exists(data_path):
|
||||||
print(f"数据文件不存在: {data_path}")
|
print(f"数据文件不存在: {data_path}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user