Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 62 additions & 26 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,

template <int THREADS_PER_ROW, typename Engine0, typename Layout0,
typename Engine1, typename Layout1, typename Engine2, typename Layout2>
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
__forceinline__ __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
Tensor<Engine1, Layout1> &dP_sum, Tensor<Engine2, Layout2> &sdPsum,
const int gdP_col_stride, const float scale) {
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
Expand Down Expand Up @@ -425,7 +425,7 @@ inline __device__ void convert_dKV(const Params &params) {
////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Is_attn_mask, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
__forceinline__ __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {

const bool Is_sparse_attn_mask = params.flashmask_downstart_ptr != nullptr;
int flashmask_startrow = 0;
Expand Down Expand Up @@ -488,9 +488,32 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
const bool flashmask_has_end = params.flashmask_downend_ptr != nullptr;
int flashmask_upendrow = params.seqlen_q;

#define SPARSE_MASKED_DOWN \
(((m_block * kBlockM) >= flashmask_downstartmax) && (!flashmask_has_end || (m_block + 1) * kBlockM < flashmask_downendmin))

#define SPARSE_MASKED_UP \
(!Is_causal && (m_block + 1) * kBlockM < flashmask_upendmin && (!flashmask_has_end || m_block * kBlockM >= flashmask_upstartmax))

#define SPARSE_MASKED \
(SPARSE_MASKED_DOWN || SPARSE_MASKED_UP)

const bool enable_mask_bypass = params.enable_mask_bypass;

if (Is_sparse_attn_mask && enable_mask_bypass) {
int flashmask_downstartmax = std::numeric_limits<int>::max();
int flashmask_downendmin = 0;
int flashmask_upendmin = 0;
int flashmask_upstartmax = std::numeric_limits<int>::max();

if(params.flashmask_downstart_nblockmax != nullptr)
flashmask_downstartmax = gSparseMaskDownMax[n_block];
if(params.flashmask_downend_nblockmin != nullptr)
flashmask_downendmin = gSparseMaskDownEndMin[n_block];
if(params.flashmask_upend_nblockmin != nullptr)
flashmask_upendmin = gSparseMaskUpMin[n_block];
if(params.flashmask_upstart_nblockmax != nullptr)
flashmask_upstartmax = gSparseMaskUpStartMax[n_block];

if (Is_sparse_attn_mask && enable_mask_bypass && !flashmask_has_end) {
m_block_max = min(m_block_max,
cute::ceil_div(gSparseMaskDownMax[n_block], kBlockM));
/*
Expand Down Expand Up @@ -744,7 +767,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN) / kBlockM;
if(Is_sparse_attn_mask && enable_mask_bypass){
if(Is_sparse_attn_mask && enable_mask_bypass && !flashmask_has_end){
if (!Is_causal) {
m_block_min = max(m_block_min, gSparseMaskUpMin[n_block] / kBlockM);
}
Expand Down Expand Up @@ -922,8 +945,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
// }
// if (cute::thread0()) { print(tSrK); }
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);

if (!SPARSE_MASKED) {
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
}

// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
Expand Down Expand Up @@ -1005,7 +1031,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
}
// if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if (!SPARSE_MASKED) {
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
}
if (Is_dropout) {
uint32_t warp_id = tidx / 32;
uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
Expand Down Expand Up @@ -1048,21 +1076,23 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

// if (cute::thread0()) { print(dP_sum); }

flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
);

// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor dS = make_tensor(acc_dp.data(), scores.layout());
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
if (!SPARSE_MASKED) {
flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
);

// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
}
}
}
// if (cute::thread0()) { print(dS); }
Expand Down Expand Up @@ -1104,8 +1134,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
// flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
if (!SPARSE_MASKED) {
flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
}
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
// if (cute::thread0()) { print(acc_dv); }

Expand All @@ -1124,8 +1156,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
}
}

flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
if (!SPARSE_MASKED) {
flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
}
// if (cute::thread0()) { print(acc_dq); }

if (m_block > m_block_min) {
Expand Down Expand Up @@ -1163,8 +1197,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
}

flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
if (!SPARSE_MASKED) {
flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
}
// if (cute::thread0()) { print(acc_dk); }
if (Double_buffer) { // Double buffer for sQ
tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ));
Expand Down
4 changes: 0 additions & 4 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
const bool is_attn_mask = params.attn_mask_ptr != nullptr;
const bool is_deterministic = params.num_splits == 1;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
if (params.flashmask_downend_ptr != nullptr) {
// bypass is not supported for flashmask_downend
params.enable_mask_bypass = false;
}
prepare_sparsemask<Kernel_traits>(params, stream);
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
Expand Down