Description
Summary
This PR pytorch/pytorch#108070 updates the pin in PyTorch from 3.1 to 3.2.1. We are currently failing for FlashAttention tests when doing this update.
The update is causing the kernel to IMA.
A minimal repro for this is:
import torch
from torch.nn.functional import scaled_dot_product_attention
# Changing seq_len from 129 to 128 causing IMA to stop, likely because we are doing only 1 iteration
query = torch.randn(1, 1, 129, 8, device="cuda", dtype=torch.bfloat16) # (batch, num_heads, seq_len, embed_dim)
key = torch.randn(1, 1, 129, 8, device="cuda", dtype=torch.bfloat16) # (batch, num_heads, seq_len, embed_dim)
value = torch.randn(1, 1, 129, 8,device="cuda", dtype=torch.bfloat16) # (batch, num_heads, seq_len, embed_dim)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
scaled_dot_product_attention(query, key, value)
Using compute-sanitizer --tool memcheck python cutlass_repro.py
produces 64 errors of the form:
========= Invalid __global__ read of size 16 bytes
========= at 0x2bb0 in void pytorch_flash::flash_fwd_kernel<pytorch_flash::Flash_fwd_kernel_traits<(int)32, (int)128, (int)128, (int)4, (bool)0, (bool)0, cutlass::bfloat16_t, pytorch_flash::Flash_kernel_traits<(int)32, (int)128, (int)128, (int)4, cutlass::bfloat16_t>>, (bool)0, (bool)0, (bool)0, (bool)0, (bool)0>(pytorch_flash::Flash_fwd_params)
========= by thread (28,0,0) in block (1,0,0)
========= Address 0x7f6f25e00a70 is out of bounds
========= and is 8,336,181,105 bytes after the nearest allocation at 0x7f6d35000100 of size 512 bytes
========= Saved host backtrace up to driver entry point at kernel launch time
========= Host Frame: [0x304fd2]
========= in /lib64/libcuda.so.1
Full Repro Steps
FlashAttention requires A100 or newer to run ( I have validated on both A100 and H100)
Following the setup instructions for building PyTorch from source here:
https://github.com/pytorch/pytorch#from-source
Before building from source checkout the above PR:
Can use Githubs CLI tool to do this.
Installing gh conda install gh --channel conda-forge
Then checkout the PR
gh pr checkout 108070
For much faster builds you can use these env variables to turn off parts of the build that don't matter for this Repro:
#!/bin/bash
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
export USE_DISTRIBUTED=1
export BUILD_TEST=0
export BUILD_CAFFE2_OPS=0
export USE_FBGEMM=0
export BUILD_CAFFE2=0
export USE_OPENCV=0
export USE_QNNPACK=0
export USE_XNNPACK=0
export DEBUG=0
export USE_KINETO=1
export USE_CUDA=1
export TORCH_SHOW_CPP_STACKTRACES=1
export USE_GOLD_LINKER=1
export USE_NCCL=0
export WERROR=1
# export TORCH_CUDA_ARCH_LIST=9.0
Activity