Open
Description
描述问题
训练到六百多步,报Nan grad norm occurs, please check it。请问我该怎么调整?以下是报错信息和训练config。
报错信息:
2024-11-24 02:35:28,531 INFO pipeline.py:770 in record_current_batch_training_metrics -- tflops=229.64161376788178 step=642 loss=2.1725409030914307 real_tgs=719.84 tgs (tokens/gpu/second)=720.46 tgs/last_tgs_1=720.46 tgs/tgs_all=716.63 tgs/tgs_avg=722.47 tgs/tgs_SMA=722.84 tgs/last_tgs_10=723.59 tgs/last_tgs_50=722.37 lr=9.808683347597691e-05 loss_scale=16384.0 grad_norm={'0_default': 2.156615516531637, '1_fp32': 0.0} micro_num=16 num_consumed_tokens=674234368 inf_nan_skip_batches=3 num_samples_in_batch=102 largest_length=7543 largest_batch=11 smallest_batch=2 adam_beta2=0.95 fwd_bwd_time=181.49 bwd_time=120.98 acc=0.5494 perplexity=9.0888 acc/zh=0.5494 tokens/zh=1047661 loss_from_metric=2.207 loss/zh=2.207
2024-11-24 02:38:29,509 INFO pipeline.py:770 in record_current_batch_training_metrics -- tflops=230.8624288772513 step=643 loss=2.046380043029785 real_tgs=723.69 tgs (tokens/gpu/second)=724.29 tgs/last_tgs_1=724.29 tgs/tgs_all=716.64 tgs/tgs_avg=722.47 tgs/tgs_SMA=722.86 tgs/last_tgs_10=723.59 tgs/last_tgs_50=722.37 lr=9.8080255311898e-05 loss_scale=16384.0 grad_norm={'0_default': 1.7590333217221075, '1_fp32': 0.0} micro_num=16 num_consumed_tokens=675282944 inf_nan_skip_batches=3 num_samples_in_batch=108 largest_length=6803 largest_batch=10 smallest_batch=2 adam_beta2=0.95 fwd_bwd_time=180.53 bwd_time=120.24 acc=0.5622 perplexity=8.3492 acc/zh=0.5622 tokens/zh=1047699 loss_from_metric=2.1222 loss/zh=2.1222
2024-11-24 02:41:32,346 INFO pipeline.py:770 in record_current_batch_training_metrics -- tflops=228.51365321663746 step=644 loss=2.3020739555358887 real_tgs=716.29 tgs (tokens/gpu/second)=716.93 tgs/last_tgs_1=716.93 tgs/tgs_all=716.64 tgs/tgs_avg=722.46 tgs/tgs_SMA=722.75 tgs/last_tgs_10=723.59 tgs/last_tgs_50=722.37 lr=9.807366610463195e-05 loss_scale=16384.0 grad_norm={'0_default': 1.1157139483556522, '1_fp32': 0.0} micro_num=16 num_consumed_tokens=676331520 inf_nan_skip_batches=3 num_samples_in_batch=78 largest_length=7693 largest_batch=9 smallest_batch=2 adam_beta2=0.95 fwd_bwd_time=182.39 bwd_time=121.65 acc=0.5478 perplexity=9.0845 acc/zh=0.5478 tokens/zh=1047634 loss_from_metric=2.2066 loss/zh=2.2066
2024-11-24 02:44:33,969 INFO pipeline.py:770 in record_current_batch_training_metrics -- tflops=230.04208104683008 step=645 loss=2.0069305896759033 real_tgs=721.13 tgs (tokens/gpu/second)=721.72 tgs/last_tgs_1=721.72 tgs/tgs_all=716.65 tgs/tgs_avg=722.46 tgs/tgs_SMA=722.74 tgs/last_tgs_10=723.59 tgs/last_tgs_50=722.37 lr=9.806706585586778e-05 loss_scale=16384.0 grad_norm={'0_default': 1.1654005105502763, '1_fp32': 0.0} micro_num=16 num_consumed_tokens=677380096 inf_nan_skip_batches=3 num_samples_in_batch=88 largest_length=7649 largest_batch=10 smallest_batch=2 adam_beta2=0.95 fwd_bwd_time=181.17 bwd_time=120.76 acc=0.5658 perplexity=8.2723 acc/zh=0.5658 tokens/zh=1047719 loss_from_metric=2.1129 loss/zh=2.1129
2024-11-24 02:47:35,188 INFO pipeline.py:770 in record_current_batch_training_metrics -- tflops=230.55453843330315 step=646 loss=1.931685209274292 real_tgs=722.71 tgs (tokens/gpu/second)=723.33 tgs/last_tgs_1=723.33 tgs/tgs_all=716.66 tgs/tgs_avg=722.46 tgs/tgs_SMA=722.73 tgs/last_tgs_10=723.59 tgs/last_tgs_50=722.37 lr=9.806045456729744e-05 loss_scale=16384.0 grad_norm={'0_default': 1.6949513837430747, '1_fp32': 0.0} micro_num=16 num_consumed_tokens=678428672 inf_nan_skip_batches=3 num_samples_in_batch=87 largest_length=5648 largest_batch=10 smallest_batch=3 adam_beta2=0.95 fwd_bwd_time=180.77 bwd_time=120.48 acc=0.5596 perplexity=8.4299 acc/zh=0.5596 tokens/zh=1047665 loss_from_metric=2.1318 loss/zh=2.1318
2024-11-24 02:50:36,860 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 02:53:37,825 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 02:56:38,023 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 02:59:38,265 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 03:02:38,901 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 03:05:40,250 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 03:08:41,350 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 03:11:41,621 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 03:14:42,524 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 03:17:42,189 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 03:20:42,677 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 03:23:43,389 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 03:26:44,261 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 03:29:45,623 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
2024-11-24 03:32:45,724 WARNING hybrid_zero_optim.py:791 in _step -- Nan grad norm occurs, please check it.
config:
JOB_NAME = "32b_qwen2_train"
model_type = "QWEN2"
DO_ALERT = False
VOCAB_SIZE = 152064
SEQ_LEN = 8192
HIDDEN_SIZE = 5120
NUM_ATTENTION_HEAD = 40
NUM_KV_ATTENTION_HEAD = 8
# MLP_RATIO = 5.25
# MLP_RATIO = 3.4
MLP_RATIO = 5.4
NUM_LAYER = 64
MODEL_ONLY_FOLDER = "/data"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
SAVE_CKPT_FOLDER = "local:./save_model/pretrain"
# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
CHECKPOINT_EVERY = 323
ckpt = dict(
enable_save_ckpt=True, # enable ckpt save.
enable_internevo2hf_ckpt=True, # enable ckpt save for huggingface format.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
auto_resume=True,
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
TRAIN_FOLDER = "/root/data"
VALID_FOLDER = None # "/path/to/dataset"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=16,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=1,
# defaults to the value of micro_num
valid_micro_num=1,
# defaults to 0, means disable evaluate
valid_every=0,
pack_sample_into_one=False,
total_steps=6267,
skip_batches="630-700",
# rampup_batch_size (str): A string with three space-separated integers representing the
# starting batch size, the increment, and the number of steps between
# each increment. For example, "192 24 8" means that the batch size (micro_num)
# starts at 192 and increases by 24 every 8 steps. Defaults to None.
# (IMPORTANT): The interval step size is 'micro_bsz'.
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
train_folder=TRAIN_FOLDER,
valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=0,
diag_outlier_ratio=1.1,
)
grad_scaler = dict(
fp16=dict(
# the initial loss scale, defaults to 2**16
initial_scale=2**16,
# the minimum loss scale, defaults to None
min_scale=1,
# the number of steps to increase loss scale when no overflow occurs
growth_interval=1000,
),
# the multiplication factor for increasing loss scale, defaults to 2
growth_factor=2,
# the multiplication factor for decreasing loss scale, defaults to 0.5
backoff_factor=0.5,
# the maximum loss scale, defaults to None
max_scale=2**24,
# the number of overflows before decreasing loss scale, defaults to 2
hysteresis=2,
)
hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
overlap_sync_grad=True,
overlap_sync_param=False,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
clip_grad_norm=1.0,
)
loss = dict(
label_smoothing=0,
)
adam = dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
)
lr_scheduler = dict(
total_steps=data["total_steps"],
init_steps=0, # optimizer_warmup_step
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
)
beta2_scheduler = dict(
init_beta2=adam["adam_beta2"],
c=adam["adam_beta2_c"],
cur_iter=-1,
)
use_fp32_norm = False
model = dict(
checkpoint=True,
num_chunks=1,
num_attention_heads=NUM_ATTENTION_HEAD,
num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
qkv_bias=True,
o_bias=False,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.float16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-6,
use_flash_attn=True,
# Whether the odd and even columns of the query and key in the model are normally interleaved.
# If it's True, the model's odd and even columns are normally ordered; if it's False,
# it means that the model has prematurely concatenated all odd columns and even columns in front
# and back, in order to improve the RoPE's computational efficiency.
# Example:
# qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
qk_interleaved=False,
multiple_of=1,
rope_base=1000000,
use_sliding_window=False,
sliding_window=131072,
max_window_layers=70,
)
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=8, overlap=True),
)
cudnn_deterministic = False
cudnn_benchmark = False
monitor = dict(
# feishu alert configs
alert=dict(
enable_feishu_alert=DO_ALERT,
feishu_alert_address=None, # feishu webhook to send alert message
light_monitor_address=None, # light_monitor address to send heartbeat
alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
),
tensorboard=dict(
queue_max_length=10,
),
)
generation = dict(
ckpt_folder="/path/to/saved/ckpt",
output_folder="/path/to/save/generation",
batch_size=1,
eos_id=[2, 0],
bos_id=1,
max_length=100,
do_sample=True,
temperature=1.0,
top_k=50,
top_p=1.0,
repetition_penalty=1,
length_penalty=1.0,
Activity