diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index cdc81f47047d6c..6adcb22354f085 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -994,7 +994,7 @@ template phi::dtype::complex GetValue(const DenseTensor* x); template std::vector GetVectorFromTensor(const DenseTensor* x) { std::vector vec_new_data; - if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT32) { + if (x->dtype() == DataType::INT32) { auto* data = x->data(); DenseTensor cpu_attr_tensor; if (x->place().GetType() != phi::AllocationType::CPU) { @@ -1004,7 +1004,7 @@ std::vector GetVectorFromTensor(const DenseTensor* x) { data = cpu_attr_tensor.data(); } vec_new_data = std::vector(data, data + x->numel()); - } else if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT64) { + } else if (x->dtype() == DataType::INT64) { auto* data = x->data(); DenseTensor cpu_attr_tensor; if (x->place().GetType() != phi::AllocationType::CPU) { @@ -1018,7 +1018,7 @@ std::vector GetVectorFromTensor(const DenseTensor* x) { } else { PADDLE_THROW(common::errors::InvalidArgument( "The dtype of Tensor must be int32 or int64, but received: %s", - phi::TransToProtoVarType(x->dtype()))); + x->dtype())); } return vec_new_data; } @@ -1046,20 +1046,18 @@ std::vector _GetVectorFromTensor(const DenseTensor* x) { template <> std::vector GetVectorFromTensor(const DenseTensor* x) { - if (phi::TransToProtoVarType(x->dtype()) != ProtoDataType::FP32) { + if (x->dtype() != DataType::FLOAT32) { PADDLE_THROW(common::errors::InvalidArgument( - "The dtype of Tensor must be float32, but received: %s", - phi::TransToProtoVarType(x->dtype()))); + "The dtype of Tensor must be float32, but received: %s", x->dtype())); } return _GetVectorFromTensor(x); } template <> std::vector GetVectorFromTensor(const DenseTensor* x) { - if (phi::TransToProtoVarType(x->dtype()) != ProtoDataType::FP64) { + if (x->dtype() != DataType::FLOAT64) { PADDLE_THROW(common::errors::InvalidArgument( - "The dtype of Tensor must be float64, but received: %s", - phi::TransToProtoVarType(x->dtype()))); + "The dtype of Tensor must be float64, but received: %s", x->dtype())); } return _GetVectorFromTensor(x); } diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index a5b400b4516d33..9f43f551e127b1 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2849,11 +2849,12 @@ void FusionGroupInferMeta(const std::vector& ins, } for (size_t j = 0; j < num_outs; ++j) { - if (outs_dtype[j] == phi::TransToProtoVarType(DataType::FLOAT16)) { + DataType out_dtype = TransToPhiDataType(outs_dtype[j]); + if (out_dtype == DataType::FLOAT16) { outs[j]->set_dtype(DataType::FLOAT16); - } else if (outs_dtype[j] == phi::TransToProtoVarType(DataType::FLOAT32)) { + } else if (out_dtype == DataType::FLOAT32) { outs[j]->set_dtype(DataType::FLOAT32); - } else if (outs_dtype[j] == phi::TransToProtoVarType(DataType::FLOAT64)) { + } else if (out_dtype == DataType::FLOAT64) { outs[j]->set_dtype(DataType::FLOAT64); } } diff --git a/paddle/phi/kernels/cpu/pyramid_hash_kernel.cc b/paddle/phi/kernels/cpu/pyramid_hash_kernel.cc index d3ce4a285b0c05..a3f435b8869f4f 100644 --- a/paddle/phi/kernels/cpu/pyramid_hash_kernel.cc +++ b/paddle/phi/kernels/cpu/pyramid_hash_kernel.cc @@ -228,8 +228,8 @@ void CPUPyramidHashOPKernel(const Context& dev_ctx, if (iter != iter_end) { exit(1); } - auto weight_type = TransToProtoVarType(_blobs_0->dtype()); - if (_is_training == 0 && weight_type != ProtoDataType::INT8) { + auto weight_type = _blobs_0->dtype(); + if (_is_training == 0 && weight_type != DataType::INT8) { funcs::axpy_noadd( top_data, top_data, top->dims()[0] * top->dims()[1], _drop_out_percent); } diff --git a/paddle/phi/kernels/cpu/tdm_sampler_kernel.cc b/paddle/phi/kernels/cpu/tdm_sampler_kernel.cc index aeea907ba48fb7..ea0b0095ef8379 100644 --- a/paddle/phi/kernels/cpu/tdm_sampler_kernel.cc +++ b/paddle/phi/kernels/cpu/tdm_sampler_kernel.cc @@ -262,9 +262,9 @@ void TDMSamplerKernel(const Context &dev_ctx, DenseTensor *out, DenseTensor *labels, DenseTensor *mask) { - const auto &input_type = TransToProtoVarType(x.dtype()); + const auto &input_type = x.dtype(); bool input_type_match = - input_type == ProtoDataType::INT32 || input_type == ProtoDataType::INT64; + input_type == DataType::INT32 || input_type == DataType::INT64; PADDLE_ENFORCE_EQ(input_type_match, true, common::errors::InvalidArgument( @@ -274,9 +274,9 @@ void TDMSamplerKernel(const Context &dev_ctx, DataTypeToString(DataType::INT32), DataTypeToString(DataType::INT64))); - const auto &travel_type = TransToProtoVarType(travel.dtype()); - bool travel_type_match = travel_type == ProtoDataType::INT32 || - travel_type == ProtoDataType::INT64; + const auto &travel_type = travel.dtype(); + bool travel_type_match = + travel_type == DataType::INT32 || travel_type == DataType::INT64; PADDLE_ENFORCE_EQ(travel_type_match, true, common::errors::InvalidArgument( @@ -286,9 +286,9 @@ void TDMSamplerKernel(const Context &dev_ctx, DataTypeToString(DataType::INT32), DataTypeToString(DataType::INT64))); - const auto &layer_type = TransToProtoVarType(layer.dtype()); + const auto &layer_type = layer.dtype(); bool layer_type_match = - layer_type == ProtoDataType::INT32 || layer_type == ProtoDataType::INT64; + layer_type == DataType::INT32 || layer_type == DataType::INT64; PADDLE_ENFORCE_EQ(layer_type_match, true, common::errors::InvalidArgument( @@ -305,10 +305,9 @@ void TDMSamplerKernel(const Context &dev_ctx, DataTypeToString(travel.dtype()), DataTypeToString(layer.dtype()))); - auto output_type = static_cast(dtype); + auto output_type = TransToPhiDataType(dtype); - if (travel_type == ProtoDataType::INT32 && - output_type == ProtoDataType::INT32) { + if (travel_type == DataType::INT32 && output_type == DataType::INT32) { TDMSamplerInner(dev_ctx, x, travel, @@ -320,8 +319,7 @@ void TDMSamplerKernel(const Context &dev_ctx, out, labels, mask); - } else if (travel_type == ProtoDataType::INT64 && - output_type == ProtoDataType::INT32) { + } else if (travel_type == DataType::INT64 && output_type == DataType::INT32) { TDMSamplerInner(dev_ctx, x, travel, @@ -333,8 +331,7 @@ void TDMSamplerKernel(const Context &dev_ctx, out, labels, mask); - } else if (travel_type == ProtoDataType::INT32 && - output_type == ProtoDataType::INT64) { + } else if (travel_type == DataType::INT32 && output_type == DataType::INT64) { TDMSamplerInner(dev_ctx, x, travel, @@ -346,8 +343,7 @@ void TDMSamplerKernel(const Context &dev_ctx, out, labels, mask); - } else if (travel_type == ProtoDataType::INT64 && - output_type == ProtoDataType::INT64) { + } else if (travel_type == DataType::INT64 && output_type == DataType::INT64) { TDMSamplerInner(dev_ctx, x, travel, diff --git a/paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu b/paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu index 5d498961eeec19..3fcf21da26e780 100644 --- a/paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fusion_group_kernel.cu @@ -28,15 +28,14 @@ static void MutableMultiTypeData(std::vector* var, const std::vector& data_type, const Context& dev_ctx) { for (size_t i = 0; i < var->size(); i++) { - if (data_type[i] == phi::TransToProtoVarType(phi::DataType::FLOAT32)) { + DataType dtype = TransToPhiDataType(data_type[i]); + if (dtype == DataType::FLOAT32) { dev_ctx.template Alloc((*var)[i], (*var)[i]->numel() * sizeof(float)); - } else if (data_type[i] == - phi::TransToProtoVarType(phi::DataType::FLOAT16)) { - dev_ctx.template Alloc( - (*var)[i], (*var)[i]->numel() * sizeof(phi::float16)); - } else if (data_type[i] == - phi::TransToProtoVarType(phi::DataType::FLOAT64)) { + } else if (dtype == DataType::FLOAT16) { + dev_ctx.template Alloc((*var)[i], + (*var)[i]->numel() * sizeof(float16)); + } else if (dtype == DataType::FLOAT64) { dev_ctx.template Alloc((*var)[i], (*var)[i]->numel() * sizeof(double)); } @@ -66,25 +65,23 @@ void FusionGroupKernel(const Context& dev_ctx, args.push_back(&n); std::vector ptrs(num_ins + num_outs); for (size_t i = 0; i < num_ins; ++i) { - if (inputs_dtype[i] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) { - ptrs[i] = ins[i]->data(); - } else if (inputs_dtype[i] == - phi::TransToProtoVarType(phi::DataType::FLOAT32)) { + DataType input_dtype = TransToPhiDataType(inputs_dtype[i]); + if (input_dtype == DataType::FLOAT16) { + ptrs[i] = ins[i]->data(); + } else if (input_dtype == DataType::FLOAT32) { ptrs[i] = ins[i]->data(); - } else if (inputs_dtype[i] == - phi::TransToProtoVarType(phi::DataType::FLOAT64)) { + } else if (input_dtype == DataType::FLOAT64) { ptrs[i] = ins[i]->data(); } args.push_back(&ptrs[i]); } for (size_t j = 0; j < num_outs; ++j) { - if (outs_dtype[j] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) { - ptrs[num_ins + j] = outs[j]->data(); - } else if (outs_dtype[j] == - phi::TransToProtoVarType(phi::DataType::FLOAT32)) { + DataType out_dtype = TransToPhiDataType(outs_dtype[j]); + if (out_dtype == DataType::FLOAT16) { + ptrs[num_ins + j] = outs[j]->data(); + } else if (out_dtype == DataType::FLOAT32) { ptrs[num_ins + j] = outs[j]->data(); - } else if (outs_dtype[j] == - phi::TransToProtoVarType(phi::DataType::FLOAT64)) { + } else if (out_dtype == DataType::FLOAT64) { ptrs[num_ins + j] = outs[j]->data(); } args.push_back(&ptrs[num_ins + j]); diff --git a/paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc b/paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc index 932d8516d5538f..9f9f6b126582dd 100644 --- a/paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc +++ b/paddle/phi/kernels/fusion/onednn/fusion_gru_kernel.cc @@ -489,14 +489,13 @@ void RunKernel(const phi::OneDNNContext& dev_ctx, std::shared_ptr h0_memory_p, weight_h_memory_p, weight_x_memory_p; - if (phi::TransToProtoVarType(weight_h.dtype()) == phi::ProtoDataType::FP32) { + if (weight_h.dtype() == DataType::FLOAT32) { h0_memory_p = handler.template AcquireH0Memory(h0.get_ptr()); weight_x_memory_p = handler.template AcquireWeightXMemory(&weight_x, origin_mode); weight_h_memory_p = handler.template AcquireWeightHMemory(&weight_h, origin_mode); - } else if (phi::TransToProtoVarType(weight_h.dtype()) == - phi::ProtoDataType::BF16) { + } else if (weight_h.dtype() == DataType::BFLOAT16) { h0_memory_p = handler.template AcquireH0Memory(h0.get_ptr()); weight_x_memory_p = handler.template AcquireWeightXMemory( &weight_x, origin_mode); diff --git a/paddle/phi/kernels/onednn/matmul_kernel.cc b/paddle/phi/kernels/onednn/matmul_kernel.cc index 0b80eb9946e2b2..c98cfafe8dcfb3 100644 --- a/paddle/phi/kernels/onednn/matmul_kernel.cc +++ b/paddle/phi/kernels/onednn/matmul_kernel.cc @@ -449,9 +449,9 @@ std::shared_ptr> GetPrimitiveFactory( const DenseTensor *input_y, const engine &onednn_engine) { std::string key = funcs::CreateKey(dev_ctx, - phi::TransToProtoVarType(input_x->dtype()), + TransToProtoVarType(input_x->dtype()), vectorize(input_x->dims()), - phi::TransToProtoVarType(input_y->dtype()), + TransToProtoVarType(input_y->dtype()), vectorize(input_y->dims()), dev_ctx.GetOutputsName("Out")[0]); key = funcs::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);