更新训练参数
This commit is contained in:
parent
1fd8319715
commit
de05061bc4
@ -98,28 +98,32 @@ def create_lora_config():
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
def prepare_dataset(data_path, tokenizer):
|
def prepare_dataset(data_path, tokenizer):
|
||||||
"""准备数据集"""
|
"""准备数据集(增强健壮性)"""
|
||||||
print(f"Loading dataset from: {data_path}")
|
print(f"Loading dataset from: {data_path}")
|
||||||
|
|
||||||
# 加载JSON数据
|
# 加载JSON数据
|
||||||
with open(data_path, 'r', encoding='utf-8') as f:
|
with open(data_path, 'r', encoding='utf-8') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
print(f"Total samples: {len(data)}")
|
print(f"Total samples before filtering: {len(data)}")
|
||||||
|
|
||||||
# 转换为Dataset格式
|
# 转换为Dataset格式
|
||||||
dataset = Dataset.from_list(data)
|
dataset = Dataset.from_list(data)
|
||||||
|
|
||||||
# 应用预处理函数
|
# 应用预处理函数
|
||||||
def tokenize_function(examples):
|
|
||||||
return process_func(examples, tokenizer)
|
|
||||||
|
|
||||||
tokenized_dataset = dataset.map(
|
tokenized_dataset = dataset.map(
|
||||||
tokenize_function,
|
lambda example: process_func(example, tokenizer),
|
||||||
remove_columns=dataset.column_names,
|
remove_columns=dataset.column_names,
|
||||||
batched=False
|
batched=False # process_func expects single examples
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 关键步骤:过滤掉预处理后变为空的样本
|
||||||
|
original_size = len(tokenized_dataset)
|
||||||
|
tokenized_dataset = tokenized_dataset.filter(lambda example: len(example.get("input_ids", [])) > 0)
|
||||||
|
filtered_size = len(tokenized_dataset)
|
||||||
|
|
||||||
|
print(f"Total samples after filtering: {filtered_size} ({original_size - filtered_size} samples removed)")
|
||||||
|
|
||||||
return tokenized_dataset
|
return tokenized_dataset
|
||||||
|
|
||||||
def train_lora_model(model_path, data_path, output_dir):
|
def train_lora_model(model_path, data_path, output_dir):
|
||||||
@ -152,17 +156,17 @@ def train_lora_model(model_path, data_path, output_dir):
|
|||||||
logging_steps=10,
|
logging_steps=10,
|
||||||
num_train_epochs=3, # 增加训练轮数以充分学习角色特征
|
num_train_epochs=3, # 增加训练轮数以充分学习角色特征
|
||||||
save_steps=50,
|
save_steps=50,
|
||||||
learning_rate=5e-5, # 稍微提高学习率
|
learning_rate=2e-5, # 降低学习率以增加稳定性
|
||||||
warmup_ratio=0.1,
|
warmup_ratio=0.1,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0, # 保持梯度裁剪
|
||||||
save_on_each_node=True,
|
save_on_each_node=True,
|
||||||
gradient_checkpointing=True,
|
gradient_checkpointing=True,
|
||||||
gradient_checkpointing_kwargs={"use_reentrant": True},
|
gradient_checkpointing_kwargs={"use_reentrant": True},
|
||||||
dataloader_pin_memory=False, # 减少内存使用
|
dataloader_pin_memory=False, # 减少内存使用
|
||||||
remove_unused_columns=False,
|
remove_unused_columns=False,
|
||||||
report_to="none",
|
report_to="none",
|
||||||
#bf16=True,
|
bf16=True, # 显式启用bf16以匹配模型加载类型
|
||||||
#fp16=True, # 使用混合精度训练
|
#fp16=False, # 确保fp16被禁用
|
||||||
save_total_limit=3, # 只保留最新的3个检查点
|
save_total_limit=3, # 只保留最新的3个检查点
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -271,7 +275,7 @@ def test_trained_model(model_path, lora_path):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 配置路径
|
# 配置路径
|
||||||
model_path = '/mnt/e/AI/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径
|
model_path = '/mnt/g/Project02/AITrain/Qwen/Qwen3-8B-AWQ' # 基础模型路径
|
||||||
data_path = './npc_dialogue_dataset.json' # 训练数据路径
|
data_path = './npc_dialogue_dataset.json' # 训练数据路径
|
||||||
output_dir = './output/NPC_Dialogue_LoRA' # 输出目录
|
output_dir = './output/NPC_Dialogue_LoRA' # 输出目录
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user