|
21 | 21 | #include "paddle/phi/backends/gpu/gpu_context.h" |
22 | 22 | #include "paddle/phi/backends/gpu/gpu_info.h" |
23 | 23 | #include "paddle/phi/backends/gpu/gpu_launch_config.h" |
24 | | -#include "paddle/phi/core/kernel_registry.h" |
25 | 24 | #include "paddle/phi/common/memory_utils.h" |
| 25 | +#include "paddle/phi/core/kernel_registry.h" |
26 | 26 | #include "paddle/phi/kernels/funcs/blas/blas.h" |
27 | 27 | #include "paddle/phi/kernels/funcs/cub.h" |
28 | 28 | #include "paddle/phi/kernels/funcs/math_function.h" |
@@ -284,15 +284,13 @@ void PerSort(const GPUContext& dev_ctx, |
284 | 284 | bool stable, |
285 | 285 | bool descending) { |
286 | 286 | #ifdef PADDLE_WITH_CUDA |
287 | | - phi::memory_utils::ThrustAllocator<cudaStream_t> allocator( |
288 | | - dev_ctx.GetPlace(), dev_ctx.stream()); |
289 | | - const auto& exec_policy = |
290 | | - thrust::cuda::par(allocator).on(dev_ctx.stream()); |
| 287 | + phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(dev_ctx.GetPlace(), |
| 288 | + dev_ctx.stream()); |
| 289 | + const auto& exec_policy = thrust::cuda::par(allocator).on(dev_ctx.stream()); |
291 | 290 | #else |
292 | | - phi::memory_utils::ThrustAllocator<hipStream_t> allocator( |
293 | | - dev_ctx.GetPlace(), dev_ctx.stream()); |
294 | | - const auto& exec_policy = |
295 | | - thrust::hip::par(allocator).on(dev_ctx.stream()); |
| 291 | + phi::memory_utils::ThrustAllocator<hipStream_t> allocator(dev_ctx.GetPlace(), |
| 292 | + dev_ctx.stream()); |
| 293 | + const auto& exec_policy = thrust::hip::par(allocator).on(dev_ctx.stream()); |
296 | 294 | #endif |
297 | 295 | if (stable) { |
298 | 296 | if (descending) { |
@@ -358,13 +356,11 @@ void ArgsortKernel(const Context& dev_ctx, |
358 | 356 | #ifdef PADDLE_WITH_CUDA |
359 | 357 | phi::memory_utils::ThrustAllocator<cudaStream_t> allocator( |
360 | 358 | dev_ctx.GetPlace(), dev_ctx.stream()); |
361 | | - const auto& exec_policy = |
362 | | - thrust::cuda::par(allocator).on(dev_ctx.stream()); |
| 359 | + const auto& exec_policy = thrust::cuda::par(allocator).on(dev_ctx.stream()); |
363 | 360 | #else |
364 | 361 | phi::memory_utils::ThrustAllocator<hipStream_t> allocator( |
365 | 362 | dev_ctx.GetPlace(), dev_ctx.stream()); |
366 | | - const auto& exec_policy = |
367 | | - thrust::hip::par(allocator).on(dev_ctx.stream()); |
| 363 | + const auto& exec_policy = thrust::hip::par(allocator).on(dev_ctx.stream()); |
368 | 364 | #endif |
369 | 365 | auto cu_stream = dev_ctx.stream(); |
370 | 366 | thrust::sequence(exec_policy, ids_data, ids_data + size); |
|
0 commit comments