Skip to content

[FIX Blackwell release3.3] layer_norm backward#78777

Closed
zhengshengning wants to merge 3 commits intoPaddlePaddle:release/3.3from
zhengshengning:fix_layer_norm_backward_33
Closed

[FIX Blackwell release3.3] layer_norm backward#78777
zhengshengning wants to merge 3 commits intoPaddlePaddle:release/3.3from
zhengshengning:fix_layer_norm_backward_33

Conversation

@zhengshengning
Copy link
Copy Markdown
Contributor

@zhengshengning zhengshengning commented Apr 23, 2026

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

devPR:#78782

根因:GammaBetaBackwardCUDAKernelTemplate 使用 block_dim_x=32, block_dim_y=32 =1024线程/block。在Blackwell (SM 100)架构上,编译器为每个线程生成了更多的寄存器,导致 1024 线程 × 寄存器数/线程 超出了SM的寄存器文件容量。Kernel完全没有执行("too many resources requested for launch"),输出buffer保留了未初始化的脏数据(Inf/NaN)。H卡(Hopper SM 90)上同样的kernel没有问题,因为寄存器使用量更低。

修复:给GammaBetaBackwardCUDAKernelTemplate加上__launch_bounds__(block_dim_x* block_dim_y)

是否引起精度变化

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 23, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Copy Markdown
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants