Skip to content

Commit 804d958

Browse files
committed
polish
1 parent 023b1ae commit 804d958

1 file changed

Lines changed: 9 additions & 13 deletions

File tree

paddle/phi/kernels/gpu/argsort_kernel.cu

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
#include "paddle/phi/backends/gpu/gpu_context.h"
2222
#include "paddle/phi/backends/gpu/gpu_info.h"
2323
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
24-
#include "paddle/phi/core/kernel_registry.h"
2524
#include "paddle/phi/common/memory_utils.h"
25+
#include "paddle/phi/core/kernel_registry.h"
2626
#include "paddle/phi/kernels/funcs/blas/blas.h"
2727
#include "paddle/phi/kernels/funcs/cub.h"
2828
#include "paddle/phi/kernels/funcs/math_function.h"
@@ -284,15 +284,13 @@ void PerSort(const GPUContext& dev_ctx,
284284
bool stable,
285285
bool descending) {
286286
#ifdef PADDLE_WITH_CUDA
287-
phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(
288-
dev_ctx.GetPlace(), dev_ctx.stream());
289-
const auto& exec_policy =
290-
thrust::cuda::par(allocator).on(dev_ctx.stream());
287+
phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(dev_ctx.GetPlace(),
288+
dev_ctx.stream());
289+
const auto& exec_policy = thrust::cuda::par(allocator).on(dev_ctx.stream());
291290
#else
292-
phi::memory_utils::ThrustAllocator<hipStream_t> allocator(
293-
dev_ctx.GetPlace(), dev_ctx.stream());
294-
const auto& exec_policy =
295-
thrust::hip::par(allocator).on(dev_ctx.stream());
291+
phi::memory_utils::ThrustAllocator<hipStream_t> allocator(dev_ctx.GetPlace(),
292+
dev_ctx.stream());
293+
const auto& exec_policy = thrust::hip::par(allocator).on(dev_ctx.stream());
296294
#endif
297295
if (stable) {
298296
if (descending) {
@@ -358,13 +356,11 @@ void ArgsortKernel(const Context& dev_ctx,
358356
#ifdef PADDLE_WITH_CUDA
359357
phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(
360358
dev_ctx.GetPlace(), dev_ctx.stream());
361-
const auto& exec_policy =
362-
thrust::cuda::par(allocator).on(dev_ctx.stream());
359+
const auto& exec_policy = thrust::cuda::par(allocator).on(dev_ctx.stream());
363360
#else
364361
phi::memory_utils::ThrustAllocator<hipStream_t> allocator(
365362
dev_ctx.GetPlace(), dev_ctx.stream());
366-
const auto& exec_policy =
367-
thrust::hip::par(allocator).on(dev_ctx.stream());
363+
const auto& exec_policy = thrust::hip::par(allocator).on(dev_ctx.stream());
368364
#endif
369365
auto cu_stream = dev_ctx.stream();
370366
thrust::sequence(exec_policy, ids_data, ids_data + size);

0 commit comments

Comments
 (0)