Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recipe configs validating #10954

Merged
merged 31 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c6f00fe
Init with parallelism validation
BoxiangW Oct 19, 2024
c971ed0
Apply isort and black reformatting
BoxiangW Oct 19, 2024
d33b591
Add config checks
BoxiangW Oct 31, 2024
d32ced6
Apply isort and black reformatting
BoxiangW Oct 31, 2024
751c1da
fix
BoxiangW Oct 31, 2024
9efa022
Merge branch 'main' into boxiangw/add-config-validation
BoxiangW Nov 4, 2024
3774662
Merge branch 'main' into boxiangw/add-config-validation
BoxiangW Nov 4, 2024
d30ad07
Add unit tests on configs validation
BoxiangW Nov 4, 2024
9f81ca9
Add test
BoxiangW Nov 4, 2024
c3133d2
Add copyright
BoxiangW Nov 4, 2024
ad807a7
Change test
BoxiangW Nov 4, 2024
8d663b5
Merge branch 'main' into boxiangw/add-config-validation
BoxiangW Nov 5, 2024
eaee5ea
Merge branch 'main' into boxiangw/add-config-validation
BoxiangW Nov 5, 2024
b167124
testing
BoxiangW Nov 5, 2024
c5ed16f
Test
BoxiangW Nov 5, 2024
f7c62aa
Merge branch 'main' into boxiangw/add-config-validation
BoxiangW Nov 6, 2024
2c1fb7f
Merge branch 'main' into boxiangw/add-config-validation
BoxiangW Nov 6, 2024
1d95c21
Add mbs and gbs to datamodule
BoxiangW Nov 7, 2024
45725f9
Add temp fix
BoxiangW Nov 7, 2024
4adc9cc
Fix trainer device issue
BoxiangW Nov 7, 2024
877f7ed
Fix test
BoxiangW Nov 7, 2024
3dff963
Fix devices and num_nodes issue
BoxiangW Nov 8, 2024
7358a6f
Fix T5 issue
BoxiangW Nov 8, 2024
e6f7aa2
Merge branch 'main' into boxiangw/add-config-validation
BoxiangW Nov 8, 2024
6450f0c
Merge branch 'main' into boxiangw/add-config-validation
BoxiangW Nov 8, 2024
7789fbc
Fix bug
BoxiangW Nov 12, 2024
2e46fea
Merge branch 'main' into boxiangw/add-config-validation
BoxiangW Nov 12, 2024
1794ba3
Remove print
BoxiangW Nov 12, 2024
94e7abc
Add if condition for nl.MegatronStrategy
BoxiangW Nov 19, 2024
9ba9dcf
Apply isort and black reformatting
BoxiangW Nov 19, 2024
72a5db8
Merge branch 'main' into boxiangw/add-config-validation
BoxiangW Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import warnings
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
Expand Down Expand Up @@ -145,6 +146,7 @@ def pretrain(
>>> llm.pretrain(model, data, trainer)
PosixPath('/path/to/log_dir')
"""
_validate_config(model, data, trainer, log=log, resume=resume, optim=optim)
return train(
model=model,
data=data,
Expand Down Expand Up @@ -195,6 +197,7 @@ def finetune(
PosixPath('/path/to/log_dir')
"""

_validate_config(model, data, trainer, log=log, resume=resume, optim=optim, model_transform=peft)
return train(
model=model,
data=data,
Expand Down Expand Up @@ -875,3 +878,103 @@ def _set_with_io(obj, attr, value):
setattr(obj, attr, value)
if hasattr(obj, "__io__") and hasattr(value, "__io__"):
setattr(obj.__io__, attr, deepcopy(value.__io__))


def _validate_config(
model: pl.LightningModule,
data: pl.LightningDataModule,
trainer: Trainer,
log: Optional[NeMoLogger] = None,
resume: Optional[AutoResume] = None,
optim: Optional[OptimizerModule] = None,
tokenizer: Optional[TokenizerType] = None,
model_transform: Optional[Union[PEFT, ModelTransform, Callable]] = None,
) -> None:

## Model validation
assert getattr(model.config, "seq_length", 1) > 0
assert getattr(model.config, "max_position_embeddings", 1) > 0
assert model.config.num_layers > 0
assert model.config.hidden_size > 0
assert model.config.num_attention_heads > 0
assert model.config.ffn_hidden_size > 0

if hasattr(model.config, "seq_length"):
if getattr(model.config, "max_position_embeddings", None) is not None:
assert model.config.seq_length <= model.config.max_position_embeddings

## Data validation
assert data.micro_batch_size > 0
assert data.global_batch_size > 0
assert data.seq_length > 0

assert (
data.global_batch_size % data.micro_batch_size == 0
), "Global batch size must be divisible by micro batch size in data module."

## Trainer validation

# MegatronStrategy validation
if isinstance(trainer.strategy, nl.MegatronStrategy):
# Basic validation
assert trainer.strategy.tensor_model_parallel_size > 0
assert trainer.strategy.pipeline_model_parallel_size > 0
assert trainer.strategy.context_parallel_size > 0

# DP validation
assert (trainer.num_devices * trainer.num_nodes) % (
trainer.strategy.tensor_model_parallel_size
* trainer.strategy.pipeline_model_parallel_size
* trainer.strategy.context_parallel_size
) == 0, "Number of GPUs must be divisible by the product of all parallelism sizes for data parallel."

assert (
data.global_batch_size
% (
data.micro_batch_size
* (
(trainer.num_devices * trainer.num_nodes)
/ (
trainer.strategy.tensor_model_parallel_size
* trainer.strategy.pipeline_model_parallel_size
* trainer.strategy.context_parallel_size
)
)
)
== 0
), "Global batch size must be divisible by the product of micro batch size and data parallel size"

# TP/SP validation
if trainer.strategy.tensor_model_parallel_size == 1:
if trainer.strategy.sequence_parallel == True:
warnings.warn("Disabling sequence parallelism because tensor model parallelism is disabled")
trainer.strategy.sequence_parallel = False

# PP/VP validation
if trainer.strategy.pipeline_model_parallel_size > 1:
assert (
trainer.strategy.pipeline_dtype is not None
), "pipeline_dtype must be set if pipeline model parallelism is enabled"
else:
if trainer.strategy.virtual_pipeline_model_parallel_size is not None:
warnings.warn("Disabling virtual pipeline parallelism because pipeline model parallelism is disabled")
trainer.strategy.virtual_pipeline_model_parallel_size = None
if trainer.strategy.pipeline_dtype is not None:
warnings.warn("Setting pipeline dtype to None because pipeline model parallelism is disabled")
trainer.strategy.pipeline_dtype = None

# CP validation
if trainer.strategy.context_parallel_size > 1:
if model.config.seq_length is not None:
assert (
model.config.seq_length % (trainer.strategy.context_parallel_size * 2) == 0
), 'Sequence length must be divisible by 2 * context parallel size if context parallel is used.'

# EP validation
if trainer.strategy.expert_model_parallel_size > 1:
assert (
model.config.num_moe_experts is not None
), "num_experts must be non None to use expert model parallelism"
assert (
model.config.num_moe_experts % trainer.strategy.expert_model_parallel_size == 0
), "Number of experts should be a multiple of expert model parallel_size."
6 changes: 4 additions & 2 deletions nemo/collections/llm/gpt/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
):
super().__init__()
self.seq_length = seq_length
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.num_train_samples = num_train_samples
self.num_val_samples = num_val_samples
self.num_test_samples = num_test_samples
Expand All @@ -65,8 +67,8 @@ def __init__(

self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
rampup_batch_size=rampup_batch_size,
)

Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/llm/gpt/data/pre_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def __init__(

self.build_kwargs = build_kwargs
self.seq_length = seq_length
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.tokenizer = tokenizer
self.num_workers = num_workers
self.pin_memory = pin_memory
Expand All @@ -211,8 +213,8 @@ def __init__(
self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "GPT2BPETokenizer")
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
rampup_batch_size=rampup_batch_size,
)

Expand Down
174 changes: 174 additions & 0 deletions tests/collections/llm/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nemo_run as run
import pytest
import torch

from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.llm.api import _validate_config
from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel


class TestValidateConfig:

def reset_configs(self):
model = LlamaModel(config=run.Config(Llama3Config8B))
data = llm.MockDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
trainer = nl.Trainer(strategy=nl.MegatronStrategy())
return model, data, trainer

def test_model_validation(self):
model, data, trainer = self.reset_configs()
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
model.config.seq_length = 0
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
model.config.num_layers = 0
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
model.config.hidden_size = 0
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
model.config.num_attention_heads = 0
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
model.config.ffn_hidden_size = 0
_validate_config(model, data, trainer)

def test_data_validation(self):
model, data, trainer = self.reset_configs()
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
data.micro_batch_size = 0
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
data.global_batch_size = 0
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
data.seq_length = 0
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
data.micro_batch_size = 3
data.global_batch_size = 128
_validate_config(model, data, trainer)

def test_trainer_validatiopn(self):
model, data, trainer = self.reset_configs()
_validate_config(model, data, trainer)

# Basic validation
with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
trainer.strategy.tensor_model_parallel_size = 0
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
trainer.strategy.pipeline_model_parallel_size = 0
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
trainer.strategy.context_parallel_size = 0
_validate_config(model, data, trainer)

# DP validation
with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
trainer.strategy.tensor_model_parallel_size = 8
trainer.strategy.pipeline_model_parallel_size = 2
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
trainer.strategy.tensor_model_parallel_size = 3
trainer.strategy.pipeline_model_parallel_size = 2
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
data.global_batch_size = 3
data.micro_batch_size = 1
trainer.strategy.tensor_model_parallel_size = 2
trainer.strategy.pipeline_model_parallel_size = 2
_validate_config(model, data, trainer)

# TP/SP validation
model, data, trainer = self.reset_configs()
trainer.strategy.tensor_model_parallel_size = 1
trainer.strategy.sequence_parallel = True
_validate_config(model, data, trainer)
assert trainer.strategy.sequence_parallel == False

# PP/VP validation
with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
trainer.strategy.pipeline_model_parallel_size = 2
trainer.strategy.pipeline_dtype = None
_validate_config(model, data, trainer)

model, data, trainer = self.reset_configs()
trainer.strategy.pipeline_model_parallel_size = 1
trainer.strategy.virtual_pipeline_model_parallel_size = 2
trainer.strategy.pipeline_dtype = torch.bfloat16
_validate_config(model, data, trainer)
assert trainer.strategy.virtual_pipeline_model_parallel_size is None
assert trainer.strategy.pipeline_dtype is None

# CP validation
with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
model.config.seq_length = 5
trainer.strategy.context_parallel_size = 2
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
model.config.seq_length = 2
trainer.strategy.context_parallel_size = 2
_validate_config(model, data, trainer)

# EP validation
with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
model.config.num_moe_experts = None
trainer.strategy.expert_model_parallel_size = 2
_validate_config(model, data, trainer)

with pytest.raises(AssertionError):
model, data, trainer = self.reset_configs()
model.config.num_moe_experts = 3
trainer.strategy.expert_model_parallel_size = 2
_validate_config(model, data, trainer)
Loading