Skip to content

Allow separate sub-DAG for load and compute warp groups with warp-specialized circular buffering. #3941

Open
@rdspring1

Description

Why? Support persistent and ping-pong matmul kernels.

Context:
Circular-buffer for-loop is the first serial for-loop to the left of computeAt position. It is applied to the load cacheAfter TensorViews. Persistent scheduling with matmul creates multiple serial for-loops. There is a grid-stride for-loop over output-tiles and cta-k for-loop.

Scheduling Proposal:

  • Merge the output-tile and cta-k iterDomains.
  • Apply circular buffering to the load cacheAfter TensorView.
  • wgmma consumer would have separate output-tile and cta-k iterDomains.

Why? The load cacheAfter TensorViews will have a single serial iterDomain for circular buffering. This matches current circular buffering implementation.

Problem: The output-tile and cta-k iterDomains cannot be merged for compute warp-groups because storing the matmul results to global memory does not have cta-k iterDomain. Therefore, the wgmma consumer cannot be inlined with cacheAfter tma load, breaking the current circular buffering implementation.

Does consumer of circular buffering inputs need to be inlined?
This restriction seems unnecessary for warp-specialized circular buffering.

Lowering Proposal: For warp-specialized circular buffering, track separate for-loops for load and compute warp groups.
Restriction: The compute for-loop must be derived from load for-loop.

Pseudo-code:

mbarrier init
if (tma-load) {
  decrease_register_limit(40);
  for (output-tile) {
    for (cta-k) {
      if (elect-sync) {
        mbarrier wait for empty stage
        mbarrier arriveExpectTx for tma load
        tma load operand A and B cta tiles for stage
      }
    }
  }
} else { compute warp-group
  increase_register_limit(232);
  mbarrier arrive to signal all stages are empty
  for (output-tile) {
    for (cta-k) {
      mbarrier wait for full stage
      for (warp-k) {
        wgmma_fence;
        wgmma_64m_256m_16k;
      }
      wgmma_commit;
      wgmma_wait;
      mbarrier arrive to signal current stage is empty
    }
    wgmma_wait;
    convert fp32 results to bf16
    stmatrix from registers to shared memory
    block_sync();
    tma store from shared to global memory
    tma_store_commit;
    tma_store_wait;
  }
}
destroy mbarrier

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions