@@ -211,6 +211,11 @@ struct CUBlas<float> {
211211 static void TRSM_BATCH (ARGS... args) {
212212 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cublasStrsmBatched (args...));
213213 }
214+
215+ template <typename ... ARGS>
216+ static void DOT (ARGS... args) {
217+ PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cublasSdot_v2 (args...));
218+ }
214219};
215220
216221template <>
@@ -302,6 +307,11 @@ struct CUBlas<double> {
302307 static void TRSM_BATCH (ARGS... args) {
303308 PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cublasDtrsmBatched (args...));
304309 }
310+
311+ template <typename ... ARGS>
312+ static void DOT (ARGS... args) {
313+ PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cublasDdot_v2 (args...));
314+ }
305315};
306316
307317template <>
@@ -559,6 +569,26 @@ struct CUBlas<phi::float16> {
559569 " cublasGemmEx_64 is not supported on cuda < 12.3" ));
560570#endif
561571 }
572+
573+ static void DOT (cublasHandle_t handle,
574+ int n,
575+ const phi::float16 *x,
576+ const int incx,
577+ const phi::float16 *y,
578+ const int incy,
579+ phi::float16 *result) {
580+ PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cublasDotEx (handle,
581+ n,
582+ x,
583+ CUDA_R_16F,
584+ incx,
585+ y,
586+ CUDA_R_16F,
587+ incy,
588+ result,
589+ CUDA_R_16F,
590+ CUDA_R_32F));
591+ }
562592};
563593
564594template <>
@@ -908,6 +938,23 @@ struct CUBlas<phi::complex64> {
908938 info,
909939 batch_size));
910940 }
941+
942+ static void DOT (cublasHandle_t handle,
943+ int n,
944+ const phi::complex64 *x,
945+ const int incx,
946+ const phi::complex64 *y,
947+ const int incy,
948+ phi::complex64 *result) {
949+ PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cublasCdotu_v2 (
950+ handle,
951+ n,
952+ reinterpret_cast <const cuFloatComplex *>(x),
953+ incx,
954+ reinterpret_cast <const cuFloatComplex *>(y),
955+ incy,
956+ reinterpret_cast <cuFloatComplex *>(result)));
957+ }
911958};
912959
913960template <>
@@ -1257,6 +1304,23 @@ struct CUBlas<phi::complex128> {
12571304 info,
12581305 batch_size));
12591306 }
1307+
1308+ static void DOT (cublasHandle_t handle,
1309+ int n,
1310+ const phi::complex128 *x,
1311+ const int incx,
1312+ const phi::complex128 *y,
1313+ const int incy,
1314+ phi::complex128 *result) {
1315+ PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cublasZdotu_v2 (
1316+ handle,
1317+ n,
1318+ reinterpret_cast <const cuDoubleComplex *>(x),
1319+ incx,
1320+ reinterpret_cast <const cuDoubleComplex *>(y),
1321+ incy,
1322+ reinterpret_cast <cuDoubleComplex *>(result)));
1323+ }
12601324};
12611325
12621326inline void CheckGEMMNSize (int64_t N) {
@@ -2289,6 +2353,38 @@ void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
22892353 });
22902354}
22912355
2356+ template <>
2357+ template <typename T>
2358+ void Blas<phi::GPUContext>::CUDOT (
2359+ int n, const T *x, int incx, const T *y, int incy, T *result) const {
2360+ dev_ctx_.CublasCall ([&](cublasHandle_t handle) {
2361+ CUBlas<T>::DOT (handle, n, x, incx, y, incy, result);
2362+ });
2363+ }
2364+
2365+ template <>
2366+ template <>
2367+ inline void Blas<phi::GPUContext>::CUDOT (int n,
2368+ const phi::bfloat16 *x,
2369+ int incx,
2370+ const phi::bfloat16 *y,
2371+ int incy,
2372+ phi::bfloat16 *result) const {
2373+ dev_ctx_.CublasCall ([&](cublasHandle_t handle) {
2374+ PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::cublasDotEx (handle,
2375+ n,
2376+ x,
2377+ CUDA_R_16BF,
2378+ incx,
2379+ y,
2380+ CUDA_R_16BF,
2381+ incy,
2382+ result,
2383+ CUDA_R_16BF,
2384+ CUDA_R_32F));
2385+ });
2386+ }
2387+
22922388template <>
22932389template <typename T>
22942390void Blas<phi::GPUContext>::SCAL (int n, const T alpha, T *x) const {
0 commit comments