Skip to content

Swizzle tiles in matmul without introducing larger grid due to nondivisible splits #3942

@jacobhinkle

Description

In both the Ampere and Hopper matmul schedulers we first schedule a tile grid in 2 dimensions. Then if we have set params.grid_swizzle_factor != 1 we perform a Split of the first dimension and we Merge the inner part of the split with the other tile grid dimension.

// split [I1/factor, factor, I2]
// reorder [I1/factor, I2, factor]
// merge [I1/factor, I2*factor]

The result is another 2D grid that has mixed the original two dimensions together through the split and merge, which we then parallelize as BIDx, BIDy for non-persistent, or we merge them together for persistent.

Image

The problem with this approach is that it introduces new blocks due to nondivisible splits. Let M, N be the original extents of the tile grid dimensions that are being swizzled. After the swizzle described above we have dimensions ceilDiv(M, factor), N*factor, so the number of tiles has increased from M*N to ceilDiv(M, factor)*factor*N, which is an increase when M%factor != 0.

Proposed approach

We currently have 1D->1D and 2D->2D swizzles represented by the Swizzle and Swizzle2D classes. I propose to add another class SwizzleMerge to represent 2D->1D swizzles that preserve the number of elements. The only swizzle type implemented for this new class would skip the overflowing indices (red in this diagram):

Image

Details

The current Split+Merge+Merge lets us convert a linear index i in the range [0, ceilDiv(M, factor) * N * factor - 1] into a pair of indices f(i), g(i) as follows:

megarow = i / (N*factor);
megarow_pos = i % (N*factor);
f(i) = factor * megarow + megarow_pos % factor;
g(i) = megarow_pos / factor;

Instead I propose during indexing of SwizzleMerge that we map a linear output index i in the range [0, M * N - 1] to a pair of input indices p(i),q(i) like so:

megarow = i / (N*factor);
megarow_pos = i - megarow * (N*factor);
megarow_height = (megarow == M / factor) ? (M - (M / factor) * factor) : factor;
p(i) = factor * megarow + megarow_pos % megarow_height;
q(i) = megarow_pos / megarow_height;

Notice that this has a naturally recursive structure where we begin with a linear indices and we produce as outputs an offset, new sizes for the megarow, and a new linear index within the megarow. This will let us generalize the swizzle easily accept a list of factors.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions