Skip to content

Commit fa7502c

Browse files
DongBaiYueclaude
andcommitted
Add deterministic RNG launch config for cross-device consistency
When FLAGS_deterministic_rng is enabled, RNG kernels use a fixed grid_size and block_size instead of device-dependent values, ensuring the same seed produces identical random sequences across GPU types. Two new flags: - FLAGS_deterministic_rng (bool, default=false): enable the feature - FLAGS_deterministic_rng_grid (int32, default=1024): grid size cap Modified files: - flags.cc: define the two flags - rng_launch_config.h: new helper (IsDeterministicRNG, GetDeterministicRNGConfig) - distribution_helper.h: if/else branch in distribution_and_transform - dropout_impl.cu.h: if/else branch in DropoutFwGPUKernelDriver - fused_dropout_add_utils.h: if/else branch in GetRandomCudaProp Default (flag off) behavior is unchanged. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2b9f8b6 commit fa7502c

5 files changed

Lines changed: 166 additions & 38 deletions

File tree

paddle/common/flags.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,36 @@ PHI_DEFINE_EXPORTED_bool(
230230
"operator. The autotuning algorithm may be non-deterministic. If "
231231
"true, the algorithm is deterministic.");
232232

233+
/**
234+
* GPU RNG related FLAG
235+
* Name: FLAGS_deterministic_rng
236+
* Since Version: 3.3
237+
* Value Range: bool, default=false
238+
* Example: paddle.set_flags({'FLAGS_deterministic_rng': True})
239+
* Note: Fix RNG kernel launch config so same seed gives same results
240+
* across GPU types.
241+
*/
242+
PHI_DEFINE_EXPORTED_bool(
243+
deterministic_rng,
244+
false,
245+
"Enable cross-device RNG consistency by fixing GPU kernel launch "
246+
"configuration. When true, RNG kernels use a fixed grid/block size "
247+
"so that the same seed produces identical results across GPU types.");
248+
249+
/**
250+
* GPU RNG related FLAG
251+
* Name: FLAGS_deterministic_rng_grid
252+
* Since Version: 3.3
253+
* Value Range: int32, default=1024
254+
* Example: paddle.set_flags({'FLAGS_deterministic_rng_grid': 4096})
255+
* Note: Grid size cap used when FLAGS_deterministic_rng is enabled.
256+
* Cross-device consistency requires the same value on all devices.
257+
*/
258+
PHI_DEFINE_EXPORTED_int32(
259+
deterministic_rng_grid,
260+
1024,
261+
"Grid size cap when FLAGS_deterministic_rng is enabled.");
262+
233263
/**
234264
* CUDA related FLAG
235265
* Name: FLAGS_embedding_deterministic

paddle/phi/kernels/funcs/distribution_helper.h

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License. */
3030

3131
#if defined(__NVCC__) || defined(__HIPCC__)
3232
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
33+
#include "paddle/phi/kernels/funcs/rng_launch_config.h"
3334
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
3435
#endif
3536

@@ -311,22 +312,36 @@ void distribution_and_transform(const GPUContext &dev_ctx,
311312
if (size == 0) return;
312313
auto gen_cuda = dev_ctx.GetGenerator();
313314

314-
size_t block_size = 256;
315-
size_t expect_grid_size = (size + block_size - 1) / block_size;
316-
317-
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
318-
const auto &prop = phi::backends::gpu::GetDeviceProperties(device_id);
319-
320-
size_t max_grid_size = (prop.maxThreadsPerMultiProcessor / block_size) *
321-
prop.multiProcessorCount;
322-
size_t grid_size =
323-
expect_grid_size > max_grid_size ? max_grid_size : expect_grid_size;
315+
size_t block_size;
316+
size_t grid_size;
317+
uint64_t increment;
318+
319+
if (funcs::IsDeterministicRNG()) {
320+
constexpr int kCount = DistOp::kReturnsCount;
321+
auto cfg = funcs::GetDeterministicRNGConfig(size, kCount);
322+
block_size = cfg.block_size;
323+
grid_size = cfg.grid_size;
324+
increment = cfg.increment;
325+
} else {
326+
block_size = 256;
327+
size_t expect_grid_size = (size + block_size - 1) / block_size;
328+
329+
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
330+
const auto &prop = phi::backends::gpu::GetDeviceProperties(device_id);
331+
332+
size_t max_grid_size = (prop.maxThreadsPerMultiProcessor / block_size) *
333+
prop.multiProcessorCount;
334+
grid_size =
335+
expect_grid_size > max_grid_size ? max_grid_size : expect_grid_size;
336+
337+
size_t total_thread = block_size * grid_size;
338+
size_t curand4_loop_times =
339+
(size + 4 * total_thread - 1) / (4 * total_thread);
340+
// 'increment' should be multiple of 4
341+
increment = curand4_loop_times * 4;
342+
}
324343

325344
size_t total_thread = block_size * grid_size;
326-
size_t curand4_loop_times =
327-
(size + 4 * total_thread - 1) / (4 * total_thread);
328-
// 'increment' should be multiple of 4
329-
uint64_t increment = curand4_loop_times * 4;
330345

331346
auto seed_offset = gen_cuda->IncrementOffset(increment);
332347
uint64_t seed = seed_offset.first;

paddle/phi/kernels/funcs/dropout_impl.cu.h

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -302,19 +302,32 @@ void DropoutFwGPUKernelDriver(
302302
// VectorizedRandomGenerator use curand_uniform4, so kVecSize is 4;
303303
constexpr int kVecSize =
304304
phi::funcs::uniform_distribution<float>::kReturnsCount;
305-
auto gpu_config =
306-
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, kVecSize);
307-
size_t grid_size = gpu_config.GetGridSize();
308-
size_t block_size = gpu_config.GetBlockSize();
309-
310-
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
311-
const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id);
312-
size_t max_grid_size = prop.maxThreadsPerMultiProcessor *
313-
prop.multiProcessorCount / block_size;
314-
grid_size = std::min(grid_size, max_grid_size);
315-
316-
auto offset =
317-
((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
305+
306+
size_t grid_size;
307+
size_t block_size;
308+
size_t offset;
309+
310+
if (phi::funcs::IsDeterministicRNG()) {
311+
auto cfg = phi::funcs::GetDeterministicRNGConfig(x_numel, kVecSize);
312+
grid_size = cfg.grid_size;
313+
block_size = cfg.block_size;
314+
offset = cfg.increment;
315+
} else {
316+
auto gpu_config =
317+
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, kVecSize);
318+
grid_size = gpu_config.GetGridSize();
319+
block_size = gpu_config.GetBlockSize();
320+
321+
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
322+
const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id);
323+
size_t max_grid_size = prop.maxThreadsPerMultiProcessor *
324+
prop.multiProcessorCount / block_size;
325+
grid_size = std::min(grid_size, max_grid_size);
326+
327+
offset =
328+
((x_numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
329+
}
330+
318331
size_t main_offset =
319332
size / (block_size * kVecSize) * (block_size * kVecSize);
320333

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
#include <cstddef>
19+
#include <cstdint>
20+
21+
#include "paddle/common/flags.h"
22+
23+
COMMON_DECLARE_bool(deterministic_rng);
24+
COMMON_DECLARE_int32(deterministic_rng_grid);
25+
26+
namespace phi {
27+
namespace funcs {
28+
29+
inline bool IsDeterministicRNG() { return FLAGS_deterministic_rng; }
30+
31+
struct RNGLaunchConfig {
32+
size_t grid_size;
33+
size_t block_size;
34+
uint64_t increment;
35+
};
36+
37+
// Cross-device consistency requires the same FLAGS_deterministic_rng_grid.
38+
// vec_size: elements per thread per loop iteration (kReturnsCount).
39+
inline RNGLaunchConfig GetDeterministicRNGConfig(int64_t numel,
40+
int vec_size = 4) {
41+
RNGLaunchConfig config;
42+
constexpr size_t kBlockSize = 256;
43+
size_t grid_cap = static_cast<size_t>(FLAGS_deterministic_rng_grid);
44+
size_t needed = (static_cast<size_t>(numel) + kBlockSize - 1) / kBlockSize;
45+
config.grid_size = std::min(needed, grid_cap);
46+
config.block_size = kBlockSize;
47+
48+
size_t total_thread = config.grid_size * config.block_size;
49+
size_t loop_times =
50+
(static_cast<size_t>(numel) + vec_size * total_thread - 1) /
51+
(vec_size * total_thread);
52+
config.increment = static_cast<uint64_t>(loop_times * vec_size);
53+
54+
return config;
55+
}
56+
57+
} // namespace funcs
58+
} // namespace phi

paddle/phi/kernels/fusion/gpu/fused_dropout_add_utils.h

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "paddle/phi/kernels/funcs/distribution_helper.h"
18+
#include "paddle/phi/kernels/funcs/rng_launch_config.h"
1819

1920
namespace phi {
2021
namespace fusion {
@@ -23,17 +24,28 @@ template <typename Context>
2324
static inline std::vector<size_t> GetRandomCudaProp(int64_t numel,
2425
const Context& dev_ctx) {
2526
constexpr int kVecSize = funcs::uniform_distribution<float>::kReturnsCount;
26-
auto gpu_config =
27-
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, kVecSize);
28-
size_t grid_size = gpu_config.GetGridSize();
29-
size_t block_size = gpu_config.GetBlockSize();
30-
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
31-
const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id);
32-
size_t max_grid_size =
33-
prop.maxThreadsPerMultiProcessor * prop.multiProcessorCount / block_size;
34-
grid_size = std::min(grid_size, max_grid_size);
35-
auto offset =
36-
((numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
27+
28+
size_t grid_size;
29+
size_t block_size;
30+
size_t offset;
31+
32+
if (funcs::IsDeterministicRNG()) {
33+
auto cfg = funcs::GetDeterministicRNGConfig(numel, kVecSize);
34+
grid_size = cfg.grid_size;
35+
block_size = cfg.block_size;
36+
offset = cfg.increment;
37+
} else {
38+
auto gpu_config =
39+
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, kVecSize);
40+
grid_size = gpu_config.GetGridSize();
41+
block_size = gpu_config.GetBlockSize();
42+
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
43+
const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id);
44+
size_t max_grid_size = prop.maxThreadsPerMultiProcessor *
45+
prop.multiProcessorCount / block_size;
46+
grid_size = std::min(grid_size, max_grid_size);
47+
offset = ((numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
48+
}
3749
size_t main_offset =
3850
numel / (block_size * kVecSize) * (block_size * kVecSize);
3951
return {grid_size, block_size, offset, main_offset};

0 commit comments

Comments
 (0)