Skip to content

Commit

Permalink
Refactor stacked version of FP8 Grouped Gemm for reduced overhead (py…
Browse files Browse the repository at this point in the history
…torch#3699)

Summary:
Pull Request resolved: pytorch#3699

X-link: facebookresearch/FBGEMM#780

Currently, the stacked version of FP8 grouped gemm accepts lists of tensor inputs and produces a single tensor output. This reduces quite a bit of overhead when cuda graphs are used, but still requires splitting input tensors in prefill which can be costly. This diff updates the input types of stacked grouped gemm to support single tensors. Notably, since M varies across group and we do no padding, this change requires that we provide a new input tensor called `M_offsets` that indicates the row that each group begins at within in the first input. We create M_offsets by taking the cumulative sum of M for each group, which we may be able to further optimize.

This diff also includes a long overdue refactor of grouped gemm setup for nvidia such that we only launch a single kernel rather than one per group. This should reduce overhead by quite a bit in some cases.

Differential Revision: D69544396
  • Loading branch information
jwfromm authored and facebook-github-bot committed Feb 17, 2025
1 parent 610ea2e commit f5c437a
Show file tree
Hide file tree
Showing 73 changed files with 459 additions and 348 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ at::Tensor f8f8bf16_rowwise_batched_impl(
int B = XQ.size(0);
int M = XQ.size(1);
int N = WQ.size(1);
int K = XQ.size(2);
int K = WQ.size(2);

int StrideA = K;
int StrideB = K;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,48 @@ void set_static_kernel_args(
}
}

__global__ void set_kernel_args_m_offsets_kernel(
KernelArguments* kernel_args,
ADataType* XQ,
BDataType* WQ,
D0DataType* w_scale,
D1DataType* x_scale,
EDataType* output,
int64_t* M_offsets,
int M,
int N,
int K,
int group_count) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each thread is responsible for setting up the arguments for one group.
if (thread_idx < group_count) {
// Compute offsets for this group.
int offset_M;
int kernel_M;
if (thread_idx == 0) {
offset_M = 0;
kernel_M = M_offsets[thread_idx];
} else {
offset_M = M_offsets[thread_idx - 1];
kernel_M = M_offsets[thread_idx] - offset_M;
}
KernelArguments kernel_group_args = {
XQ + (offset_M * K),
WQ + (thread_idx * N * K),
{w_scale + (thread_idx * N), x_scale + offset_M},
output + (offset_M * N),
kernel_M,
N,
K,
K,
K,
{0, 0},
N};
// Write kernel args to memory.
kernel_args[thread_idx] = kernel_group_args;
}
}

__global__ void set_kernel_args_fixed_nk_kernel(
KernelArguments* kernel_args,
ADataType* XQ,
Expand Down Expand Up @@ -252,56 +294,85 @@ void set_dynamic_kernel_args(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
at::Tensor zero_start_index_M) {
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_offsets) {
// Get current cuda stream.
auto stream = at::cuda::getCurrentHIPStream().stream();
int group_count = XQ.size(0);
// Confirm M is on the proper device.
TORCH_CHECK(
XQ.device() == zero_start_index_M.device(),
"zero_start_index_M and inputs must be on the same device.");
int group_count;
// Check provided tensors are valid.
TORCH_CHECK(
zero_start_index_M.size(0) == group_count,
"zero_start_index_M must have an entry for each group.");
TORCH_CHECK(
zero_start_index_M.dtype() == at::kLong,
"zero_start_index_M must be int64.");
zero_start_index_M.has_value() != M_offsets.has_value(),
"One of zero_start_index_M or M_offsets must be provided.");
if (zero_start_index_M.has_value()) {
group_count = zero_start_index_M.value().size(0);
TORCH_CHECK(
XQ.device() == zero_start_index_M.value().device(),
"zero_start_index_M and inputs must be on the same device.");
TORCH_CHECK(
zero_start_index_M.value().dtype() == at::kLong,
"zero_start_index_M must be int64.");
}
if (M_offsets.has_value()) {
group_count = M_offsets.value().size(0);
TORCH_CHECK(
XQ.device() == M_offsets.value().device(),
"M_offsets and inputs must be on the same device.");
TORCH_CHECK(
M_offsets.value().dtype() == at::kLong, "M_offsets must be int64.");
}

// We assume that M, N, and K are fixed across groups.
// The actual m values are sstored in the passed M tensor.
int M = XQ.size(1);
int K = XQ.size(2);
// When m_offsets is used XQ is shape [tota_M, K]. When zero_start_index_M is
// used it is shape [G, M, K].
int M = XQ.size(XQ.dim() - 2);
int K = WQ.size(2);
int N = WQ.size(1);

// Launch a kernel that sets kernel argument memory.
// Each thread sets one float4 which corresponds to 8 bf16 values.
const int BLOCK_SIZE = 8;
TORCH_CHECK(
N % BLOCK_SIZE == 0, "N must be divisible 8 for dynamic grouped gemm.");
int block_factor = std::max(group_count, (group_count * M * N) / BLOCK_SIZE);
int blockSize = std::min(512, block_factor);
int numBlocks = (block_factor + blockSize - 1) / blockSize;
set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
reinterpret_cast<ADataType*>(XQ.data_ptr()),
reinterpret_cast<BDataType*>(WQ.data_ptr()),
reinterpret_cast<D0DataType*>(w_scale.data_ptr()),
reinterpret_cast<D1DataType*>(x_scale.data_ptr()),
reinterpret_cast<EDataType*>(output.data_ptr()),
reinterpret_cast<int64_t*>(zero_start_index_M.data_ptr()),
M,
N,
K,
group_count);
if (zero_start_index_M.has_value()) {
const int BLOCK_SIZE = 8;
TORCH_CHECK(
N % BLOCK_SIZE == 0, "N must be divisible 8 for dynamic grouped gemm.");
int block_factor =
std::max(group_count, (group_count * M * N) / BLOCK_SIZE);
int blockSize = std::min(512, block_factor);
int numBlocks = (block_factor + blockSize - 1) / blockSize;
set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
reinterpret_cast<ADataType*>(XQ.data_ptr()),
reinterpret_cast<BDataType*>(WQ.data_ptr()),
reinterpret_cast<D0DataType*>(w_scale.data_ptr()),
reinterpret_cast<D1DataType*>(x_scale.data_ptr()),
reinterpret_cast<EDataType*>(output.data_ptr()),
reinterpret_cast<int64_t*>(zero_start_index_M.value().data_ptr()),
M,
N,
K,
group_count);
} else {
int blockSize = std::min(512, group_count);
int numBlocks = (group_count + blockSize - 1) / blockSize;
set_kernel_args_m_offsets_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
reinterpret_cast<ADataType*>(XQ.data_ptr()),
reinterpret_cast<BDataType*>(WQ.data_ptr()),
reinterpret_cast<D0DataType*>(w_scale.data_ptr()),
reinterpret_cast<D1DataType*>(x_scale.data_ptr()),
reinterpret_cast<EDataType*>(output.data_ptr()),
reinterpret_cast<int64_t*>(M_offsets.value().data_ptr()),
M,
N,
K,
group_count);
}
}

template <typename OutputType>
OutputType _f8f8bf16_rowwise_grouped(
std::vector<at::Tensor> f8f8bf16_rowwise_grouped(
at::TensorList XQ,
at::TensorList WQ,
at::TensorList x_scale,
at::TensorList w_scale,
std::optional<OutputType> output = std::nullopt) {
std::optional<std::vector<at::Tensor>> output = std::nullopt) {
// Check that input datatypes are valid.
// First confirm that there are the same number of groups in all inputs.
TORCH_CHECK(
Expand Down Expand Up @@ -334,62 +405,40 @@ OutputType _f8f8bf16_rowwise_grouped(
TORCH_CHECK(ws.dtype() == at::kFloat, "Scales must be float32.");
}

OutputType Y;
std::vector<at::Tensor> Y;
// Need to handle different output modes separately.
// First handle tensor list output.
if constexpr (std::is_same_v<OutputType, std::vector<at::Tensor>>) {
if (output.has_value()) {
Y = output.value();
if (output.has_value()) {
Y = output.value();
TORCH_CHECK(
Y.size() == group_count,
"Output and input must have same number of groups.");
// Check that output shapes are correct.
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int N = WQ[i].size(0);
int out_M = Y[i].size(0);
int out_N = Y[i].size(1);
TORCH_CHECK(
Y.size() == group_count,
"Output and input must have same number of groups.");
// Check that output shapes are correct.
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int N = WQ[i].size(0);
int out_M = Y[i].size(0);
int out_N = Y[i].size(1);
TORCH_CHECK(
M == out_M && N == out_N,
"Output tensors do not have the expected shape.");
TORCH_CHECK(
Y[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16.");
}
} else {
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int N = WQ[i].size(0);
Y.push_back(at::empty({M, N}, XQ[i].options().dtype(at::kBFloat16)));
}
M == out_M && N == out_N,
"Output tensors do not have the expected shape.");
TORCH_CHECK(
Y[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16.");
}
// Now handle single tensor output.
} else {
// Compute total M across groups.
int total_M = 0;
int N = WQ[0].size(0);
for (int i = 0; i < group_count; i++) {
total_M += XQ[i].size(0);
// Also make sure N is constant across shapes.
TORCH_CHECK(
WQ[i].size(0) == N,
"N must be constant across groups for stacked output.");
}
if (output.has_value()) {
Y = output.value();
// Check that shape is expected.
TORCH_CHECK(
Y.size(0) == total_M && Y.size(1) == N,
"Preallocated output should have size [total_M, N].");
} else {
Y = at::empty({total_M, N}, XQ[0].options().dtype(at::kBFloat16));
int M = XQ[i].size(0);
int N = WQ[i].size(0);
Y.push_back(at::empty({M, N}, XQ[i].options().dtype(at::kBFloat16)));
}
}

// Prepare kernel arguments by copying them to the proper device location.
at::Tensor kernel_args = at::empty(
{static_cast<long>(group_count * sizeof(KernelArguments))},
XQ[0].options().dtype(at::kByte));
set_static_kernel_args<OutputType>(kernel_args, XQ, WQ, x_scale, w_scale, Y);
set_static_kernel_args<std::vector<at::Tensor>>(
kernel_args, XQ, WQ, x_scale, w_scale, Y);

// We use the largest of each shape for heuristics.
int MaxM = 0;
Expand All @@ -400,32 +449,67 @@ OutputType _f8f8bf16_rowwise_grouped(
MaxN = max(MaxN, WQ[i].size(0));
MaxK = max(MaxK, XQ[i].size(1));
}
RowwiseGroupedKernel<at::TensorList, OutputType> selected_kernel =
rowwise_grouped_heuristic_dispatch<at::TensorList, OutputType>(
MaxM, MaxN, MaxK);
RowwiseGroupedKernel<at::TensorList, std::vector<at::Tensor>>
selected_kernel = rowwise_grouped_heuristic_dispatch<
at::TensorList,
std::vector<at::Tensor>>(MaxM, MaxN, MaxK);
return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y);
}

// Wrapper function for list input list output.
std::vector<at::Tensor> f8f8bf16_rowwise_grouped(
at::TensorList XQ,
at::TensorList WQ,
at::TensorList x_scale,
at::TensorList w_scale,
std::optional<std::vector<at::Tensor>> output = std::nullopt) {
return _f8f8bf16_rowwise_grouped<std::vector<at::Tensor>>(
XQ, WQ, x_scale, w_scale, output);
}

// Wrapper function for list input single tensor output.
at::Tensor f8f8bf16_rowwise_grouped_stacked(
at::TensorList XQ,
at::TensorList WQ,
at::TensorList x_scale,
at::TensorList w_scale,
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor M_offsets,
std::optional<at::Tensor> output = std::nullopt) {
return _f8f8bf16_rowwise_grouped<at::Tensor>(
XQ, WQ, x_scale, w_scale, output);
// Check that input datatypes are valid.
// First confirm that there are the same number of groups in all inputs.
int group_count = M_offsets.size(0);
// XQ is expected to be shape [total_M, K].
int total_M = XQ.size(0);
// WQ is expected to be shape [G, N, K].
int N = WQ.size(1);
int K = XQ.size(1);
TORCH_CHECK(
WQ.size(0) == group_count && x_scale.numel() == total_M &&
w_scale.numel() / group_count == N,
"All inputs must have the same number of groups.");
// Iterate over inputs and check they are valid.
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(XQ.dim() == 2, "Input XQ must be 2D (total_M,K).");
TORCH_CHECK(
XQ.dtype() == at::kFloat8_e4m3fnuz,
"Input XQ must be type float8_e4m3fnuz.");

TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
TORCH_CHECK(WQ.dim() == 3, "Input WQ must be 3D (G,N,K).");
TORCH_CHECK(
WQ.dtype() == at::kFloat8_e4m3fnuz,
"Input WQ must be type float8_e4m3fnuz.");
TORCH_CHECK(
WQ.size(1) >= 512 && WQ.size(2) >= 512,
"N and K must be at least 512 for grouped gemm. For smaller inputs, consider unrolling.");

TORCH_CHECK(x_scale.dtype() == at::kFloat, "Scales must be float32.");
TORCH_CHECK(w_scale.dtype() == at::kFloat, "Scales must be float32.");

// Allocate an empty output array. We will set its values to zero as part
// of kernel setup.
at::Tensor Y = at::empty({total_M, N}, XQ.options().dtype(at::kBFloat16));

// Prepare kernel arguments by copying them to the proper device location.
at::Tensor kernel_args = at::empty(
{static_cast<long>(group_count * sizeof(KernelArguments))},
XQ.options().dtype(at::kByte));
set_dynamic_kernel_args(
kernel_args, XQ, WQ, x_scale, w_scale, Y, std::nullopt, M_offsets);

RowwiseGroupedKernel<at::Tensor, at::Tensor> selected_kernel =
rowwise_grouped_heuristic_dispatch<at::Tensor, at::Tensor>(
total_M / group_count, N, K);
return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y);
}

at::Tensor f8f8bf16_rowwise_grouped_dynamic(
Expand All @@ -439,7 +523,7 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
int group_count = XQ.size(0);
int M = XQ.size(1);
int N = WQ.size(1);
int K = XQ.size(2);
int K = WQ.size(2);
TORCH_CHECK(
WQ.size(0) == group_count && x_scale.numel() / group_count == M &&
w_scale.numel() / group_count == N,
Expand Down Expand Up @@ -473,7 +557,7 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
{static_cast<long>(group_count * sizeof(KernelArguments))},
XQ.options().dtype(at::kByte));
set_dynamic_kernel_args(
kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M);
kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M, std::nullopt);

RowwiseGroupedKernel<at::Tensor, at::Tensor> selected_kernel =
rowwise_grouped_heuristic_dispatch<at::Tensor, at::Tensor>(M, N, K);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_in
// Check if this input needs to be padded.
bool pad = false;
if constexpr (std::is_same_v<InputType, at::Tensor>) {
int K = XQ.size(2);
int K = WQ.size(2);
pad = K % 128 != 0;
} else {
for (int i = 0; i < XQ.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_in
// Check if this input needs to be padded.
bool pad = false;
if constexpr (std::is_same_v<InputType, at::Tensor>) {
int K = XQ.size(2);
int K = WQ.size(2);
pad = K % 128 != 0;
} else {
for (int i = 0; i < XQ.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_in
// Check if this input needs to be padded.
bool pad = false;
if constexpr (std::is_same_v<InputType, at::Tensor>) {
int K = XQ.size(2);
int K = WQ.size(2);
pad = K % 128 != 0;
} else {
for (int i = 0; i < XQ.size(); i++) {
Expand Down
Loading

0 comments on commit f5c437a

Please sign in to comment.