Skip to content

Commit 045ae2d

Browse files
committed
rename sign function in slogdet kernel
1 parent 839a4ed commit 045ae2d

2 files changed

Lines changed: 6 additions & 6 deletions

File tree

paddle/phi/kernels/gpu/slogdeterminant_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ struct SlogDeterminantV2Functor {
389389
VLOG(2) << "det value: " << matrix.determinant();
390390
VLOG(2) << "matrix val: " << matrix;
391391
auto det_val = matrix.determinant();
392-
sign_vec.push_back(phi::sign(det_val));
392+
sign_vec.push_back(slogdet_sign(det_val));
393393
det_val >= 0
394394
? log_vec.push_back(std::log(det_val))
395395
: log_vec.push_back(std::log(std::abs(
@@ -557,7 +557,7 @@ struct SlogDeterminantV2Functor<phi::dtype::complex<T>, Context> {
557557
std::complex<T> det_val = matrix.determinant();
558558
T abs_det_val = std::abs(det_val);
559559
sign_vec.push_back(static_cast<phi::dtype::complex<T>>(
560-
phi::sign(det_val, static_cast<std::complex<T>>(abs_det_val))));
560+
slogdet_sign(det_val, static_cast<std::complex<T>>(abs_det_val))));
561561
log_vec.push_back(std::log(abs_det_val));
562562
}
563563
TensorFromVector(sign_vec, dev_ctx, sign);

paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ namespace phi {
3030

3131
// T is not complex
3232
template <typename T>
33-
T sign(T val) {
33+
T slogdet_sign(T val) {
3434
return static_cast<T>(T(0) < val) - (val < T(0));
3535
}
3636

3737
// T is complex
3838
template <typename T>
39-
T sign(T det, T modulus) {
39+
T slogdet_sign(T det, T modulus) {
4040
return det / modulus;
4141
}
4242

@@ -209,7 +209,7 @@ struct SlogDeterminantV2Functor {
209209
VLOG(2) << "det value: " << matrix.determinant();
210210
VLOG(2) << "matrix val: " << matrix;
211211
T det_val = matrix.determinant();
212-
sign_data[i] = phi::sign(det_val);
212+
sign_data[i] = slogdet_sign(det_val);
213213
det_val >= 0
214214
? logdet_data[i] = std::log(det_val)
215215
: logdet_data[i] = std::log(std::abs(
@@ -270,7 +270,7 @@ struct SlogDeterminantV2Functor<dtype::complex<T>, Context> {
270270
logdet_data[i] = -std::numeric_limits<T>::infinity();
271271
} else {
272272
sign_data[i] = static_cast<Complex_T>(
273-
phi::sign(det_val, static_cast<std::complex<T>>(abs_det_val)));
273+
slogdet_sign(det_val, static_cast<std::complex<T>>(abs_det_val)));
274274
logdet_data[i] = std::log(abs_det_val);
275275
}
276276
}

0 commit comments

Comments
 (0)