添加Nan监测
This commit is contained in:
parent
de05061bc4
commit
09631041e0
@ -8,13 +8,15 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq
|
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq, TrainerCallback
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
import platform
|
import platform
|
||||||
import swanlab
|
import swanlab
|
||||||
from swanlab.integration.transformers import SwanLabCallback
|
from swanlab.integration.transformers import SwanLabCallback
|
||||||
|
import logging
|
||||||
|
|
||||||
# Windows multiprocessing兼容性修复
|
# Windows multiprocessing兼容性修复
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
@ -25,33 +27,142 @@ os.environ['VLLM_USE_MODELSCOPE'] = 'True'
|
|||||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||||
os.environ["TORCH_USE_CUDA_DSA"] = "1"
|
os.environ["TORCH_USE_CUDA_DSA"] = "1"
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler('training_debug.log'),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class GradientMonitorCallback(TrainerCallback):
|
||||||
|
"""梯度监控回调函数,用于检测NaN和梯度爆炸"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.step_count = 0
|
||||||
|
|
||||||
|
def on_step_begin(self, args, state, control, model=None, **kwargs):
|
||||||
|
"""在每个训练步骤开始前检查参数状态"""
|
||||||
|
self.step_count += 1
|
||||||
|
logger.info(f"\n=== Step {self.step_count} Begin ===")
|
||||||
|
|
||||||
|
# 检查模型参数中的异常值
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if torch.isnan(param.data).any():
|
||||||
|
logger.error(f"NaN detected in parameter: {name}")
|
||||||
|
if torch.isinf(param.data).any():
|
||||||
|
logger.error(f"Inf detected in parameter: {name}")
|
||||||
|
|
||||||
|
# 记录参数统计
|
||||||
|
param_stats = {
|
||||||
|
'min': param.data.min().item(),
|
||||||
|
'max': param.data.max().item(),
|
||||||
|
'mean': param.data.mean().item(),
|
||||||
|
'std': param.data.std().item()
|
||||||
|
}
|
||||||
|
|
||||||
|
if abs(param_stats['max']) > 1e6 or abs(param_stats['min']) > 1e6:
|
||||||
|
logger.warning(f"Large parameter values in {name}: {param_stats}")
|
||||||
|
|
||||||
|
def on_step_end(self, args, state, control, model=None, **kwargs):
|
||||||
|
"""在每个训练步骤结束后检查梯度"""
|
||||||
|
logger.info(f"=== Step {self.step_count} End ===")
|
||||||
|
|
||||||
|
total_norm = 0
|
||||||
|
nan_grads = []
|
||||||
|
inf_grads = []
|
||||||
|
large_grads = []
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad and param.grad is not None:
|
||||||
|
# 检查梯度中的NaN和Inf
|
||||||
|
if torch.isnan(param.grad).any():
|
||||||
|
nan_grads.append(name)
|
||||||
|
logger.error(f"NaN gradient detected in: {name}")
|
||||||
|
|
||||||
|
if torch.isinf(param.grad).any():
|
||||||
|
inf_grads.append(name)
|
||||||
|
logger.error(f"Inf gradient detected in: {name}")
|
||||||
|
|
||||||
|
# 计算梯度范数
|
||||||
|
param_norm = param.grad.data.norm(2)
|
||||||
|
total_norm += param_norm.item() ** 2
|
||||||
|
|
||||||
|
# 检查异常大的梯度
|
||||||
|
if param_norm > 10.0:
|
||||||
|
large_grads.append((name, param_norm.item()))
|
||||||
|
logger.warning(f"Large gradient in {name}: {param_norm.item():.6f}")
|
||||||
|
|
||||||
|
total_norm = total_norm ** (1. / 2)
|
||||||
|
logger.info(f"Total gradient norm: {total_norm:.6f}")
|
||||||
|
|
||||||
|
# 汇总报告
|
||||||
|
if nan_grads:
|
||||||
|
logger.error(f"Parameters with NaN gradients: {nan_grads}")
|
||||||
|
if inf_grads:
|
||||||
|
logger.error(f"Parameters with Inf gradients: {inf_grads}")
|
||||||
|
if large_grads:
|
||||||
|
logger.warning(f"Parameters with large gradients: {large_grads}")
|
||||||
|
|
||||||
|
# 如果梯度范数过大,记录详细信息
|
||||||
|
if total_norm > 100.0:
|
||||||
|
logger.error(f"Gradient explosion detected! Total norm: {total_norm}")
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, model=None, **kwargs):
|
||||||
|
"""记录训练日志"""
|
||||||
|
if hasattr(state, 'log_history') and state.log_history:
|
||||||
|
last_log = state.log_history[-1]
|
||||||
|
if 'train_loss' in last_log:
|
||||||
|
loss = last_log['train_loss']
|
||||||
|
logger.info(f"Current loss: {loss:.6f}")
|
||||||
|
|
||||||
|
if np.isnan(loss):
|
||||||
|
logger.error("Loss is NaN!")
|
||||||
|
elif loss > 1e6:
|
||||||
|
logger.error(f"Loss explosion detected: {loss}")
|
||||||
|
|
||||||
|
|
||||||
def process_func(example, tokenizer):
|
def process_func(example, tokenizer):
|
||||||
"""数据预处理函数"""
|
"""数据预处理函数(增加数值稳定性检查)"""
|
||||||
MAX_LENGTH = 1024
|
MAX_LENGTH = 1024
|
||||||
|
|
||||||
|
try:
|
||||||
# 构建对话模板 - 专门针对角色对话
|
# 构建对话模板 - 专门针对角色对话
|
||||||
system_prompt = f"你是一个游戏中的NPC角色。{example['character']}"
|
system_prompt = f"你是一个游戏中的NPC角色。{example['character']}"
|
||||||
instruction = example['instruction']
|
instruction = example['instruction']
|
||||||
user_input = example['input']
|
user_input = example['input']
|
||||||
|
|
||||||
|
# 验证输入数据
|
||||||
|
if not isinstance(system_prompt, str) or not isinstance(instruction, str):
|
||||||
|
logger.warning(f"Invalid input data types: system_prompt={type(system_prompt)}, instruction={type(instruction)}")
|
||||||
|
return None
|
||||||
|
|
||||||
# 定义输入部分
|
# 定义输入部分
|
||||||
instruction = tokenizer(
|
instruction_tokens = tokenizer(
|
||||||
f"<s><|im_start|>system\n{system_prompt}<|im_end|>\n"
|
f"<s><|im_start|>system\n{system_prompt}<|im_end|>\n"
|
||||||
f"<|im_start|>user\n{instruction + user_input}<|im_end|>\n"
|
f"<|im_start|>user\n{instruction + user_input}<|im_end|>\n"
|
||||||
f"<|im_start|>assistant\n<think>\n\n</think>\n\n",
|
f"<|im_start|>assistant\n",
|
||||||
add_special_tokens=False
|
add_special_tokens=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义输出部分
|
# 定义输出部分
|
||||||
response = tokenizer(f"{example['output']}", add_special_tokens=False)
|
response_tokens = tokenizer(f"{example['output']}", add_special_tokens=False)
|
||||||
|
|
||||||
|
# 验证tokenization结果
|
||||||
|
if not instruction_tokens["input_ids"] or not response_tokens["input_ids"]:
|
||||||
|
logger.warning("Empty tokenization result")
|
||||||
|
return None
|
||||||
|
|
||||||
# 合并输入输出
|
# 合并输入输出
|
||||||
input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
|
input_ids = instruction_tokens["input_ids"] + response_tokens["input_ids"] + [tokenizer.pad_token_id]
|
||||||
attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]
|
attention_mask = instruction_tokens["attention_mask"] + response_tokens["attention_mask"] + [1]
|
||||||
|
|
||||||
# 标签:只对输出部分计算损失
|
# 标签:只对输出部分计算损失
|
||||||
labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
|
labels = [-100] * len(instruction_tokens["input_ids"]) + response_tokens["input_ids"] + [tokenizer.pad_token_id]
|
||||||
|
|
||||||
# 截断处理
|
# 截断处理
|
||||||
if len(input_ids) > MAX_LENGTH:
|
if len(input_ids) > MAX_LENGTH:
|
||||||
@ -59,12 +170,21 @@ def process_func(example, tokenizer):
|
|||||||
attention_mask = attention_mask[:MAX_LENGTH]
|
attention_mask = attention_mask[:MAX_LENGTH]
|
||||||
labels = labels[:MAX_LENGTH]
|
labels = labels[:MAX_LENGTH]
|
||||||
|
|
||||||
|
# 最终验证
|
||||||
|
if len(input_ids) != len(attention_mask) or len(input_ids) != len(labels):
|
||||||
|
logger.error(f"Length mismatch: input_ids={len(input_ids)}, attention_mask={len(attention_mask)}, labels={len(labels)}")
|
||||||
|
return None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"labels": labels
|
"labels": labels
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in process_func: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def load_model_and_tokenizer(model_path):
|
def load_model_and_tokenizer(model_path):
|
||||||
"""加载模型和分词器"""
|
"""加载模型和分词器"""
|
||||||
print(f"Loading model from: {model_path}")
|
print(f"Loading model from: {model_path}")
|
||||||
@ -91,21 +211,23 @@ def create_lora_config():
|
|||||||
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
|
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
|
||||||
inference_mode=False,
|
inference_mode=False,
|
||||||
r=8, # 增加rank以提高表达能力
|
r=8, # 增加rank以提高表达能力
|
||||||
lora_alpha=16, # alpha = 2 * r
|
lora_alpha=8, # alpha = 2 * r
|
||||||
lora_dropout=0.1,
|
lora_dropout=0.1,
|
||||||
modules_to_save=["lm_head", "embed_tokens"]
|
modules_to_save=["lm_head", "embed_tokens"]
|
||||||
)
|
)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def prepare_dataset(data_path, tokenizer):
|
def prepare_dataset(data_path, tokenizer):
|
||||||
"""准备数据集(增强健壮性)"""
|
"""准备数据集(增强健壮性和数值稳定性检查)"""
|
||||||
print(f"Loading dataset from: {data_path}")
|
print(f"Loading dataset from: {data_path}")
|
||||||
|
logger.info(f"Loading dataset from: {data_path}")
|
||||||
|
|
||||||
# 加载JSON数据
|
# 加载JSON数据
|
||||||
with open(data_path, 'r', encoding='utf-8') as f:
|
with open(data_path, 'r', encoding='utf-8') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
print(f"Total samples before filtering: {len(data)}")
|
print(f"Total samples before filtering: {len(data)}")
|
||||||
|
logger.info(f"Total samples before filtering: {len(data)}")
|
||||||
|
|
||||||
# 转换为Dataset格式
|
# 转换为Dataset格式
|
||||||
dataset = Dataset.from_list(data)
|
dataset = Dataset.from_list(data)
|
||||||
@ -114,15 +236,35 @@ def prepare_dataset(data_path, tokenizer):
|
|||||||
tokenized_dataset = dataset.map(
|
tokenized_dataset = dataset.map(
|
||||||
lambda example: process_func(example, tokenizer),
|
lambda example: process_func(example, tokenizer),
|
||||||
remove_columns=dataset.column_names,
|
remove_columns=dataset.column_names,
|
||||||
batched=False # process_func expects single examples
|
batched=False,
|
||||||
|
desc="Tokenizing dataset"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 关键步骤:过滤掉预处理后变为空的样本
|
# 过滤掉预处理后变为空的样本或包含None的样本
|
||||||
original_size = len(tokenized_dataset)
|
original_size = len(tokenized_dataset)
|
||||||
tokenized_dataset = tokenized_dataset.filter(lambda example: len(example.get("input_ids", [])) > 0)
|
tokenized_dataset = tokenized_dataset.filter(
|
||||||
|
lambda example: example is not None and
|
||||||
|
len(example.get("input_ids", [])) > 0 and
|
||||||
|
len(example.get("labels", [])) > 0
|
||||||
|
)
|
||||||
filtered_size = len(tokenized_dataset)
|
filtered_size = len(tokenized_dataset)
|
||||||
|
|
||||||
print(f"Total samples after filtering: {filtered_size} ({original_size - filtered_size} samples removed)")
|
print(f"Total samples after filtering: {filtered_size} ({original_size - filtered_size} samples removed)")
|
||||||
|
logger.info(f"Total samples after filtering: {filtered_size} ({original_size - filtered_size} samples removed)")
|
||||||
|
|
||||||
|
# 数据质量检查
|
||||||
|
if filtered_size == 0:
|
||||||
|
logger.error("No valid samples remaining after filtering!")
|
||||||
|
raise ValueError("Dataset is empty after preprocessing")
|
||||||
|
|
||||||
|
# 检查几个样本的数据质量
|
||||||
|
for i in range(min(3, filtered_size)):
|
||||||
|
sample = tokenized_dataset[i]
|
||||||
|
logger.info(f"Sample {i} - input_ids length: {len(sample['input_ids'])}, labels length: {len(sample['labels'])}")
|
||||||
|
|
||||||
|
# 检查是否包含异常token
|
||||||
|
if any(token_id < 0 or token_id > tokenizer.vocab_size for token_id in sample['input_ids'] if token_id != -100):
|
||||||
|
logger.warning(f"Sample {i} contains out-of-vocabulary tokens")
|
||||||
|
|
||||||
return tokenized_dataset
|
return tokenized_dataset
|
||||||
|
|
||||||
@ -156,7 +298,7 @@ def train_lora_model(model_path, data_path, output_dir):
|
|||||||
logging_steps=10,
|
logging_steps=10,
|
||||||
num_train_epochs=3, # 增加训练轮数以充分学习角色特征
|
num_train_epochs=3, # 增加训练轮数以充分学习角色特征
|
||||||
save_steps=50,
|
save_steps=50,
|
||||||
learning_rate=2e-5, # 降低学习率以增加稳定性
|
learning_rate=1e-5, # 降低学习率以增加稳定性
|
||||||
warmup_ratio=0.1,
|
warmup_ratio=0.1,
|
||||||
max_grad_norm=1.0, # 保持梯度裁剪
|
max_grad_norm=1.0, # 保持梯度裁剪
|
||||||
save_on_each_node=True,
|
save_on_each_node=True,
|
||||||
@ -176,18 +318,31 @@ def train_lora_model(model_path, data_path, output_dir):
|
|||||||
experiment_name="Qwen3-8B-LoRA-experiment"
|
experiment_name="Qwen3-8B-LoRA-experiment"
|
||||||
)
|
)
|
||||||
swanlab.login(api_key="pAxFTROvv3aspmEijax46")
|
swanlab.login(api_key="pAxFTROvv3aspmEijax46")
|
||||||
|
|
||||||
|
# 创建梯度监控回调
|
||||||
|
gradient_monitor = GradientMonitorCallback()
|
||||||
|
|
||||||
# 7. 创建训练器
|
# 7. 创建训练器
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_preparedataset,
|
train_dataset=train_preparedataset,
|
||||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
|
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
|
||||||
callbacks=[swanlab_callback] # 传入之前的swanlab_callback
|
callbacks=[swanlab_callback, gradient_monitor] # 添加梯度监控回调
|
||||||
)
|
)
|
||||||
|
|
||||||
# 8. 开始训练
|
# 8. 开始训练
|
||||||
print("Starting training...")
|
print("Starting training...")
|
||||||
|
logger.info("Starting training...")
|
||||||
|
|
||||||
|
try:
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
logger.info("Training completed successfully!")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training failed with error: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
||||||
# 9. 保存最终模型
|
# 9. 保存最终模型
|
||||||
final_output_dir = os.path.join(output_dir, "final_model")
|
final_output_dir = os.path.join(output_dir, "final_model")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user