Skip to content

Add Paddle linear_attn GDN/KDA operators and benchmark support#133

Open
YJMSTR wants to merge 4 commits intoPaddlePaddle:mainfrom
YJMSTR:gdn_and_kda
Open

Add Paddle linear_attn GDN/KDA operators and benchmark support#133
YJMSTR wants to merge 4 commits intoPaddlePaddle:mainfrom
YJMSTR:gdn_and_kda

Conversation

@YJMSTR
Copy link
Copy Markdown

@YJMSTR YJMSTR commented Apr 17, 2026

  • migrate the GDN and KDA linear attention operator stack into flashmask/linear_attn from flash-linear-attention
  • add shared Triton/Paddle utilities, l2norm support, common chunk kernels, and fused recurrent paths required by both operators
  • register paddle benchmark ops and runner entrypoints for chunk/recurrent GDN and KDA performance comparison
  • add Paddle-side GDN/KDA correctness tests covering forward/backward parity, varlen paths, gate-in-kernel paths, transpose-state handling, and inference-oriented state outputs
  • keep the test suite focused on operator correctness rather than internal implementation details

- migrate the GDN and KDA linear attention operator stack into flashmask/linear_attn from flash-linear-attention
- add shared Triton/Paddle utilities, l2norm support, common chunk kernels, and fused recurrent paths required by both operators
- register paddle benchmark ops and runner entrypoints for chunk/recurrent GDN and KDA performance comparison
- add Paddle-side GDN/KDA correctness tests covering forward/backward parity, varlen paths, gate-in-kernel paths, transpose-state handling, and inference-oriented state outputs
- keep the test suite focused on operator correctness rather than internal implementation details
@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Apr 17, 2026

CLA assistant check
All committers have signed the CLA.

@YJMSTR YJMSTR marked this pull request as ready for review April 20, 2026 07:45
@YJMSTR
Copy link
Copy Markdown
Author

YJMSTR commented Apr 20, 2026

benchmark on H800 pure paddle env, FLA_BENCHMARK=1, warmup = 100ms, rep = 500ms.

============================================================================================
  Machine: NVIDIA H800 | Paddle 3.4.0.dev20260407
============================================================================================
  op                 mode       B      T    H    D   median(ms)      p20(ms)      p80(ms)
  ------------------ ------- ---- ------ ---- ---- ------------ ------------ ------------
  chunk_gdn          fwd        1   8192   96  128        2.033        2.030        2.034
  chunk_kda          fwd        1   8192   96  128        3.268        3.266        3.270
  recurrent_gdn      fwd        1   8192   96  128       12.065       12.055       12.073
  recurrent_kda      fwd        1   8192   96  128       21.727       21.706       21.746
  chunk_gdn          fwd        2  16384   16  128        1.342        1.341        1.344
  chunk_kda          fwd        2  16384   16  128        2.299        2.297        2.301
  recurrent_gdn      fwd        2  16384   16  128       22.165       22.160       22.175
  recurrent_kda      fwd        2  16384   16  128       39.508       39.425       39.530
  chunk_gdn          fwd        4   2048   16  128        0.622        0.604        0.651
  chunk_kda          fwd        4   2048   16  128        0.715        0.700        0.727
  recurrent_gdn      fwd        4   2048   16  128        2.851        2.847        2.855
  recurrent_kda      fwd        4   2048   16  128        5.164        5.147        5.173
  chunk_gdn          fwd        4   4096   64  128        2.589        2.585        2.594
  chunk_kda          fwd        4   4096   64  128        4.165        4.163        4.167
  recurrent_gdn      fwd        4   4096   64  128       12.649       12.620       12.669
  recurrent_kda      fwd        4   4096   64  128       23.368       23.357       23.384
  chunk_gdn          fwd        8   1024    8   64        0.614        0.591        0.648
  chunk_kda          fwd        8   1024    8   64        0.718        0.696        0.750
  recurrent_gdn      fwd        8   1024    8   64        1.256        1.255        1.258
  recurrent_kda      fwd        8   1024    8   64        2.373        2.368        2.377
  chunk_gdn          fwd        8   2048   32  256        2.958        2.955        2.962
  chunk_kda          fwd        8   2048   32  256        5.195        5.190        5.199
  recurrent_gdn      fwd        8   2048   32  256       27.461       27.401       27.667
  recurrent_kda      fwd        8   2048   32  256       45.525       45.483       45.554
  chunk_gdn          fwdbwd     1   8192   96  128        7.265        7.261        7.275
  chunk_kda          fwdbwd     1   8192   96  128       13.648       13.644       13.651
  chunk_gdn          fwdbwd     2  16384   16  128        4.647        4.644        4.650
  chunk_kda          fwdbwd     2  16384   16  128        9.233        9.230        9.238
  chunk_gdn          fwdbwd     4   2048   16  128        1.579        1.543        1.630
  chunk_kda          fwdbwd     4   2048   16  128        2.405        2.403        2.410
  chunk_gdn          fwdbwd     4   4096   64  128        8.881        8.878        8.887
  chunk_kda          fwdbwd     4   4096   64  128       17.626       17.621       17.629
  chunk_gdn          fwdbwd     8   1024    8   64        1.507        1.486        1.557
  chunk_kda          fwdbwd     8   1024    8   64        1.845        1.823        1.871
  chunk_gdn          fwdbwd     8   2048   32  256       13.104       13.081       13.122
  chunk_kda          fwdbwd     8   2048   32  256       20.102       20.086       20.116
============================================================================================

paddle + torch env:

============================================================================================
  MachinMachine: NVIDIA H800 | Paddle 3.4.0.dev20260407
============================================================================================
  op                 mode       B      T    H    D   median(ms)      p20(ms)      p80(ms)
  ------------------ ------- ---- ------ ---- ---- ------------ ------------ ------------
  chunk_gdn          fwd        1   8192   96  128        2.027        2.024        2.031
  chunk_kda          fwd        1   8192   96  128        3.264        3.261        3.267
  recurrent_gdn      fwd        1   8192   96  128       12.611       12.604       12.621
  recurrent_kda      fwd        1   8192   96  128       19.857       19.834       19.880
  chunk_gdn          fwd        2  16384   16  128        1.328        1.327        1.330
  chunk_kda          fwd        2  16384   16  128        2.298        2.296        2.299
  recurrent_gdn      fwd        2  16384   16  128       23.316       23.257       23.328
  recurrent_kda      fwd        2  16384   16  128       36.378       36.333       36.404
  chunk_gdn          fwd        4   2048   16  128        0.712        0.700        0.732
  chunk_kda          fwd        4   2048   16  128        0.851        0.836        0.867
  recurrent_gdn      fwd        4   2048   16  128        3.005        2.998        3.009
  recurrent_kda      fwd        4   2048   16  128        4.712        4.699        4.733
  chunk_gdn          fwd        4   4096   64  128        2.577        2.574        2.580
  chunk_kda          fwd        4   4096   64  128        4.193        4.191        4.196
  recurrent_gdn      fwd        4   4096   64  128       13.118       13.078       13.160
  recurrent_kda      fwd        4   4096   64  128       22.053       22.032       22.077
  chunk_gdn          fwd        8   1024    8   64        0.701        0.688        0.721
  chunk_kda          fwd        8   1024    8   64        0.830        0.816        0.857
  recurrent_gdn      fwd        8   1024    8   64        1.262        1.260        1.265
  recurrent_kda      fwd        8   1024    8   64        2.318        2.315        2.321
  chunk_gdn          fwd        8   2048   32  256        2.950        2.947        2.954
  chunk_kda          fwd        8   2048   32  256        5.214        5.211        5.219
  recurrent_gdn      fwd        8   2048   32  256       28.753       28.651       28.816
  recurrent_kda      fwd        8   2048   32  256       39.391       39.372       39.419
  chunk_gdn          fwdbwd     1   8192   96  128        7.055        7.046        7.071
  chunk_kda          fwdbwd     1   8192   96  128       13.721       13.716       13.723
  chunk_gdn          fwdbwd     2  16384   16  128        4.654        4.651        4.658
  chunk_kda          fwdbwd     2  16384   16  128        9.398        9.394        9.401
  chunk_gdn          fwdbwd     4   2048   16  128        1.764        1.748        1.779
  chunk_kda          fwdbwd     4   2048   16  128        2.451        2.449        2.453
  chunk_gdn          fwdbwd     4   4096   64  128        8.844        8.841        8.847
  chunk_kda          fwdbwd     4   4096   64  128       17.906       17.903       17.914
  chunk_gdn          fwdbwd     8   1024    8   64        1.733        1.714        1.752
  chunk_kda          fwdbwd     8   1024    8   64        2.112        2.096        2.135
  chunk_gdn          fwdbwd     8   2048   32  256       13.160       13.144       13.173
  chunk_kda          fwdbwd     8   2048   32  256       20.331       20.320       20.337
============================================================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants