@@ -30,6 +30,7 @@ limitations under the License. */
3030
3131#if defined(__NVCC__) || defined(__HIPCC__)
3232#include " paddle/phi/kernels/funcs/index_impl.cu.h"
33+ #include " paddle/phi/kernels/funcs/rng_launch_config.h"
3334#include " paddle/phi/kernels/primitive/kernel_primitives.h"
3435#endif
3536
@@ -311,22 +312,36 @@ void distribution_and_transform(const GPUContext &dev_ctx,
311312 if (size == 0 ) return ;
312313 auto gen_cuda = dev_ctx.GetGenerator ();
313314
314- size_t block_size = 256 ;
315- size_t expect_grid_size = (size + block_size - 1 ) / block_size;
316-
317- int64_t device_id = dev_ctx.GetPlace ().GetDeviceId ();
318- const auto &prop = phi::backends::gpu::GetDeviceProperties (device_id);
319-
320- size_t max_grid_size = (prop.maxThreadsPerMultiProcessor / block_size) *
321- prop.multiProcessorCount ;
322- size_t grid_size =
323- expect_grid_size > max_grid_size ? max_grid_size : expect_grid_size;
315+ size_t block_size;
316+ size_t grid_size;
317+ uint64_t increment;
318+
319+ if (funcs::IsDeterministicRNG ()) {
320+ constexpr int kCount = DistOp::kReturnsCount ;
321+ auto cfg = funcs::GetDeterministicRNGConfig (size, kCount );
322+ block_size = cfg.block_size ;
323+ grid_size = cfg.grid_size ;
324+ increment = cfg.increment ;
325+ } else {
326+ block_size = 256 ;
327+ size_t expect_grid_size = (size + block_size - 1 ) / block_size;
328+
329+ int64_t device_id = dev_ctx.GetPlace ().GetDeviceId ();
330+ const auto &prop = phi::backends::gpu::GetDeviceProperties (device_id);
331+
332+ size_t max_grid_size = (prop.maxThreadsPerMultiProcessor / block_size) *
333+ prop.multiProcessorCount ;
334+ grid_size =
335+ expect_grid_size > max_grid_size ? max_grid_size : expect_grid_size;
336+
337+ size_t total_thread = block_size * grid_size;
338+ size_t curand4_loop_times =
339+ (size + 4 * total_thread - 1 ) / (4 * total_thread);
340+ // 'increment' should be multiple of 4
341+ increment = curand4_loop_times * 4 ;
342+ }
324343
325344 size_t total_thread = block_size * grid_size;
326- size_t curand4_loop_times =
327- (size + 4 * total_thread - 1 ) / (4 * total_thread);
328- // 'increment' should be multiple of 4
329- uint64_t increment = curand4_loop_times * 4 ;
330345
331346 auto seed_offset = gen_cuda->IncrementOffset (increment);
332347 uint64_t seed = seed_offset.first ;
0 commit comments