Skip to content

[Enhancement] RNG: add deterministic launch config for cross-device consistency#78767

Merged
DongBaiYue merged 2 commits intoPaddlePaddle:developfrom
DongBaiYue:feature/deterministic-rng-launch-config
Apr 24, 2026
Merged

[Enhancement] RNG: add deterministic launch config for cross-device consistency#78767
DongBaiYue merged 2 commits intoPaddlePaddle:developfrom
DongBaiYue:feature/deterministic-rng-launch-config

Conversation

@DongBaiYue
Copy link
Copy Markdown
Contributor

@DongBaiYue DongBaiYue commented Apr 23, 2026

PR Category

Operator Mechanism

PR Types

New features

Description

新增 FLAGS_deterministic_rng flag,开启后相同 seed 在不同型号 GPU 上产生相同随机序列。默认关闭,不影响现有行为。

问题

python3 -c "import paddle; paddle.seed(42); print(paddle.rand([1000000])[-8:].numpy())"
设备 输出
H800 [7.23e-01, 8.73e-01, 3.03e-01, 9.44e-01, 4.01e-01, 1.76e-01, 3.90e-04, 4.10e-01]
A100 [4.15e-01, 1.43e-01, 4.31e-01, 8.18e-01, 6.12e-01, 3.92e-01, 6.85e-02, 5.67e-01]
海光 K100-AI [7.44e-01, 3.75e-01, 8.44e-01, 1.08e-01, 3.15e-01, 4.55e-01, 9.99e-01, 3.71e-01]

相同 seed=42,三种设备结果完全不同。

根因

随机数生成(RNG)kernel 使用 Philox4_32_10 PRNG,通过 curand_init(seed, subsequence=thread_id, offset, &state) 初始化。grid_size 由设备属性计算:

grid_size = maxThreadsPerMultiProcessor * multiProcessorCount / block_size
设备 SM/CU maxThreadsPerMultiProcessor grid_size(block=256)
A100 108 2048 864
H800 132 2048 1056
海光 K100-AI 120 2560 1200
  1. grid_size 不同 → 线程数不同 → 相同元素被不同 subsequence 的线程处理
  2. increment 不同 → Generator offset 发散 → 后续所有 RNG 调用结果均不同

方案

Flag 接口

  • FLAGS_deterministic_rng(bool,默认 false):开关
  • FLAGS_deterministic_rng_grid(int32,默认 1024):grid size 上限,跨设备一致性要求所有设备使用相同值

思路:开启 flag 后,RNG kernel 使用固定的 launch 配置(block_size=256, grid_size=min(numel/256, FLAGS_deterministic_rng_grid)),使线程到元素的映射与设备无关,无需改 kernel 内部代码。

为什么不用 StatelessPhilox:PyTorch 传统路径(DistributionTemplates.h)和 Paddle 存在相同的设备依赖问题。PyTorch 新增的 StatelessPhilox(StatelessPhilox4x32.cuh)固定 subsequence=0,用 offset 寻址全部 128-bit counter 空间,彻底解耦线程映射,但需重写所有 RNG kernel。本方案仅改 host 端 launch 逻辑,同样实现跨设备一致,侵入性更低。

paddle.set_flags({'FLAGS_deterministic_rng': True})
paddle.seed(42)
paddle.rand([1000000])  # 所有设备结果一致

覆盖范围

入口函数 覆盖的算子 修改文件
distribution_and_transform paddle.randpaddle.uniformTensor.uniform_paddle.randnpaddle.normalpaddle.randintpaddle.exponential_paddle.nn.functional.rrelu distribution_helper.h
DropoutFwGPUKernelDriver paddle.nn.functional.dropout dropout_impl.cu.h
GetRandomCudaProp fused dropout+add fused_dropout_add_utils.h

修改文件

文件 修改内容
paddle/common/flags.cc 定义两个 flag
paddle/phi/kernels/funcs/rng_launch_config.h 新增辅助头文件
paddle/phi/kernels/funcs/distribution_helper.h distribution_and_transform 增加 if/else
paddle/phi/kernels/funcs/dropout_impl.cu.h DropoutFwGPUKernelDriver 增加 if/else
paddle/phi/kernels/fusion/gpu/fused_dropout_add_utils.h GetRandomCudaProp 增加 if/else

性能影响

FLAGS_deterministic_rng_grid 默认值 1024 大于所有上述设备的 SM/CU 数(A100=108, H800=132, K100-AI=120),足以饱和显存带宽,性能影响预计不大。且默认行为(flag 关闭)不受影响。

是否引起精度变化

默认行为(flag 关闭)不受影响。

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Apr 23, 2026

CLA assistant check
All committers have signed the CLA.

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 23, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

When FLAGS_deterministic_rng is enabled, RNG kernels use a fixed
grid_size and block_size instead of device-dependent values, ensuring
the same seed produces identical random sequences across GPU types.

Two new flags:
- FLAGS_deterministic_rng (bool, default=false): enable the feature
- FLAGS_deterministic_rng_grid (int32, default=1024): grid size cap

Modified files:
- flags.cc: define the two flags
- rng_launch_config.h: new helper (IsDeterministicRNG, GetDeterministicRNGConfig)
- distribution_helper.h: if/else branch in distribution_and_transform
- dropout_impl.cu.h: if/else branch in DropoutFwGPUKernelDriver
- fused_dropout_add_utils.h: if/else branch in GetRandomCudaProp

Default (flag off) behavior is unchanged.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@DongBaiYue DongBaiYue force-pushed the feature/deterministic-rng-launch-config branch from bda210a to d3fb25d Compare April 23, 2026 12:05
@PaddlePaddle PaddlePaddle deleted a comment from CLAassistant Apr 23, 2026
…tic-rng-launch-config

# Conflicts:
#	paddle/phi/kernels/fusion/gpu/fused_dropout_add_utils.h
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@d6e489c). Learn more about missing BASE report.

Additional details and impacted files
@@             Coverage Diff             @@
##             develop    #78767   +/-   ##
===========================================
  Coverage           ?   100.00%           
===========================================
  Files              ?         1           
  Lines              ?         2           
  Branches           ?         0           
===========================================
  Hits               ?         2           
  Misses             ?         0           
  Partials           ?         0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Contributor

@From00 From00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Contributor

@yongqiangma yongqiangma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@DongBaiYue DongBaiYue merged commit 272eee3 into PaddlePaddle:develop Apr 24, 2026
133 of 140 checks passed
@risemeup1111
Copy link
Copy Markdown

❌ Cherry-pick failed: Conflicts detected when cherry-picking to release/3.3. Please resolve manually.

@risemeup1111
Copy link
Copy Markdown

❌ Cherry-pick failed: Conflicts detected when cherry-picking to release/3.4. Please resolve manually.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants